.. _sphx_glr_auto_examples_neural_networks_plot_mnist_filters.py: ===================================== Visualization of MLP weights on MNIST ===================================== Sometimes looking at the learned coefficients of a neural network can provide insight into the learning behavior. For example if weights look unstructured, maybe some were not used at all, or if very large coefficients exist, maybe regularization was too low or the learning rate too high. This example shows how to plot some of the first layer weights in a MLPClassifier trained on the MNIST dataset. The input data consists of 28x28 pixel handwritten digits, leading to 784 features in the dataset. Therefore the first layer weight matrix have the shape (784, hidden_layer_sizes[0]). We can therefore visualize a single column of the weight matrix as a 28x28 pixel image. To make the example run faster, we use very few hidden units, and train only for a very short time. Training longer would result in weights with a much smoother spatial appearance. .. image:: /auto_examples/neural_networks/images/sphx_glr_plot_mnist_filters_001.png :align: center .. rst-class:: sphx-glr-script-out Out:: ________________________________________________________________________________ [Memory] Calling __main__--home-ubuntu-scikit-learn-examples-neural_networks-.fetch_mnist... fetch_mnist() _____________________________________________________fetch_mnist - 41.1s, 0.7min Iteration 1, loss = 88770.19492622 Iteration 2, loss = 94144.72785948 Iteration 3, loss = 94116.48942606 Iteration 4, loss = 94088.25915097 Training loss did not improve more than tol=0.000100 for two consecutive epochs. Stopping. Training set score: 0.112367 Test set score: 0.113500 | .. code-block:: python import io from scipy.io.arff import loadarff import matplotlib.pyplot as plt from sklearn.datasets import get_data_home from sklearn.externals.joblib import Memory from sklearn.neural_network import MLPClassifier try: from urllib.request import urlopen except ImportError: # Python 2 from urllib2 import urlopen print(__doc__) memory = Memory(get_data_home()) @memory.cache() def fetch_mnist(): content = urlopen( 'https://www.openml.org/data/download/52667/mnist_784.arff').read() data, meta = loadarff(io.StringIO(content.decode('utf8'))) data = data.view([('pixels', '` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_mnist_filters.ipynb ` .. rst-class:: sphx-glr-signature `Generated by Sphinx-Gallery `_