AUGRC#
- class torch_uncertainty.metrics.classification.AUGRC(**kwargs)[source]#
Calculate The Area Under the Generalized Risk-Coverage curve (AUGRC).
The Area Under the Generalized Risk-Coverage curve (AUGRC) for selective classification performance assessment. It avoids putting too much weight on the most confident samples.
As input to
forwardandupdatethe metric accepts the following input:preds (
Tensor): A float tensor of shape(N, ...)containing probabilities for each observation.target (
Tensor): An int tensor of shape(N, ...)containing ground-truth labels.
As output to
forwardandcomputethe metric returns the following output:Augrc (
Tensor): A scalar tensor containing the area under the risk-coverage curve
- Parameters:
kwargs – Additional keyword arguments.
References
[1] Traub et al. Overcoming Common Flaws in the Evaluation of Selective Classification Systems.
See also
AURC: Parent class, the AURC metric
- compute()[source]#
Normalize the AUGRC as if its support was between 0 and 1. This has an impact on the AUGRC when the number of samples is small.
- Returns:
The AUGRC.
- Return type:
Tensor
- plot(ax=None, plot_value=True, name=None)[source]#
Plot the generalized risk-cov. curve corresponding to the inputs passed to
update.- Parameters:
ax (
Axes|None) – A matplotlib axis object. If provided, the plot is added to this axis. Defaults toNone.plot_value (
bool) – Whether to print the AURC value on the plot. Defaults toTrue.name (
str|None) – Name of the model. Defaults toNone.
- Returns:
Figure object and axes object.
- Return type:
tuple[Figure | None, Axes]