Изменения

Перейти к: навигация, поиск
Аккуратность (англ. Accuracy)
При этом, наша модель совершенно не обладает никакой предсказательной силой, так как изначально мы хотели определять письма со спамом. Преодолеть это нам поможет переход с общей для всех классов метрики к отдельным показателям качества классов.
 
<font color="green"># код для для подсчета аккуратности:</font>
'''# Пример классификатора, способного проводить различие между всего лишь двумя
'''# классами, "пятерка" и "не пятерка" из набор данных MNIST
'''import''' numpy '''as''' np
'''from''' sklearn.datasets '''import''' fetch_openml
'''from''' sklearn.model_selection '''import''' cross_val_predict
'''from''' sklearn.metrics '''import''' accuracy_score
'''from''' sklearn.linear_model '''import''' SGDClassifier
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"]
y = y.astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
y_train_5 = (y_train == 5) # True для всех пятерок, False для в сех остальных цифр. Задача опознать пятерки
y_test_5 = (y_test == 5)
sgd_clf = SGDClassifier(random_state=42) #классификатор на основе метода стохастического градиентного спуска (Stochastic Gradient Descent SGD)
sgd_clf.fit(X_train, y_train_5) #обучаем классификатор распозновать пятерки на целом обучающем наборе
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
# print(confusion_matrix(y_train_5, y_train_pred))
# array([[53892, 687]
# [ 1891, 3530]])
print(accuracy_score(y_train_5, y_train_pred)) # == (53892 + 3530) / (53892 + 3530 + 1891 +687)
# 0.9570333333333333
=== Точность (англ. Precision) ===
187
правок

Навигация