metric_at_thresholds#
- sklearn.metrics.metric_at_thresholds(y_true, y_score, metric_func, *, sample_weight=None, metric_params=None)[source]#
Compute
metric_funcper threshold for binary data.Aids visualization of metric values across thresholds when tuning the decision threshold.
Read more in the User Guide.
Added in version 1.9.
- Parameters:
- y_truearray-like of shape (n_samples,)
Ground truth (correct) target labels.
- y_scorearray-like of shape (n_samples,)
Continuous prediction scores, either estimated probabilities of the positive class or output of a decision_function.
- metric_funccallable
The metric function to use. It will be called as
metric_func(y_true, y_pred, **metric_params), wherey_predare thresholded predictions, internally computed asy_pred = (y_score >= threshold). The output should be a single numeric or a collection where each element has the same size.- sample_weightarray-like of shape (n_samples,), default=None
Sample weights. If not
None, will be passed tometric_func.- metric_paramsdict, default=None
Parameters to pass to
metric_func.
- Returns:
- metric_valuesndarray of shape (n_thresholds,) or (n_thresholds, *n_outputs)
The scores associated with each threshold. If
metric_funcreturns a collection (e.g., a tuple of floats), the output would be a 2D array of shape (n_thresholds, *n_outputs).- thresholdsndarray of shape (n_thresholds,)
The thresholds used to compute the scores.
See also
confusion_matrix_at_thresholdsCompute binary confusion matrix per threshold.
precision_recall_curveCompute precision-recall pairs for different probability thresholds.
det_curveCompute error rates for different probability thresholds.
roc_curveCompute Receiver operating characteristic (ROC) curve.
Examples
>>> import numpy as np >>> from sklearn.metrics import accuracy_score, metric_at_thresholds >>> y_true = np.array([0, 0, 1, 1]) >>> y_score = np.array([0.1, 0.4, 0.35, 0.8]) >>> metric_values, thresholds = metric_at_thresholds( ... y_true, y_score, accuracy_score) >>> thresholds array([0.8 , 0.4 , 0.35, 0.1 ]) >>> metric_values array([0.75, 0.5 , 0.75, 0.5 ])