.. currentmodule:: sklearn

.. TODO: update doc/conftest.py once document is updated and examples run.

.. _metadata_routing:

Metadata Routing
================

.. note::
  The Metadata Routing API is experimental, and is not implemented yet for many
  estimators. Please refer to the :ref:`list of supported and unsupported
  models <metadata_routing_models>` for more information. It may change without
  the usual deprecation cycle. By default this feature is not enabled. You can
  enable this feature  by setting the ``enable_metadata_routing`` flag to
  ``True``::

    >>> import sklearn
    >>> sklearn.set_config(enable_metadata_routing=True)

This guide demonstrates how metadata such as ``sample_weight`` can be routed
and passed along to estimators, scorers, and CV splitters through
meta-estimators such as :class:`~pipeline.Pipeline` and
:class:`~model_selection.GridSearchCV`. In order to pass metadata to a method
such as ``fit`` or ``score``, the object consuming the metadata, must *request*
it. For estimators and splitters, this is done via ``set_*_request`` methods,
e.g. ``set_fit_request(...)``, and for scorers this is done via the
``set_score_request`` method. For grouped splitters such as
:class:`~model_selection.GroupKFold`, a ``groups`` parameter is requested by
default. This is best demonstrated by the following examples.

If you are developing a scikit-learn compatible estimator or meta-estimator,
you can check our related developer guide:
:ref:`sphx_glr_auto_examples_miscellaneous_plot_metadata_routing.py`.

.. note::
  Note that the methods and requirements introduced in this document are only
  relevant if you want to pass :term:`metadata` (e.g. ``sample_weight``) to a method.
  If you're only passing ``X`` and ``y`` and no other parameter / metadata to
  methods such as :term:`fit`, :term:`transform`, etc, then you don't need to set
  anything.

Usage Examples
**************
Here we present a few examples to show different common use-cases. The examples
in this section require the following imports and data::

  >>> import numpy as np
  >>> from sklearn.metrics import make_scorer, accuracy_score
  >>> from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
  >>> from sklearn.model_selection import cross_validate, GridSearchCV, GroupKFold
  >>> from sklearn.feature_selection import SelectKBest
  >>> from sklearn.pipeline import make_pipeline
  >>> n_samples, n_features = 100, 4
  >>> rng = np.random.RandomState(42)
  >>> X = rng.rand(n_samples, n_features)
  >>> y = rng.randint(0, 2, size=n_samples)
  >>> my_groups = rng.randint(0, 10, size=n_samples)
  >>> my_weights = rng.rand(n_samples)
  >>> my_other_weights = rng.rand(n_samples)

Weighted scoring and fitting
----------------------------

Here :class:`~model_selection.GroupKFold` requests ``groups`` by default. However, we
need to explicitly request weights for our scorer and the internal cross validation of
:class:`~linear_model.LogisticRegressionCV`. Both of these *consumers* know how to use
metadata called ``sample_weight``::

  >>> weighted_acc = make_scorer(accuracy_score).set_score_request(
  ...     sample_weight=True
  ... )
  >>> lr = LogisticRegressionCV(
  ...     cv=GroupKFold(), scoring=weighted_acc,
  ... ).set_fit_request(sample_weight=True)
  >>> cv_results = cross_validate(
  ...     lr,
  ...     X,
  ...     y,
  ...     params={"sample_weight": my_weights, "groups": my_groups},
  ...     cv=GroupKFold(),
  ...     scoring=weighted_acc,
  ... )

Note that in this example, ``my_weights`` is passed to both the scorer and
:class:`~linear_model.LogisticRegressionCV`.

Error handling: if ``params={"sample_weigh": my_weights, ...}`` were passed
(note the typo), :func:`~model_selection.cross_validate` would raise an error,
since ``sample_weigh`` was not requested by any of its underlying objects.

Weighted scoring and unweighted fitting
---------------------------------------

When passing metadata such as ``sample_weight`` around, all ``sample_weight``
:term:`consumers <consumer>` require weights to be either explicitly requested
or not requested (i.e. ``True`` or ``False``) when used in another
:term:`router` such as a :class:`~pipeline.Pipeline` or a ``*GridSearchCV``. To
perform an unweighted fit, we need to configure
:class:`~linear_model.LogisticRegressionCV` to not request sample weights, so
that :func:`~model_selection.cross_validate` does not pass the weights along::

  >>> weighted_acc = make_scorer(accuracy_score).set_score_request(
  ...     sample_weight=True
  ... )
  >>> lr = LogisticRegressionCV(
  ...     cv=GroupKFold(), scoring=weighted_acc,
  ... ).set_fit_request(sample_weight=False)
  >>> cv_results = cross_validate(
  ...     lr,
  ...     X,
  ...     y,
  ...     cv=GroupKFold(),
  ...     params={"sample_weight": my_weights, "groups": my_groups},
  ...     scoring=weighted_acc,
  ... )

If :meth:`linear_model.LogisticRegressionCV.set_fit_request` has not
been called, :func:`~model_selection.cross_validate` will raise an
error because ``sample_weight`` is passed in but
:class:`~linear_model.LogisticRegressionCV` would not be explicitly configured
to recognize the weights.

Unweighted feature selection
----------------------------

Setting request values for metadata are only required if the object, e.g. estimator,
scorer, etc., is a consumer of that metadata Unlike
:class:`~linear_model.LogisticRegressionCV`, :class:`~feature_selection.SelectKBest`
doesn't consume weights and therefore no request value for ``sample_weight`` on its
instance is set and ``sample_weight`` is not routed to it::

  >>> weighted_acc = make_scorer(accuracy_score).set_score_request(
  ...     sample_weight=True
  ... )
  >>> lr = LogisticRegressionCV(
  ...     cv=GroupKFold(), scoring=weighted_acc,
  ... ).set_fit_request(sample_weight=True)
  >>> sel = SelectKBest(k=2)
  >>> pipe = make_pipeline(sel, lr)
  >>> cv_results = cross_validate(
  ...     pipe,
  ...     X,
  ...     y,
  ...     cv=GroupKFold(),
  ...     params={"sample_weight": my_weights, "groups": my_groups},
  ...     scoring=weighted_acc,
  ... )

Advanced: Different scoring and fitting weights
-----------------------------------------------

Despite :func:`~metrics.make_scorer` and
:class:`~linear_model.LogisticRegressionCV` both expecting the key
``sample_weight``, we can use aliases to pass different weights to different
consumers. In this example, we pass ``scoring_weight`` to the scorer, and
``fitting_weight`` to :class:`~linear_model.LogisticRegressionCV`::

  >>> weighted_acc = make_scorer(accuracy_score).set_score_request(
  ...    sample_weight="scoring_weight"
  ... )
  >>> lr = LogisticRegressionCV(
  ...     cv=GroupKFold(), scoring=weighted_acc,
  ... ).set_fit_request(sample_weight="fitting_weight")
  >>> cv_results = cross_validate(
  ...     lr,
  ...     X,
  ...     y,
  ...     cv=GroupKFold(),
  ...     params={
  ...         "scoring_weight": my_weights,
  ...         "fitting_weight": my_other_weights,
  ...         "groups": my_groups,
  ...     },
  ...     scoring=weighted_acc,
  ... )

API Interface
*************

A :term:`consumer` is an object (estimator, meta-estimator, scorer, splitter)
which accepts and uses some :term:`metadata` in at least one of its methods
(``fit``, ``predict``, ``inverse_transform``, ``transform``, ``score``,
``split``). Meta-estimators which only forward the metadata to other objects
(the child estimator, scorers, or splitters) and don't use the metadata
themselves are not consumers. (Meta-)Estimators which route metadata to other
objects are :term:`routers <router>`. A(n) (meta-)estimator can be a
:term:`consumer` and a :term:`router` at the same time. (Meta-)Estimators and
splitters expose a ``set_*_request`` method for each method which accepts at
least one metadata. For instance, if an estimator supports ``sample_weight`` in
``fit`` and ``score``, it exposes
``estimator.set_fit_request(sample_weight=value)`` and
``estimator.set_score_request(sample_weight=value)``. Here ``value`` can be:

- ``True``: method requests a ``sample_weight``. This means if the metadata is
  provided, it will be used, otherwise no error is raised.
- ``False``: method does not request a ``sample_weight``.
- ``None``: router will raise an error if ``sample_weight`` is passed. This is
  in almost all cases the default value when an object is instantiated and
  ensures the user sets the metadata requests explicitly when a metadata is
  passed. The only exception are ``Group*Fold`` splitters.
- ``"param_name"``: if this estimator is used in a meta-estimator, the
  meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this
  estimator. This means the mapping between the metadata required by the
  object, e.g. ``sample_weight`` and what is provided by the user, e.g.
  ``my_weights`` is done at the router level, and not by the object, e.g.
  estimator, itself.

Metadata are requested in the same way for scorers using ``set_score_request``.

If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata
request for all objects which potentially can consume ``sample_weight`` should
be set by the user, otherwise an error is raised by the router object. For
example, the following code raises an error, since it hasn't been explicitly
specified whether ``sample_weight`` should be passed to the estimator's scorer
or not::

    >>> param_grid = {"C": [0.1, 1]}
    >>> lr = LogisticRegression().set_fit_request(sample_weight=True)
    >>> try:
    ...     GridSearchCV(
    ...         estimator=lr, param_grid=param_grid
    ...     ).fit(X, y, sample_weight=my_weights)
    ... except ValueError as e:
    ...     print(e)
    [sample_weight] are passed but are not explicitly set as requested or not for
    LogisticRegression.score

The issue can be fixed by explicitly setting the request value::

    >>> lr = LogisticRegression().set_fit_request(
    ...     sample_weight=True
    ... ).set_score_request(sample_weight=False)

At the end we disable the configuration flag for metadata routing::

    >>> sklearn.set_config(enable_metadata_routing=False)

.. _metadata_routing_models:

Metadata Routing Support Status
*******************************
All consumers (i.e. simple estimators which only consume metadata and don't
route them) support metadata routing, meaning they can be used inside
meta-estimators which support metadata routing. However, development of support
for metadata routing for meta-estimators is in progress, and here is a list of
meta-estimators and tools which support and don't yet support metadata routing.


Meta-estimators and functions supporting metadata routing:

- :class:`sklearn.calibration.CalibratedClassifierCV`
- :class:`sklearn.compose.ColumnTransformer`
- :class:`sklearn.feature_selection.SelectFromModel`
- :class:`sklearn.linear_model.ElasticNetCV`
- :class:`sklearn.linear_model.LarsCV`
- :class:`sklearn.linear_model.LassoCV`
- :class:`sklearn.linear_model.LassoLarsCV`
- :class:`sklearn.linear_model.LogisticRegressionCV`
- :class:`sklearn.linear_model.MultiTaskElasticNetCV`
- :class:`sklearn.linear_model.MultiTaskLassoCV`
- :class:`sklearn.model_selection.GridSearchCV`
- :class:`sklearn.model_selection.HalvingGridSearchCV`
- :class:`sklearn.model_selection.HalvingRandomSearchCV`
- :class:`sklearn.model_selection.RandomizedSearchCV`
- :func:`sklearn.model_selection.cross_validate`
- :func:`sklearn.model_selection.cross_val_score`
- :func:`sklearn.model_selection.cross_val_predict`
- :class:`sklearn.multiclass.OneVsOneClassifier`
- :class:`sklearn.multiclass.OneVsRestClassifier`
- :class:`sklearn.multiclass.OutputCodeClassifier`
- :class:`sklearn.multioutput.ClassifierChain`
- :class:`sklearn.multioutput.MultiOutputClassifier`
- :class:`sklearn.multioutput.MultiOutputRegressor`
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
- :class:`sklearn.multioutput.RegressorChain`
- :class:`sklearn.pipeline.Pipeline`

Meta-estimators and tools not supporting metadata routing yet:

- :class:`sklearn.compose.TransformedTargetRegressor`
- :class:`sklearn.covariance.GraphicalLassoCV`
- :class:`sklearn.ensemble.AdaBoostClassifier`
- :class:`sklearn.ensemble.AdaBoostRegressor`
- :class:`sklearn.ensemble.BaggingClassifier`
- :class:`sklearn.ensemble.BaggingRegressor`
- :class:`sklearn.ensemble.StackingClassifier`
- :class:`sklearn.ensemble.StackingRegressor`
- :class:`sklearn.ensemble.VotingClassifier`
- :class:`sklearn.ensemble.VotingRegressor`
- :class:`sklearn.feature_selection.RFE`
- :class:`sklearn.feature_selection.RFECV`
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
- :class:`sklearn.impute.IterativeImputer`
- :class:`sklearn.linear_model.RANSACRegressor`
- :class:`sklearn.linear_model.RidgeClassifierCV`
- :class:`sklearn.linear_model.RidgeCV`
- :class:`sklearn.model_selection.learning_curve`
- :class:`sklearn.model_selection.permutation_test_score`
- :class:`sklearn.model_selection.validation_curve`
- :class:`sklearn.pipeline.FeatureUnion`
- :class:`sklearn.semi_supervised.SelfTrainingClassifier`