.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/applications/plot_face_recognition.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_applications_plot_face_recognition.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_applications_plot_face_recognition.py:


===================================================
Faces recognition example using eigenfaces and SVMs
===================================================

The dataset used in this example is a preprocessed excerpt of the
"Labeled Faces in the Wild", aka LFW_:

  http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz (233MB)

.. _LFW: http://vis-www.cs.umass.edu/lfw/

.. GENERATED FROM PYTHON SOURCE LINES 15-27

.. code-block:: Python

    from time import time

    import matplotlib.pyplot as plt
    from scipy.stats import loguniform

    from sklearn.datasets import fetch_lfw_people
    from sklearn.decomposition import PCA
    from sklearn.metrics import ConfusionMatrixDisplay, classification_report
    from sklearn.model_selection import RandomizedSearchCV, train_test_split
    from sklearn.preprocessing import StandardScaler
    from sklearn.svm import SVC








.. GENERATED FROM PYTHON SOURCE LINES 28-29

Download the data, if not already on disk and load it as numpy arrays

.. GENERATED FROM PYTHON SOURCE LINES 29-51

.. code-block:: Python


    lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)

    # introspect the images arrays to find the shapes (for plotting)
    n_samples, h, w = lfw_people.images.shape

    # for machine learning we use the 2 data directly (as relative pixel
    # positions info is ignored by this model)
    X = lfw_people.data
    n_features = X.shape[1]

    # the label to predict is the id of the person
    y = lfw_people.target
    target_names = lfw_people.target_names
    n_classes = target_names.shape[0]

    print("Total dataset size:")
    print("n_samples: %d" % n_samples)
    print("n_features: %d" % n_features)
    print("n_classes: %d" % n_classes)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Total dataset size:
    n_samples: 1288
    n_features: 1850
    n_classes: 7




.. GENERATED FROM PYTHON SOURCE LINES 52-53

Split into a training set and a test and keep 25% of the data for testing.

.. GENERATED FROM PYTHON SOURCE LINES 53-62

.. code-block:: Python


    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=42
    )

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 63-65

Compute a PCA (eigenfaces) on the face dataset (treated as unlabeled
dataset): unsupervised feature extraction / dimensionality reduction

.. GENERATED FROM PYTHON SOURCE LINES 65-84

.. code-block:: Python


    n_components = 150

    print(
        "Extracting the top %d eigenfaces from %d faces" % (n_components, X_train.shape[0])
    )
    t0 = time()
    pca = PCA(n_components=n_components, svd_solver="randomized", whiten=True).fit(X_train)
    print("done in %0.3fs" % (time() - t0))

    eigenfaces = pca.components_.reshape((n_components, h, w))

    print("Projecting the input data on the eigenfaces orthonormal basis")
    t0 = time()
    X_train_pca = pca.transform(X_train)
    X_test_pca = pca.transform(X_test)
    print("done in %0.3fs" % (time() - t0))






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Extracting the top 150 eigenfaces from 966 faces
    done in 0.076s
    Projecting the input data on the eigenfaces orthonormal basis
    done in 0.007s




.. GENERATED FROM PYTHON SOURCE LINES 85-86

Train a SVM classification model

.. GENERATED FROM PYTHON SOURCE LINES 86-102

.. code-block:: Python


    print("Fitting the classifier to the training set")
    t0 = time()
    param_grid = {
        "C": loguniform(1e3, 1e5),
        "gamma": loguniform(1e-4, 1e-1),
    }
    clf = RandomizedSearchCV(
        SVC(kernel="rbf", class_weight="balanced"), param_grid, n_iter=10
    )
    clf = clf.fit(X_train_pca, y_train)
    print("done in %0.3fs" % (time() - t0))
    print("Best estimator found by grid search:")
    print(clf.best_estimator_)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Fitting the classifier to the training set
    done in 5.682s
    Best estimator found by grid search:
    SVC(C=76823.03433306456, class_weight='balanced', gamma=0.0034189458230957995)




.. GENERATED FROM PYTHON SOURCE LINES 103-104

Quantitative evaluation of the model quality on the test set

.. GENERATED FROM PYTHON SOURCE LINES 104-118

.. code-block:: Python


    print("Predicting people's names on the test set")
    t0 = time()
    y_pred = clf.predict(X_test_pca)
    print("done in %0.3fs" % (time() - t0))

    print(classification_report(y_test, y_pred, target_names=target_names))
    ConfusionMatrixDisplay.from_estimator(
        clf, X_test_pca, y_test, display_labels=target_names, xticks_rotation="vertical"
    )
    plt.tight_layout()
    plt.show()





.. image-sg:: /auto_examples/applications/images/sphx_glr_plot_face_recognition_001.png
   :alt: plot face recognition
   :srcset: /auto_examples/applications/images/sphx_glr_plot_face_recognition_001.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Predicting people's names on the test set
    done in 0.044s
                       precision    recall  f1-score   support

         Ariel Sharon       0.75      0.69      0.72        13
         Colin Powell       0.72      0.87      0.79        60
      Donald Rumsfeld       0.77      0.63      0.69        27
        George W Bush       0.88      0.95      0.91       146
    Gerhard Schroeder       0.95      0.80      0.87        25
          Hugo Chavez       0.90      0.60      0.72        15
           Tony Blair       0.93      0.75      0.83        36

             accuracy                           0.84       322
            macro avg       0.84      0.75      0.79       322
         weighted avg       0.85      0.84      0.84       322





.. GENERATED FROM PYTHON SOURCE LINES 119-120

Qualitative evaluation of the predictions using matplotlib

.. GENERATED FROM PYTHON SOURCE LINES 120-134

.. code-block:: Python



    def plot_gallery(images, titles, h, w, n_row=3, n_col=4):
        """Helper function to plot a gallery of portraits"""
        plt.figure(figsize=(1.8 * n_col, 2.4 * n_row))
        plt.subplots_adjust(bottom=0, left=0.01, right=0.99, top=0.90, hspace=0.35)
        for i in range(n_row * n_col):
            plt.subplot(n_row, n_col, i + 1)
            plt.imshow(images[i].reshape((h, w)), cmap=plt.cm.gray)
            plt.title(titles[i], size=12)
            plt.xticks(())
            plt.yticks(())









.. GENERATED FROM PYTHON SOURCE LINES 135-136

plot the result of the prediction on a portion of the test set

.. GENERATED FROM PYTHON SOURCE LINES 136-149

.. code-block:: Python



    def title(y_pred, y_test, target_names, i):
        pred_name = target_names[y_pred[i]].rsplit(" ", 1)[-1]
        true_name = target_names[y_test[i]].rsplit(" ", 1)[-1]
        return "predicted: %s\ntrue:      %s" % (pred_name, true_name)


    prediction_titles = [
        title(y_pred, y_test, target_names, i) for i in range(y_pred.shape[0])
    ]

    plot_gallery(X_test, prediction_titles, h, w)



.. image-sg:: /auto_examples/applications/images/sphx_glr_plot_face_recognition_002.png
   :alt: predicted: Bush true:      Bush, predicted: Bush true:      Bush, predicted: Blair true:      Blair, predicted: Bush true:      Bush, predicted: Bush true:      Bush, predicted: Bush true:      Bush, predicted: Schroeder true:      Schroeder, predicted: Powell true:      Powell, predicted: Bush true:      Bush, predicted: Bush true:      Bush, predicted: Bush true:      Bush, predicted: Bush true:      Bush
   :srcset: /auto_examples/applications/images/sphx_glr_plot_face_recognition_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 150-151

plot the gallery of the most significative eigenfaces

.. GENERATED FROM PYTHON SOURCE LINES 151-157

.. code-block:: Python


    eigenface_titles = ["eigenface %d" % i for i in range(eigenfaces.shape[0])]
    plot_gallery(eigenfaces, eigenface_titles, h, w)

    plt.show()




.. image-sg:: /auto_examples/applications/images/sphx_glr_plot_face_recognition_003.png
   :alt: eigenface 0, eigenface 1, eigenface 2, eigenface 3, eigenface 4, eigenface 5, eigenface 6, eigenface 7, eigenface 8, eigenface 9, eigenface 10, eigenface 11
   :srcset: /auto_examples/applications/images/sphx_glr_plot_face_recognition_003.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 158-162

Face recognition problem would be much more effectively solved by training
convolutional neural networks but this family of models is outside of the scope of
the scikit-learn library. Interested readers should instead try to use pytorch or
tensorflow to implement such models.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 22.070 seconds)


.. _sphx_glr_download_auto_examples_applications_plot_face_recognition.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.4.X?urlpath=lab/tree/notebooks/auto_examples/applications/plot_face_recognition.ipynb
        :alt: Launch binder
        :width: 150 px

    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/?path=auto_examples/applications/plot_face_recognition.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_face_recognition.ipynb <plot_face_recognition.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_face_recognition.py <plot_face_recognition.py>`


.. include:: plot_face_recognition.recommendations


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_