.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/semi_supervised/plot_label_propagation_digits.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_semi_supervised_plot_label_propagation_digits.py: =================================================== Label Propagation digits: Demonstrating performance =================================================== This example demonstrates the power of semisupervised learning by training a Label Spreading model to classify handwritten digits with sets of very few labels. The handwritten digit dataset has 1797 total points. The model will be trained using all points, but only 30 will be labeled. Results in the form of a confusion matrix and a series of metrics over each class will be very good. At the end, the top 10 most uncertain predictions will be shown. .. GENERATED FROM PYTHON SOURCE LINES 18-22 .. code-block:: default # Authors: Clay Woolam # License: BSD .. GENERATED FROM PYTHON SOURCE LINES 23-27 Data generation --------------- We use the digits dataset. We only use a subset of randomly selected samples. .. GENERATED FROM PYTHON SOURCE LINES 27-35 .. code-block:: default from sklearn import datasets import numpy as np digits = datasets.load_digits() rng = np.random.RandomState(2) indices = np.arange(len(digits.data)) rng.shuffle(indices) .. GENERATED FROM PYTHON SOURCE LINES 36-39 We selected 340 samples of which only 40 will be associated with a known label. Therefore, we store the indices of the 300 other samples for which we are not supposed to know their labels. .. GENERATED FROM PYTHON SOURCE LINES 40-51 .. code-block:: default X = digits.data[indices[:340]] y = digits.target[indices[:340]] images = digits.images[indices[:340]] n_total_samples = len(y) n_labeled_points = 40 indices = np.arange(n_total_samples) unlabeled_set = indices[n_labeled_points:] .. GENERATED FROM PYTHON SOURCE LINES 52-53 Shuffle everything around .. GENERATED FROM PYTHON SOURCE LINES 53-56 .. code-block:: default y_train = np.copy(y) y_train[unlabeled_set] = -1 .. GENERATED FROM PYTHON SOURCE LINES 57-62 Semi-supervised learning ------------------------ We fit a :class:`~sklearn.semi_supervised.LabelSpreading` and use it to predict the unknown labels. .. GENERATED FROM PYTHON SOURCE LINES 62-75 .. code-block:: default from sklearn.semi_supervised import LabelSpreading from sklearn.metrics import classification_report lp_model = LabelSpreading(gamma=0.25, max_iter=20) lp_model.fit(X, y_train) predicted_labels = lp_model.transduction_[unlabeled_set] true_labels = y[unlabeled_set] print( "Label Spreading model: %d labeled & %d unlabeled points (%d total)" % (n_labeled_points, n_total_samples - n_labeled_points, n_total_samples) ) .. rst-class:: sphx-glr-script-out .. code-block:: none Label Spreading model: 40 labeled & 300 unlabeled points (340 total) .. GENERATED FROM PYTHON SOURCE LINES 76-77 Classification report .. GENERATED FROM PYTHON SOURCE LINES 77-79 .. code-block:: default print(classification_report(true_labels, predicted_labels)) .. rst-class:: sphx-glr-script-out .. code-block:: none precision recall f1-score support 0 1.00 1.00 1.00 27 1 0.82 1.00 0.90 37 2 1.00 0.86 0.92 28 3 1.00 0.80 0.89 35 4 0.92 1.00 0.96 24 5 0.74 0.94 0.83 34 6 0.89 0.96 0.92 25 7 0.94 0.89 0.91 35 8 1.00 0.68 0.81 31 9 0.81 0.88 0.84 24 accuracy 0.90 300 macro avg 0.91 0.90 0.90 300 weighted avg 0.91 0.90 0.90 300 .. GENERATED FROM PYTHON SOURCE LINES 80-81 Confusion matrix .. GENERATED FROM PYTHON SOURCE LINES 81-87 .. code-block:: default from sklearn.metrics import ConfusionMatrixDisplay ConfusionMatrixDisplay.from_predictions( true_labels, predicted_labels, labels=lp_model.classes_ ) .. image-sg:: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_001.png :alt: plot label propagation digits :srcset: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 88-92 Plot the most uncertain predictions ----------------------------------- Here, we will pick and show the 10 most uncertain predictions. .. GENERATED FROM PYTHON SOURCE LINES 92-96 .. code-block:: default from scipy import stats pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T) .. GENERATED FROM PYTHON SOURCE LINES 97-98 Pick the top 10 most uncertain labels .. GENERATED FROM PYTHON SOURCE LINES 98-100 .. code-block:: default uncertainty_index = np.argsort(pred_entropies)[-10:] .. GENERATED FROM PYTHON SOURCE LINES 101-102 Plot .. GENERATED FROM PYTHON SOURCE LINES 102-118 .. code-block:: default import matplotlib.pyplot as plt f = plt.figure(figsize=(7, 5)) for index, image_index in enumerate(uncertainty_index): image = images[image_index] sub = f.add_subplot(2, 5, index + 1) sub.imshow(image, cmap=plt.cm.gray_r) plt.xticks([]) plt.yticks([]) sub.set_title( "predict: %i\ntrue: %i" % (lp_model.transduction_[image_index], y[image_index]) ) f.suptitle("Learning with small amount of labeled data") plt.show() .. image-sg:: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_002.png :alt: Learning with small amount of labeled data, predict: 1 true: 2, predict: 2 true: 2, predict: 8 true: 8, predict: 1 true: 8, predict: 1 true: 8, predict: 1 true: 8, predict: 3 true: 3, predict: 8 true: 8, predict: 2 true: 2, predict: 7 true: 2 :srcset: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.332 seconds) .. _sphx_glr_download_auto_examples_semi_supervised_plot_label_propagation_digits.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.2.X?urlpath=lab/tree/notebooks/auto_examples/semi_supervised/plot_label_propagation_digits.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_label_propagation_digits.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_label_propagation_digits.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_