Изменения

Перейти к: навигация, поиск
Precison-recall кривая
y = y.astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
y_train_5 = (y_train == 5) <font color="green"># True для всех пятерок, False для в сех остальных цифр. Задача опознать пятерки</font>
y_test_5 = (y_test == 5)
sgd_clf = SGDClassifier(random_state=42) <font color="green"> #классификатор на основе метода стохастического градиентного спуска (англ. Stochastic Gradient Descent SGD)</font> sgd_clf.fit(X_train, y_train_5) <font color="green">#обучаем классификатор распозновать пятерки на целом обучающем наборе</font>
<font color="green"># Для расчета матрицы ошибок сначала понадобится иметь набор прогнозов, чтобы их можно было сравнивать с фактическими целями</font>
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
y = y.astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
y_train_5 = (y_train == 5) <font color="green"># True для всех пятерок, False для в сех остальных цифр. Задача опознать пятерки</font>
y_test_5 = (y_test == 5)
y_train_perfect_predictions = y_train_5 <font color="green"># притворись, что мы достигли совершенства</font>
print(confusion_matrix(y_train_5, y_train_perfect_predictions))
<font color="green"># array([[54579, 0],
# array([[53892, 687]
# [ 1891, 3530]])</font>
print(accuracy_score(y_train_5, y_train_pred)) <font color="green"> # == (53892 + 3530) / (53892 + 3530 + 1891 +687)</font>
<font color="green"># 0.9570333333333333</font>
[[Файл:PR_curve.png]]
<font color="green"># Код отрисовки Precison-recall кривой</font>
'''# На примере классификатора, способного проводить различие между всего лишь двумя классами
'''# "пятерка" и "не пятерка" из набора рукописных цифр MNIST</font>
'''from''' sklearn.metrics '''import''' precision_recall_curve
'''import''' matplotlib.pyplot '''as''' plt
187
правок

Навигация