{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# RBF SVM parameters\n\n\nThis example illustrates the effect of the parameters ``gamma`` and ``C`` of\nthe Radial Basis Function (RBF) kernel SVM.\n\nIntuitively, the ``gamma`` parameter defines how far the influence of a single\ntraining example reaches, with low values meaning 'far' and high values meaning\n'close'. The ``gamma`` parameters can be seen as the inverse of the radius of\ninfluence of samples selected by the model as support vectors.\n\nThe ``C`` parameter trades off misclassification of training examples against\nsimplicity of the decision surface. A low ``C`` makes the decision surface\nsmooth, while a high ``C`` aims at classifying all training examples correctly\nby giving the model freedom to select more samples as support vectors.\n\nThe first plot is a visualization of the decision function for a variety of\nparameter values on a simplified classification problem involving only 2 input\nfeatures and 2 possible target classes (binary classification). Note that this\nkind of plot is not possible to do for problems with more features or target\nclasses.\n\nThe second plot is a heatmap of the classifier's cross-validation accuracy as a\nfunction of ``C`` and ``gamma``. For this example we explore a relatively large\ngrid for illustration purposes. In practice, a logarithmic grid from\n$10^{-3}$ to $10^3$ is usually sufficient. If the best parameters\nlie on the boundaries of the grid, it can be extended in that direction in a\nsubsequent search.\n\nNote that the heat map plot has a special colorbar with a midpoint value close\nto the score values of the best performing models so as to make it easy to tell\nthem appart in the blink of an eye.\n\nThe behavior of the model is very sensitive to the ``gamma`` parameter. If\n``gamma`` is too large, the radius of the area of influence of the support\nvectors only includes the support vector itself and no amount of\nregularization with ``C`` will be able to prevent overfitting.\n\nWhen ``gamma`` is very small, the model is too constrained and cannot capture\nthe complexity or \"shape\" of the data. The region of influence of any selected\nsupport vector would include the whole training set. The resulting model will\nbehave similarly to a linear model with a set of hyperplanes that separate the\ncenters of high density of any pair of two classes.\n\nFor intermediate values, we can see on the second plot that good models can\nbe found on a diagonal of ``C`` and ``gamma``. Smooth models (lower ``gamma``\nvalues) can be made more complex by selecting a larger number of support\nvectors (larger ``C`` values) hence the diagonal of good performing models.\n\nFinally one can also observe that for some intermediate values of ``gamma`` we\nget equally performing models when ``C`` becomes very large: it is not\nnecessary to regularize by limiting the number of support vectors. The radius of\nthe RBF kernel alone acts as a good structural regularizer. In practice though\nit might still be interesting to limit the number of support vectors with a\nlower value of ``C`` so as to favor models that use less memory and that are\nfaster to predict.\n\nWe should also note that small differences in scores results from the random\nsplits of the cross-validation procedure. Those spurious variations can be\nsmoothed out by increasing the number of CV iterations ``n_splits`` at the\nexpense of compute time. Increasing the value number of ``C_range`` and\n``gamma_range`` steps will increase the resolution of the hyper-parameter heat\nmap.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print(__doc__)\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.colors import Normalize\n\nfrom sklearn.svm import SVC\nfrom sklearn.preprocessing import StandardScaler\nfrom sklearn.datasets import load_iris\nfrom sklearn.model_selection import StratifiedShuffleSplit\nfrom sklearn.model_selection import GridSearchCV\n\n\n# Utility function to move the midpoint of a colormap to be around\n# the values of interest.\n\nclass MidpointNormalize(Normalize):\n\n def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):\n self.midpoint = midpoint\n Normalize.__init__(self, vmin, vmax, clip)\n\n def __call__(self, value, clip=None):\n x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]\n return np.ma.masked_array(np.interp(value, x, y))\n\n# #############################################################################\n# Load and prepare data set\n#\n# dataset for grid search\n\niris = load_iris()\nX = iris.data\ny = iris.target\n\n# Dataset for decision function visualization: we only keep the first two\n# features in X and sub-sample the dataset to keep only 2 classes and\n# make it a binary classification problem.\n\nX_2d = X[:, :2]\nX_2d = X_2d[y > 0]\ny_2d = y[y > 0]\ny_2d -= 1\n\n# It is usually a good idea to scale the data for SVM training.\n# We are cheating a bit in this example in scaling all of the data,\n# instead of fitting the transformation on the training set and\n# just applying it on the test set.\n\nscaler = StandardScaler()\nX = scaler.fit_transform(X)\nX_2d = scaler.fit_transform(X_2d)\n\n# #############################################################################\n# Train classifiers\n#\n# For an initial search, a logarithmic grid with basis\n# 10 is often helpful. Using a basis of 2, a finer\n# tuning can be achieved but at a much higher cost.\n\nC_range = np.logspace(-2, 10, 13)\ngamma_range = np.logspace(-9, 3, 13)\nparam_grid = dict(gamma=gamma_range, C=C_range)\ncv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)\ngrid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)\ngrid.fit(X, y)\n\nprint(\"The best parameters are %s with a score of %0.2f\"\n % (grid.best_params_, grid.best_score_))\n\n# Now we need to fit a classifier for all parameters in the 2d version\n# (we use a smaller set of parameters here because it takes a while to train)\n\nC_2d_range = [1e-2, 1, 1e2]\ngamma_2d_range = [1e-1, 1, 1e1]\nclassifiers = []\nfor C in C_2d_range:\n for gamma in gamma_2d_range:\n clf = SVC(C=C, gamma=gamma)\n clf.fit(X_2d, y_2d)\n classifiers.append((C, gamma, clf))\n\n# #############################################################################\n# Visualization\n#\n# draw visualization of parameter effects\n\nplt.figure(figsize=(8, 6))\nxx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))\nfor (k, (C, gamma, clf)) in enumerate(classifiers):\n # evaluate decision function in a grid\n Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])\n Z = Z.reshape(xx.shape)\n\n # visualize decision function for these parameters\n plt.subplot(len(C_2d_range), len(gamma_2d_range), k + 1)\n plt.title(\"gamma=10^%d, C=10^%d\" % (np.log10(gamma), np.log10(C)),\n size='medium')\n\n # visualize parameter's effect on decision function\n plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.RdBu)\n plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_2d, cmap=plt.cm.RdBu_r,\n edgecolors='k')\n plt.xticks(())\n plt.yticks(())\n plt.axis('tight')\n\nscores = grid.cv_results_['mean_test_score'].reshape(len(C_range),\n len(gamma_range))\n\n# Draw heatmap of the validation accuracy as a function of gamma and C\n#\n# The score are encoded as colors with the hot colormap which varies from dark\n# red to bright yellow. As the most interesting scores are all located in the\n# 0.92 to 0.97 range we use a custom normalizer to set the mid-point to 0.92 so\n# as to make it easier to visualize the small variations of score values in the\n# interesting range while not brutally collapsing all the low score values to\n# the same color.\n\nplt.figure(figsize=(8, 6))\nplt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)\nplt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot,\n norm=MidpointNormalize(vmin=0.2, midpoint=0.92))\nplt.xlabel('gamma')\nplt.ylabel('C')\nplt.colorbar()\nplt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)\nplt.yticks(np.arange(len(C_range)), C_range)\nplt.title('Validation accuracy')\nplt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}