.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here ` to download the full example code or to run this example in your browser via Binder
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_auto_examples_decomposition_plot_image_denoising.py:
=========================================
Image denoising using dictionary learning
=========================================
An example comparing the effect of reconstructing noisy fragments
of a raccoon face image using firstly online :ref:`DictionaryLearning` and
various transform methods.
The dictionary is fitted on the distorted left half of the image, and
subsequently used to reconstruct the right half. Note that even better
performance could be achieved by fitting to an undistorted (i.e.
noiseless) image, but here we start from the assumption that it is not
available.
A common practice for evaluating the results of image denoising is by looking
at the difference between the reconstruction and the original image. If the
reconstruction is perfect this will look like Gaussian noise.
It can be seen from the plots that the results of :ref:`omp` with two
non-zero coefficients is a bit less biased than when keeping only one
(the edges look less prominent). It is in addition closer from the ground
truth in Frobenius norm.
The result of :ref:`least_angle_regression` is much more strongly biased: the
difference is reminiscent of the local intensity value of the original image.
Thresholding is clearly not useful for denoising, but it is here to show that
it can produce a suggestive output with very high speed, and thus be useful
for other tasks such as object classification, where performance is not
necessarily related to visualisation.
.. rst-class:: sphx-glr-horizontal
*
.. image:: /auto_examples/decomposition/images/sphx_glr_plot_image_denoising_001.png
:alt: Dictionary learned from face patches Train time 3.3s on 22692 patches
:class: sphx-glr-multi-img
*
.. image:: /auto_examples/decomposition/images/sphx_glr_plot_image_denoising_002.png
:alt: Distorted image, Image, Difference (norm: 11.65)
:class: sphx-glr-multi-img
*
.. image:: /auto_examples/decomposition/images/sphx_glr_plot_image_denoising_003.png
:alt: Orthogonal Matching Pursuit 1 atom (time: 0.9s), Image, Difference (norm: 10.55)
:class: sphx-glr-multi-img
*
.. image:: /auto_examples/decomposition/images/sphx_glr_plot_image_denoising_004.png
:alt: Orthogonal Matching Pursuit 2 atoms (time: 2.2s), Image, Difference (norm: 9.64)
:class: sphx-glr-multi-img
*
.. image:: /auto_examples/decomposition/images/sphx_glr_plot_image_denoising_005.png
:alt: Least-angle regression 5 atoms (time: 19.2s), Image, Difference (norm: 14.04)
:class: sphx-glr-multi-img
*
.. image:: /auto_examples/decomposition/images/sphx_glr_plot_image_denoising_006.png
:alt: Thresholding alpha=0.1 (time: 0.1s), Image, Difference (norm: 14.98)
:class: sphx-glr-multi-img
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
Distorting image...
Extracting reference patches...
done in 0.01s.
Learning the dictionary...
done in 3.34s.
Extracting noisy patches...
done in 0.00s.
Orthogonal Matching Pursuit
1 atom...
done in 0.88s.
Orthogonal Matching Pursuit
2 atoms...
done in 2.16s.
Least-angle regression
5 atoms...
done in 19.24s.
Thresholding
alpha=0.1...
done in 0.14s.
|
.. code-block:: default
print(__doc__)
from time import time
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.feature_extraction.image import reconstruct_from_patches_2d
try: # SciPy >= 0.16 have face in misc
from scipy.misc import face
face = face(gray=True)
except ImportError:
face = sp.face(gray=True)
# Convert from uint8 representation with values between 0 and 255 to
# a floating point representation with values between 0 and 1.
face = face / 255.
# downsample for higher speed
face = face[::4, ::4] + face[1::4, ::4] + face[::4, 1::4] + face[1::4, 1::4]
face /= 4.0
height, width = face.shape
# Distort the right half of the image
print('Distorting image...')
distorted = face.copy()
distorted[:, width // 2:] += 0.075 * np.random.randn(height, width // 2)
# Extract all reference patches from the left half of the image
print('Extracting reference patches...')
t0 = time()
patch_size = (7, 7)
data = extract_patches_2d(distorted[:, :width // 2], patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print('done in %.2fs.' % (time() - t0))
# #############################################################################
# Learn the dictionary from reference patches
print('Learning the dictionary...')
t0 = time()
dico = MiniBatchDictionaryLearning(n_components=100, alpha=1, n_iter=500)
V = dico.fit(data).components_
dt = time() - t0
print('done in %.2fs.' % dt)
plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(V[:100]):
plt.subplot(10, 10, i + 1)
plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.suptitle('Dictionary learned from face patches\n' +
'Train time %.1fs on %d patches' % (dt, len(data)),
fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
# #############################################################################
# Display the distorted image
def show_with_diff(image, reference, title):
"""Helper function to display denoising"""
plt.figure(figsize=(5, 3.3))
plt.subplot(1, 2, 1)
plt.title('Image')
plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.subplot(1, 2, 2)
difference = image - reference
plt.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
plt.imshow(difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.suptitle(title, size=16)
plt.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)
show_with_diff(distorted, face, 'Distorted image')
# #############################################################################
# Extract noisy patches and reconstruct them using the dictionary
print('Extracting noisy patches... ')
t0 = time()
data = extract_patches_2d(distorted[:, width // 2:], patch_size)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print('done in %.2fs.' % (time() - t0))
transform_algorithms = [
('Orthogonal Matching Pursuit\n1 atom', 'omp',
{'transform_n_nonzero_coefs': 1}),
('Orthogonal Matching Pursuit\n2 atoms', 'omp',
{'transform_n_nonzero_coefs': 2}),
('Least-angle regression\n5 atoms', 'lars',
{'transform_n_nonzero_coefs': 5}),
('Thresholding\n alpha=0.1', 'threshold', {'transform_alpha': .1})]
reconstructions = {}
for title, transform_algorithm, kwargs in transform_algorithms:
print(title + '...')
reconstructions[title] = face.copy()
t0 = time()
dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
code = dico.transform(data)
patches = np.dot(code, V)
patches += intercept
patches = patches.reshape(len(data), *patch_size)
if transform_algorithm == 'threshold':
patches -= patches.min()
patches /= patches.max()
reconstructions[title][:, width // 2:] = reconstruct_from_patches_2d(
patches, (height, width // 2))
dt = time() - t0
print('done in %.2fs.' % dt)
show_with_diff(reconstructions[title], face,
title + ' (time: %.1fs)' % dt)
plt.show()
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 28.114 seconds)
.. _sphx_glr_download_auto_examples_decomposition_plot_image_denoising.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: binder-badge
.. image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/0.23.X?urlpath=lab/tree/notebooks/auto_examples/decomposition/plot_image_denoising.ipynb
:width: 150 px
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: plot_image_denoising.py `
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: plot_image_denoising.ipynb `
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery `_