.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/tree/plot_unveil_tree_structure.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_tree_plot_unveil_tree_structure.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py:


=========================================
Understanding the decision tree structure
=========================================

The decision tree structure can be analysed to gain further insight on the
relation between the features and the target to predict. In this example, we
show how to retrieve:

- the binary tree structure;
- the depth of each node and whether or not it's a leaf;
- the nodes that were reached by a sample using the ``decision_path`` method;
- the leaf that was reached by a sample using the apply method;
- the rules that were used to predict a sample;
- the decision path shared by a group of samples.

.. GENERATED FROM PYTHON SOURCE LINES 18-27

.. code-block:: default


    import numpy as np
    from matplotlib import pyplot as plt

    from sklearn import tree
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.tree import DecisionTreeClassifier








.. GENERATED FROM PYTHON SOURCE LINES 28-32

Train tree classifier
---------------------
First, we fit a :class:`~sklearn.tree.DecisionTreeClassifier` using the
:func:`~sklearn.datasets.load_iris` dataset.

.. GENERATED FROM PYTHON SOURCE LINES 32-41

.. code-block:: default


    iris = load_iris()
    X = iris.data
    y = iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

    clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
    clf.fit(X_train, y_train)






.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <style>#sk-container-id-8 {color: black;}#sk-container-id-8 pre{padding: 0;}#sk-container-id-8 div.sk-toggleable {background-color: white;}#sk-container-id-8 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-8 label.sk-toggleable__label-arrow:before {content: "▸";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-8 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-8 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-8 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-8 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-8 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-8 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: "▾";}#sk-container-id-8 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-8 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-8 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-8 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-8 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-8 div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-8 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-8 div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-8 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-8 div.sk-item {position: relative;z-index: 1;}#sk-container-id-8 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-8 div.sk-item::before, #sk-container-id-8 div.sk-parallel-item::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-8 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-8 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-8 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-8 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-8 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-8 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-8 div.sk-label-container {text-align: center;}#sk-container-id-8 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-8 div.sk-text-repr-fallback {display: none;}</style><div id="sk-container-id-8" class="sk-top-container"><div class="sk-text-repr-fallback"><pre>DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-28" type="checkbox" checked><label for="sk-estimator-id-28" class="sk-toggleable__label sk-toggleable__label-arrow">DecisionTreeClassifier</label><div class="sk-toggleable__content"><pre>DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)</pre></div></div></div></div></div>
    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 42-76

Tree structure
--------------

The decision classifier has an attribute called ``tree_`` which allows access
to low level attributes such as ``node_count``, the total number of nodes,
and ``max_depth``, the maximal depth of the tree. The
``tree_.compute_node_depths()`` method computes the depth of each node in the
tree. `tree_` also stores the entire binary tree structure, represented as a
number of parallel arrays. The i-th element of each array holds information
about the node ``i``. Node 0 is the tree's root. Some of the arrays only
apply to either leaves or split nodes. In this case the values of the nodes
of the other type is arbitrary. For example, the arrays ``feature`` and
``threshold`` only apply to split nodes. The values for leaf nodes in these
arrays are therefore arbitrary.

Among these arrays, we have:

  - ``children_left[i]``: id of the left child of node ``i`` or -1 if leaf
    node
  - ``children_right[i]``: id of the right child of node ``i`` or -1 if leaf
    node
  - ``feature[i]``: feature used for splitting node ``i``
  - ``threshold[i]``: threshold value at node ``i``
  - ``n_node_samples[i]``: the number of training samples reaching node
    ``i``
  - ``impurity[i]``: the impurity at node ``i``
  - ``weighted_n_node_samples[i]``: the weighted number of training samples
    reaching node ``i``
  - ``value[i, j, k]``: the summary of the training samples that reached node i for
    class j and output k.

Using the arrays, we can traverse the tree structure to compute various
properties. Below, we will compute the depth of each node and whether or not
it is a leaf.

.. GENERATED FROM PYTHON SOURCE LINES 76-129

.. code-block:: default


    n_nodes = clf.tree_.node_count
    children_left = clf.tree_.children_left
    children_right = clf.tree_.children_right
    feature = clf.tree_.feature
    threshold = clf.tree_.threshold
    values = clf.tree_.value

    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
    while len(stack) > 0:
        # `pop` ensures each node is only visited once
        node_id, depth = stack.pop()
        node_depth[node_id] = depth

        # If the left and right child of a node is not the same we have a split
        # node
        is_split_node = children_left[node_id] != children_right[node_id]
        # If a split node, append left and right children and depth to `stack`
        # so we can loop through them
        if is_split_node:
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
        else:
            is_leaves[node_id] = True

    print(
        "The binary tree structure has {n} nodes and has "
        "the following tree structure:\n".format(n=n_nodes)
    )
    for i in range(n_nodes):
        if is_leaves[i]:
            print(
                "{space}node={node} is a leaf node with value={value}.".format(
                    space=node_depth[i] * "\t", node=i, value=values[i]
                )
            )
        else:
            print(
                "{space}node={node} is a split node with value={value}: "
                "go to node {left} if X[:, {feature}] <= {threshold} "
                "else to node {right}.".format(
                    space=node_depth[i] * "\t",
                    node=i,
                    left=children_left[i],
                    feature=feature[i],
                    threshold=threshold[i],
                    right=children_right[i],
                    value=values[i],
                )
            )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    The binary tree structure has 5 nodes and has the following tree structure:

    node=0 is a split node with value=[[37. 34. 41.]]: go to node 1 if X[:, 3] <= 0.800000011920929 else to node 2.
            node=1 is a leaf node with value=[[37.  0.  0.]].
            node=2 is a split node with value=[[ 0. 34. 41.]]: go to node 3 if X[:, 2] <= 4.950000047683716 else to node 4.
                    node=3 is a leaf node with value=[[ 0. 33.  3.]].
                    node=4 is a leaf node with value=[[ 0.  1. 38.]].




.. GENERATED FROM PYTHON SOURCE LINES 130-148

What is the values array used here?
-----------------------------------
The `tree_.value` array is a 3D array of shape
[``n_nodes``, ``n_classes``, ``n_outputs``] which provides the count of samples
reaching a node for each class and for each output. Each node has a ``value``
array which is the number of weighted samples reaching this
node for each output and class.

For example, in the above tree built on the iris dataset, the root node has
``value = [37, 34, 41]``, indicating there are 37 samples
of class 0, 34 samples of class 1, and 41 samples of class 2 at the root node.
Traversing the tree, the samples are split and as a result, the ``value`` array
reaching each node changes. The left child of the root node has ``value = [37, 0, 0]``
because all 37 samples in the left child node are from class 0.

Note: In this example, `n_outputs=1`, but the tree classifier can also handle
multi-output problems. The `value` array at each node would just be a 2D
array instead.

.. GENERATED FROM PYTHON SOURCE LINES 150-151

We can compare the above output to the plot of the decision tree.

.. GENERATED FROM PYTHON SOURCE LINES 151-155

.. code-block:: default


    tree.plot_tree(clf)
    plt.show()




.. image-sg:: /auto_examples/tree/images/sphx_glr_plot_unveil_tree_structure_001.png
   :alt: plot unveil tree structure
   :srcset: /auto_examples/tree/images/sphx_glr_plot_unveil_tree_structure_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 156-173

Decision path
-------------

We can also retrieve the decision path of samples of interest. The
``decision_path`` method outputs an indicator matrix that allows us to
retrieve the nodes the samples of interest traverse through. A non zero
element in the indicator matrix at position ``(i, j)`` indicates that
the sample ``i`` goes through the node ``j``. Or, for one sample ``i``, the
positions of the non zero elements in row ``i`` of the indicator matrix
designate the ids of the nodes that sample goes through.

The leaf ids reached by samples of interest can be obtained with the
``apply`` method. This returns an array of the node ids of the leaves
reached by each sample of interest. Using the leaf ids and the
``decision_path`` we can obtain the splitting conditions that were used to
predict a sample or a group of samples. First, let's do it for one sample.
Note that ``node_index`` is a sparse matrix.

.. GENERATED FROM PYTHON SOURCE LINES 173-207

.. code-block:: default


    node_indicator = clf.decision_path(X_test)
    leaf_id = clf.apply(X_test)

    sample_id = 0
    # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
    node_index = node_indicator.indices[
        node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
    ]

    print("Rules used to predict sample {id}:\n".format(id=sample_id))
    for node_id in node_index:
        # continue to the next node if it is a leaf node
        if leaf_id[sample_id] == node_id:
            continue

        # check if value of the split feature for sample 0 is below threshold
        if X_test[sample_id, feature[node_id]] <= threshold[node_id]:
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print(
            "decision node {node} : (X_test[{sample}, {feature}] = {value}) "
            "{inequality} {threshold})".format(
                node=node_id,
                sample=sample_id,
                feature=feature[node_id],
                value=X_test[sample_id, feature[node_id]],
                inequality=threshold_sign,
                threshold=threshold[node_id],
            )
        )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Rules used to predict sample 0:

    decision node 0 : (X_test[0, 3] = 2.4) > 0.800000011920929)
    decision node 2 : (X_test[0, 2] = 5.1) > 4.950000047683716)




.. GENERATED FROM PYTHON SOURCE LINES 208-210

For a group of samples, we can determine the common nodes the samples go
through.

.. GENERATED FROM PYTHON SOURCE LINES 210-223

.. code-block:: default


    sample_ids = [0, 1]
    # boolean array indicating the nodes both samples go through
    common_nodes = node_indicator.toarray()[sample_ids].sum(axis=0) == len(sample_ids)
    # obtain node ids using position in array
    common_node_id = np.arange(n_nodes)[common_nodes]

    print(
        "\nThe following samples {samples} share the node(s) {nodes} in the tree.".format(
            samples=sample_ids, nodes=common_node_id
        )
    )
    print("This is {prop}% of all nodes.".format(prop=100 * len(common_node_id) / n_nodes))




.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    The following samples [0, 1] share the node(s) [0 2] in the tree.
    This is 40.0% of all nodes.





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.077 seconds)


.. _sphx_glr_download_auto_examples_tree_plot_unveil_tree_structure.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.3.X?urlpath=lab/tree/notebooks/auto_examples/tree/plot_unveil_tree_structure.ipynb
        :alt: Launch binder
        :width: 150 px



    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/?path=auto_examples/tree/plot_unveil_tree_structure.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_unveil_tree_structure.py <plot_unveil_tree_structure.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_unveil_tree_structure.ipynb <plot_unveil_tree_structure.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_