SCODAURC#

class torch_uncertainty.metrics.classification.SCODAURC(**kwargs)[source]#

Calculate Area Under the Risk-Coverage curve.

The Area Under the Risk-Coverage curve (AURC) is the main metric for Selective Classification (SC) performance assessment. It evaluates the quality of uncertainty estimates by measuring the ability to discriminate between correct and incorrect predictions based on their rank (and not their values in contrast with calibration).

Let \(\sigma\) be the permutation sorting the \(N\) samples by descending top-class confidence, so that \(\hat{p}_{\sigma(1)} \geq \cdots \geq \hat{p}_{\sigma(N)}\). The error rate at coverage \(\kappa = k/N\) is

\[r\!\left(\tfrac{k}{N}\right) = \frac{1}{k} \sum_{i=1}^{k} \mathbf{1}\!\left[\hat{y}_{\sigma(i)} \neq y_{\sigma(i)}\right]\]

and the AURC is

\[\text{AURC} = \int_0^1 r(\kappa)\,\mathrm{d}\kappa \approx \frac{1}{N} \sum_{k=1}^{N} r\!\left(\tfrac{k}{N}\right)\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...) containing probabilities for each observation.

  • target (Tensor): An int tensor of shape (N, ...) containing ground-truth labels.

As output to forward and compute the metric returns the following output:

  • Aurc (Tensor): A scalar tensor containing the area under the risk-coverage curve

Parameters:

kwargs – Additional keyword arguments.

Example

>>> from torch_uncertainty.metrics.classification import AURC
>>> aurc = AURC()
>>> probs = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.8, 0.2]])
>>> targets = torch.tensor([0, 1, 0])
>>> aurc.update(probs, targets)
>>> result = aurc.compute()
>>> print(result)
tensor(0.0833)  # Example output

References

[1] Geifman & El-Yaniv. Selective classification for deep neural networks. In NeurIPS, 2017.

update(ood_scores, targets)[source]#

Store SCOD confidence scores and associated detection errors.

Parameters:
  • ood_scores (Tensor) – OOD scores where higher means more OOD-like.

  • targets (Tensor) – Binary labels, with 0 for ID and 1 for OOD.

Return type:

None