Изменения

Перейти к: навигация, поиск
Precison-recall кривая
'''Precison-recall (PR) кривая.''' Избавиться от указанной проблемы с несбалансированными классами можно, перейдя от ROC-кривой к PR-кривой. Она определяется аналогично ROC-кривой, только по осям откладываются не FPR и TPR, а полнота (по оси абсцисс) и точность (по оси ординат). Критерием качества семейства алгоритмов выступает '''площадь под PR-кривой''' (англ. '''Area Under the Curve — AUC-PR''')
[[Файл:prPR_curve.png]]  <font color="green"># Код отрисовки Precison-recrecall кривой</font> '''# На примере классификатора, способного проводить различие между всего лишь двумя классами '''# "пятерка" и "не пятерка" из набор данных MNIST '''from''' sklearn.png|600pxmetrics '''import''' precision_recall_curve '''import''' matplotlib.pyplot '''as''' plt '''import''' numpy '''as''' np '''from''' sklearn.datasets '''import''' fetch_openml '''from''' sklearn.model_selection '''import''' cross_val_predict '''from''' sklearn.linear_model '''import''' SGDClassifier mnist = fetch_openml('mnist_784', version=1) X, y = mnist["data"], mnist["target"] y = y.astype(np.uint8) X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:] y_train_5 = (y_train == 5) # True для всех пятерок, False для в сех остальных цифр. Задача опознать пятерки y_test_5 = (y_test == 5) sgd_clf = SGDClassifier(random_state=42) #классификатор на основе метода стохастического градиентного спуска (Stochastic Gradient Descent SGD) sgd_clf.fit(X_train, y_train_5) #обучаем классификатор распозновать пятерки на целом обучающем наборе y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3) y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function") precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores) def plot_precision_recall_vs_threshold(precisions, recalls, thresholds): plt.plot(recalls, precisions, linewidth=2) plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall curve') plt.savefig("Precision_Recall_curve.png") plot_precision_recall_vs_threshold(precisions, recalls, thresholds) plt.show()
== Оценки качества регрессии ==
187
правок

Навигация