Note
Click here to download the full example code or to run this example in your browser via Binder
Receiver Operating Characteristic (ROC) with cross validation¶
Example of Receiver Operating Characteristic (ROC) metric to evaluate classifier output quality using cross-validation.
ROC curves typically feature true positive rate on the Y axis, and false positive rate on the X axis. This means that the top left corner of the plot is the “ideal” point - a false positive rate of zero, and a true positive rate of one. This is not very realistic, but it does mean that a larger area under the curve (AUC) is usually better.
The “steepness” of ROC curves is also important, since it is ideal to maximize the true positive rate while minimizing the false positive rate.
This example shows the ROC response of different datasets, created from K-fold cross-validation. Taking all of these curves, it is possible to calculate the mean area under curve, and see the variance of the curve when the training set is split into different subsets. This roughly shows how the classifier output is affected by changes in the training data, and how different the splits generated by K-fold cross-validation are from one another.
Note
Out:
/home/circleci/project/examples/model_selection/plot_roc_crossval.py:75: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
/home/circleci/project/examples/model_selection/plot_roc_crossval.py:75: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
/home/circleci/project/examples/model_selection/plot_roc_crossval.py:75: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
/home/circleci/project/examples/model_selection/plot_roc_crossval.py:75: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
/home/circleci/project/examples/model_selection/plot_roc_crossval.py:75: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
/home/circleci/project/examples/model_selection/plot_roc_crossval.py:75: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
print(__doc__)
import numpy as np
from scipy import interp
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.metrics import auc
from sklearn.metrics import plot_roc_curve
from sklearn.model_selection import StratifiedKFold
# #############################################################################
# Data IO and generation
# Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape
# Add noisy features
random_state = np.random.RandomState(0)
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
# #############################################################################
# Classification and ROC analysis
# Run classifier with cross-validation and plot ROC curves
cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel='linear', probability=True,
random_state=random_state)
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
fig, ax = plt.subplots()
for i, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = plot_roc_curve(classifier, X[test], y[test],
name='ROC fold {}'.format(i),
alpha=0.3, lw=1, ax=ax)
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.roc_auc)
ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8)
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
ax.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
title="Receiver operating characteristic example")
ax.legend(loc="lower right")
plt.show()
Total running time of the script: ( 0 minutes 0.674 seconds)
Estimated memory usage: 8 MB