.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here ` to download the full example code or to run this example in your browser via Binder
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_auto_examples_linear_model_plot_sgd_comparison.py:
==================================
Comparing various online solvers
==================================
An example showing how different online solvers perform
on the hand-written digits dataset.
.. image:: /auto_examples/linear_model/images/sphx_glr_plot_sgd_comparison_001.png
:alt: plot sgd comparison
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
training SGD
training ASGD
training Perceptron
training Passive-Aggressive I
training Passive-Aggressive II
training SAG
|
.. code-block:: default
# Author: Rob Zinkov
# License: BSD 3 clause
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier, Perceptron
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.linear_model import LogisticRegression
heldout = [0.95, 0.90, 0.75, 0.50, 0.01]
rounds = 20
X, y = datasets.load_digits(return_X_y=True)
classifiers = [
("SGD", SGDClassifier(max_iter=100)),
("ASGD", SGDClassifier(average=True)),
("Perceptron", Perceptron()),
("Passive-Aggressive I", PassiveAggressiveClassifier(loss='hinge',
C=1.0, tol=1e-4)),
("Passive-Aggressive II", PassiveAggressiveClassifier(loss='squared_hinge',
C=1.0, tol=1e-4)),
("SAG", LogisticRegression(solver='sag', tol=1e-1, C=1.e4 / X.shape[0]))
]
xx = 1. - np.array(heldout)
for name, clf in classifiers:
print("training %s" % name)
rng = np.random.RandomState(42)
yy = []
for i in heldout:
yy_ = []
for r in range(rounds):
X_train, X_test, y_train, y_test = \
train_test_split(X, y, test_size=i, random_state=rng)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
yy_.append(1 - np.mean(y_pred == y_test))
yy.append(np.mean(yy_))
plt.plot(xx, yy, label=name)
plt.legend(loc="upper right")
plt.xlabel("Proportion train")
plt.ylabel("Test Error Rate")
plt.show()
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 20.047 seconds)
.. _sphx_glr_download_auto_examples_linear_model_plot_sgd_comparison.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: binder-badge
.. image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/0.23.X?urlpath=lab/tree/notebooks/auto_examples/linear_model/plot_sgd_comparison.ipynb
:width: 150 px
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: plot_sgd_comparison.py `
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: plot_sgd_comparison.ipynb `
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery `_