Developing with the Plotting API

Scikit-learn defines a simple API for creating visualizations for machine learning. The key features of this API is to run calculations once and to have the flexibility to adjust the visualizations after the fact. This section is intended for developers who wish to develop or maintain plotting tools. For usage, users should refer to the :ref`User Guide <visualizations>`.

Plotting API Overview

This logic is encapsulated into a display object where the computed data is stored and the plotting is done in a plot method. The display object’s __init__ method contains only the data needed to create the visualization. The plot method takes in parameters that only have to do with visualization, such as a matplotlib axes. The plot method will store the matplotlib artists as attributes allowing for style adjustments through the display object. The Display class should define one or both class methods: from_estimator and from_predictions. These methods allows to create the Display object from the estimator and some data or from the true and predicted values. After these class methods create the display object with the computed values, then call the display’s plot method. Note that the plot method defines attributes related to matplotlib, such as the line artist. This allows for customizations after calling the plot method.

For example, the RocCurveDisplay defines the following methods and attributes:

class RocCurveDisplay:
    def __init__(self, fpr, tpr, roc_auc, estimator_name):
        ...
        self.fpr = fpr
        self.tpr = tpr
        self.roc_auc = roc_auc
        self.estimator_name = estimator_name

    @classmethod
    def from_estimator(cls, estimator, X, y):
        # get the predictions
        y_pred = estimator.predict_proba(X)[:, 1]
        return cls.from_predictions(y, y_pred, estimator.__class__.__name__)

    @classmethod
    def from_predictions(cls, y, y_pred, estimator_name):
        # do ROC computation from y and y_pred
        fpr, tpr, roc_auc = ...
        viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
        return viz.plot()

    def plot(self, ax=None, name=None, **kwargs):
        ...
        self.line_ = ...
        self.ax_ = ax
        self.figure_ = ax.figure_

Read more in ROC Curve with Visualization API and the User Guide.

Plotting with Multiple Axes

Some of the plotting tools like from_estimator and PartialDependenceDisplay support plotting on multiple axes. Two different scenarios are supported:

1. If a list of axes is passed in, plot will check if the number of axes is consistent with the number of axes it expects and then draws on those axes. 2. If a single axes is passed in, that axes defines a space for multiple axes to be placed. In this case, we suggest using matplotlib’s ~matplotlib.gridspec.GridSpecFromSubplotSpec to split up the space:

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec

fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())

ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])

By default, the ax keyword in plot is None. In this case, the single axes is created and the gridspec api is used to create the regions to plot in.

See for example, The axes defining the bounding box is saved in a `bounding_ax_ attribute. The individual axes created are stored in an axes_ ndarray, corresponding to the axes position on the grid. Positions that are not used are set to None. Furthermore, the matplotlib Artists are stored in lines_ and contours_ where the key is the position on the grid. When a list of axes is passed in, the axes_, lines_, and contours_ is a 1d ndarray corresponding to the list of axes passed in.