AUGRC¶
- class torch_uncertainty.metrics.classification.AUGRC(**kwargs)[source]¶
Area Under the Generalized Risk-Coverage curve.
The Area Under the Generalized Risk-Coverage curve (AUGRC) for Selective Classification (SC) performance assessment. It avoids putting too much weight on the most confident samples.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
containing probabilities for each observation.
target
(Tensor
): An int tensor of shape(N, ...)
containing ground-truth labels.
- As output to
forward
andcompute
the metric returns the following output:
augrc
(Tensor
): A scalar tensor containing thearea under the risk-coverage curve
- Parameters:
kwargs – Additional keyword arguments.
- Reference:
Traub et al. Overcoming Common Flaws in the Evaluation of Selective Classification Systems. ArXiv.
Area Under the Risk-Coverage curve.
The Area Under the Risk-Coverage curve (AURC) is the main metric for Selective Classification (SC) performance assessment. It evaluates the quality of uncertainty estimates by measuring the ability to discriminate between correct and incorrect predictions based on their rank (and not their values in contrast with calibration).
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
containing probabilities for each observation.
target
(Tensor
): An int tensor of shape(N, ...)
containing ground-truth labels.
- As output to
forward
andcompute
the metric returns the following output:
aurc
(Tensor
): A scalar tensor containing thearea under the risk-coverage curve
- Parameters:
kwargs – Additional keyword arguments.
- Reference:
- Geifman & El-Yaniv. “Selective classification for deep neural
networks.” In NeurIPS, 2017.
- compute()[source]¶
Compute the Area Under the Generalized Risk-Coverage curve (AUGRC).
Normalize the AUGRC as if its support was between 0 and 1. This has an impact on the AUGRC when the number of samples is small.
- Returns:
The AUGRC.
- Return type:
Tensor
- plot(ax=None, plot_value=True, name=None)[source]¶
Plot the generalized risk-cov. curve corresponding to the inputs passed to
update
.- Parameters:
ax (Axes | None, optional) – An matplotlib axis object. If provided will add plot to this axis. Defaults to None.
plot_value (bool, optional) – Whether to print the AURC value on the plot. Defaults to True.
name (str | None, optional) – Name of the model. Defaults to None.
- Returns:
Figure object and Axes object
- Return type:
tuple[[Figure | None], Axes]