Изменения

Перейти к: навигация, поиск
k-fold кросс-валидация
<font color="green"># Пример кода для k-fold кросс-валидация:</font>
'''
'''import numpy as np
'''from sklearn.model_selection import StratifiedKFold
'''from sklearn.datasets import fetch_openml
'''from sklearn.base import clone
'''from sklearn.linear_model import SGDClassifier
'''
'''mnist = fetch_openml('mnist_784', version=1)
'''X, y = mnist["data"], mnist["target"]
'''some_digit = X[0] # признаки цифры пять
'''some_digit_image = some_digit.reshape(28, 28)
'''y = y.astype(np.uint8)
'''X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
'''y_train_5 = (y_train == 5) # True для всех пятерок, False для в сех остальных цифр
'''y_test_5 = (y_test == 5)
'''sgd_clf = SGDClassifier(random_state=42)
'''sgd_clf.fit(X_train, y_train_5)
'''sgd_clf.predict([some_digit])
'''skfolds = StratifiedKFold(n_splits=3, random_state=42)
'''for train_index, test_index in skfolds.split(X_train, y_train_5):
''''' X_test_fold = X_train[test_index]
''''' y_test_fold = y_train_5[test_index]
'''''
''''' clone_clf.fit(X_train_folds, y_train_folds)
''''' y_pred = clone_clf.predict(X_test_fold)
''''' n_correct = sum(y_pred == y_test_fold)
''''' print(n_correct / len(y_pred))
''' '''from sklearn.model_selection import cross_val_score '''cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")array([# print 0.9635595035, 0.9379596035, 0.95615])9604
== Источники информации ==
187
правок

Навигация