Изменения

Перейти к: навигация, поиск

Кросс-валидация

115 байт добавлено, 25 март
k-fold кросс-валидация
\\ CV_k = \frac{1}{k} \sum_{i=1}^{k} Q(\mu(T^l \setminus F_i),F_i) \to min </tex>.
<font color="green"># Пример кода для k-fold кросс-валидации:</font>
'''# Пример классификатора, cпособного проводить различие между всего лишь двумя
'''# классами, "пятерка" и "не пятерка" из набор данных MNIST</font>
'''import''' numpy '''as''' np
'''from''' sklearn.model_selection '''import''' StratifiedKFold
y = y.astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
y_train_5 = (y_train == 5) <font color="green"> # True для всех пятерок, False для в сех остальных цифр. Задача опознать пятерки</font>
y_test_5 = (y_test == 5)
sgd_clf = SGDClassifier(random_state=42) <font color="green"> #классификатор на основе метода стохастического градиентного спуска (Stochastic Gradient Descent SGD)</font> <font color="green"># Разбиваем обучающий набора на 3 блока,</font> # выработку прогнозов и их оценку осуществляем на каждом блоке с использованием модели, обученной на остальных блоках</font>
skfolds = StratifiedKFold(n_splits=3, random_state=42)
for train_index, test_index in skfolds.split(X_train, y_train_5):
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
<font color="green"># print 0.95035
# 0.96035
# 0.9604</font>
=== t×k-fold кросс-валидация ===
187
правок

Навигация