metric_at_thresholds#

sklearn.metrics.metric_at_thresholds(y_true, y_score, metric_func, *, sample_weight=None, metric_params=None)[source]#

Compute metric_func per threshold for binary data.

Aids visualization of metric values across thresholds when tuning the decision threshold.

Read more in the User Guide.

Added in version 1.9.

Parameters:
y_truearray-like of shape (n_samples,)

Ground truth (correct) target labels.

y_scorearray-like of shape (n_samples,)

Continuous prediction scores, either estimated probabilities of the positive class or output of a decision_function.

metric_funccallable

The metric function to use. It will be called as metric_func(y_true, y_pred, **metric_params), where y_pred are thresholded predictions, internally computed as y_pred = (y_score >= threshold). The output should be a single numeric or a collection where each element has the same size.

sample_weightarray-like of shape (n_samples,), default=None

Sample weights. If not None, will be passed to metric_func.

metric_paramsdict, default=None

Parameters to pass to metric_func.

Returns:
metric_valuesndarray of shape (n_thresholds,) or (n_thresholds, *n_outputs)

The scores associated with each threshold. If metric_func returns a collection (e.g., a tuple of floats), the output would be a 2D array of shape (n_thresholds, *n_outputs).

thresholdsndarray of shape (n_thresholds,)

The thresholds used to compute the scores.

See also

confusion_matrix_at_thresholds

Compute binary confusion matrix per threshold.

precision_recall_curve

Compute precision-recall pairs for different probability thresholds.

det_curve

Compute error rates for different probability thresholds.

roc_curve

Compute Receiver operating characteristic (ROC) curve.

Examples

>>> import numpy as np
>>> from sklearn.metrics import accuracy_score, metric_at_thresholds
>>> y_true = np.array([0, 0, 1, 1])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> metric_values, thresholds = metric_at_thresholds(
...     y_true, y_score, accuracy_score)
>>> thresholds
array([0.8 , 0.4 , 0.35, 0.1 ])
>>> metric_values
array([0.75, 0.5 , 0.75, 0.5 ])