.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here ` to download the full example code or to run this example in your browser via Binder
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_auto_examples_neighbors_plot_digits_kde_sampling.py:
=========================
Kernel Density Estimation
=========================
This example shows how kernel density estimation (KDE), a powerful
non-parametric density estimation technique, can be used to learn
a generative model for a dataset. With this generative model in place,
new samples can be drawn. These new samples reflect the underlying model
of the data.
.. image:: /auto_examples/neighbors/images/sphx_glr_plot_digits_kde_sampling_001.png
:alt: Selection from the input data, "New" digits drawn from the kernel density model
:class: sphx-glr-single-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
best bandwidth: 3.79269019073225
|
.. code-block:: default
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.neighbors import KernelDensity
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV
# load the data
digits = load_digits()
# project the 64-dimensional data to a lower dimension
pca = PCA(n_components=15, whiten=False)
data = pca.fit_transform(digits.data)
# use grid search cross-validation to optimize the bandwidth
params = {'bandwidth': np.logspace(-1, 1, 20)}
grid = GridSearchCV(KernelDensity(), params)
grid.fit(data)
print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))
# use the best estimator to compute the kernel density estimate
kde = grid.best_estimator_
# sample 44 new points from the data
new_data = kde.sample(44, random_state=0)
new_data = pca.inverse_transform(new_data)
# turn data into a 4x11 grid
new_data = new_data.reshape((4, 11, -1))
real_data = digits.data[:44].reshape((4, 11, -1))
# plot real digits and resampled digits
fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[]))
for j in range(11):
ax[4, j].set_visible(False)
for i in range(4):
im = ax[i, j].imshow(real_data[i, j].reshape((8, 8)),
cmap=plt.cm.binary, interpolation='nearest')
im.set_clim(0, 16)
im = ax[i + 5, j].imshow(new_data[i, j].reshape((8, 8)),
cmap=plt.cm.binary, interpolation='nearest')
im.set_clim(0, 16)
ax[0, 5].set_title('Selection from the input data')
ax[5, 5].set_title('"New" digits drawn from the kernel density model')
plt.show()
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 4.812 seconds)
.. _sphx_glr_download_auto_examples_neighbors_plot_digits_kde_sampling.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: binder-badge
.. image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/0.23.X?urlpath=lab/tree/notebooks/auto_examples/neighbors/plot_digits_kde_sampling.ipynb
:width: 150 px
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: plot_digits_kde_sampling.py `
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: plot_digits_kde_sampling.ipynb `
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery `_