Implementing callback support in estimators#
Adding callback support in an estimator boils down to enabling the registration of
callbacks, expressing fit as a tree of tasks, and invoking the callbacks at the
beginning and end of each of these tasks. To achieve this, scikit-learn provides the
following helpers from the callback module:
CallbackSupportMixin, which enables callback registration and initializes callback handling at the beginning of fit.CallbackContext, which represents tasks and is the central object for managing callbacks during fit.with_callbacks, to guarantee proper callback teardown at the end of fit.
The CallbackSupportMixin class#
To support callbacks, an estimator must inherit from the CallbackSupportMixin
class, which exposes the following methods:
set_callbacks, a public method to be called by the user to register callbacks on the estimator._init_callback_context, which should be called at the beginning of fit to create the rootCallbackContext, corresponding to the task that represents the entire execution offit. This method also sets up the callbacks that are registered on the estimator.Note
While the leading underscore signals that
_init_callback_contextis intended for internal use and should not appear in auto-completion suggestions for end users, it is made available to developers building third-party estimators and should be considered part of the public API contract.
The CallbackContext class#
The CallbackContext objects are responsible for invoking the callbacks at the
right time during fit. They track the different tasks of the estimator, with one context
instance representing each task, and capture the tree structure of the tasks involved in
the execution of the fit method.
A task is an arbitrary unit of work defined by the estimator. Usually, a task corresponds to an iteration of the estimator’s learning algorithm. They can also correspond to steps of a pipeline, cross-validation folds, etc. As tasks can be decomposed into subtasks, the tasks (and therefore callback contexts) have a natural tree structure, with the root task being the whole fit task.
The callback context objects follow this tree structure, holding references to their
parent and children contexts, and are dynamically built during fit. The root context
must be created by the _init_callback_context method.
examples of task / context trees#
As an example, KMeans has two nested loops: the outer loop
is controlled by the n_init parameter, and the inner loop is controlled by the
max_iter parameter. Therefore its task tree looks like this:
KMeans fit (root)
├── init 0
│ ├── iter 0
│ ├── iter 1
│ ├── ...
│ └── iter n
├── init 1
│ ├── iter 0
│ ├── ...
│ └── iter n
└── init 2
├── iter 0
├── ...
└── iter n
where each innermost iter j task corresponds to the computation of the labels
and centers for the full dataset. A callback registered on a KMeans estimator thus
will be invoked at the beginning and end of the fit task, each of the outer
init i tasks and each of the inner iter j tasks.
By convention, for performance reasons and consistency across estimators, the innermost tasks of scikit-learn estimators, i.e. the leaves of the task tree, correspond to operations on the full input data (or batches for incremental estimators).
When the estimator is a meta-estimator, a task leaf usually corresponds to fitting a
sub-estimator. Therefore, this leaf and the root task of the sub-estimator actually
represent the same task. In this case the leaf task of the meta-estimator and the
root task of the sub-estimator are merged into a single task. The task trees of the
meta-estimator and the sub-estimator are combined into a single task tree. For
instance, a Pipeline would have a task tree that looks
like this:
Pipeline fit (root)
├── step 0 | StandardScaler fit
│ └── <insert StandardScaler task tree here>
└── step 1 | LogisticRegression fit
└── <insert LogisticRegression task tree here>
To dynamically build the context tree and manage the callbacks during fit, the
CallbackContext class exposes the following methods:
-
This method should be used to create a context for a subtask. Callback contexts must not be created directly but through this method (or
_init_callback_contextfor the root context). call_on_fit_task_beginandcall_on_fit_task_enddef call_on_fit_task_begin( self, *, estimator, X=None, y=None, metadata=None, reconstruction_attributes=None ) -> None: ... def call_on_fit_task_end( self, *, estimator, X=None, y=None, metadata=None, reconstruction_attributes=None ) -> bool: ...
These two methods must be called respectively at the beginning and end of the task that the context is responsible for. As their name suggests, they call the
on_fit_task_beginandon_fit_task_endmethods of the callbacks registered on the estimator.In addition to the callback context that is implicitly passed to the registered callbacks, the keyword arguments of
call_on_fit_task_begin/endare used to pass additional information about the state of the fitting process at a given task. It is not expected to provide a value for all of them at every call of these methods. Estimators are expected to provide all the values that they are capable to produce. Callbacks then adapt their behavior based on the provided values for a given task.The
reconstruction_attributeskwarg#When
call_on_fit_task_begin/endis called, the state of the estimator at this task is likely to be incomplete and thus unable to predict, transform, etc … Thereconstruction_attributeskwarg expects a dictionary containing the necessary missing attributes to set on the estimator to ensure that it is ready to predict, transform, etc … as if fit had stopped at this task.The callback context will copy the state of the estimator at this task, set the reconstruction attributes and pass the resulting estimator to the callbacks as
fitted_estimator.If no additional attributes are needed to make the estimator ready, an empty dictionary should be passed instead of leaving the default value otherwise the callback context won’t pass a
fitted_estimatorto the callbacks.Lazy evaluation of the kwargs#
For each of these kwargs, a callable (with no arguments and returning the kwarg value) can be provided instead of the actual value. When it is the case, if a callback requires the kwarg, the callback context will evaluate the callable and forward the returned value to the callback. This mechanism enables lazy evaluation of the kwarg values, to avoid potentially costly computations when no callback requires a kwarg value.
To prevent performance degradations, estimators should lazily pass quantities that are expensive to compute.
Interrupting
fit#The
call_on_fit_task_endmethod returns a boolean, which can be used to interrupt the current level of iterations, to implement early stopping for instance. It returnsTrueif any callback signaled to stop thefitprocess at the end of this task andFalseotherwise.-
This method enables combining the context trees of individual estimators and meta-estimators in estimator compositions (e.g. a
GridSearchCVon aLogisticRegression) into a single context tree, rooted at the fit of the top level estimator.It should be used in a meta-estimator, on a context corresponding to the task of fitting a sub-estimator. This task is both a leaf task of the meta-estimator and the root task of the sub-estimator. Their corresponding contexts are thus merged into a single context in the combined tree.
In addition,
propagate_callback_contextis a context manager that propagates the auto-propagated callbacks from the meta-estimator to the sub-estimator such that they are called at the tasks of the sub-estimator as well. It also clears the propagated callbacks on exit such that the fitted sub-estimator no longer holds any locally registered callbacks.
The with_callbacks decorator#
For third-party estimators implementing callback support, the fit method should be
decorated with the with_callbacks decorator. This decorator guarantees that the
callbacks are torn down after fit finishes, even if it exits on an error.
For scikit-learn’s built-in estimators, the _fit_context decorator
already takes care of the callbacks teardown, thus with_callbacks should not be
used.
Minimal example#
Here is a typical implementation of callback support in a custom estimator:
from sklearn.callback import CallbackSupportMixin, with_callbacks
class MyEstimator(CallbackSupportMixin):
def __init__(self, max_iter):
self.max_iter = max_iter
@with_callbacks
def fit(self, X, y):
callback_ctx = self._init_callback_context(max_subtasks=self.max_iter)
callback_ctx.call_on_fit_task_begin(estimator=self, X=X, y=y)
for i in range(self.max_iter):
subcontext = callback_ctx.subcontext(task_name="iteration")
subcontext.call_on_fit_task_begin(estimator=self, X=X, y=y)
# Do something
if subcontext.call_on_fit_task_end(estimator=self, X=X, y=y):
break
callback_ctx.call_on_fit_task_end(estimator=self, X=X, y=y)
return self