Source code for torch_uncertainty.metrics.classification.scod_risk_coverage
from torch import Tensor
from .risk_coverage import AUGRC, AURC, CovAtxRisk, RiskAtxCov
[docs]
class SCODAURC(AURC):
r"""Area Under the SCOD Risk-Coverage curve.
This metric extends selective classification to a binary
in-distribution-vs-out-of-distribution (ID/OOD) decision problem.
Let :math:`s_i` be an OOD score (higher means *more OOD*) and
:math:`e_i \in \{0,1\}` the SCOD error indicator with
:math:`e_i = 1` for OOD samples and :math:`e_i = 0` for ID samples.
Samples are sorted by decreasing ID acceptance confidence
:math:`-s_i`, and the risk at coverage :math:`\kappa = k/N` is:
.. math::
r\!\left(\tfrac{k}{N}\right) = \frac{1}{k}\sum_{i=1}^{k} e_{\sigma(i)}.
The SCOD-AURC is then:
.. math::
\text{SCOD-AURC} = \int_0^1 r(\kappa)\,\mathrm{d}\kappa
\approx \frac{1}{N}\sum_{k=1}^{N} r\!\left(\tfrac{k}{N}\right).
As input to ``forward`` and ``update`` the metric accepts the following input:
- **ood_scores** (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``
containing OOD scores.
- **targets** (:class:`~torch.Tensor`): A binary int tensor of shape ``(N, ...)``
with ``0`` for ID and ``1`` for OOD.
"""
[docs]
def update(self, ood_scores: Tensor, targets: Tensor) -> None: # pyrefly: ignore[bad-override]
"""Store SCOD confidence scores and associated detection errors.
Args:
ood_scores: OOD scores where higher means more OOD-like.
targets: Binary labels, with 0 for ID and 1 for OOD.
"""
self.scores.append(-ood_scores.reshape(-1))
self.errors.append(targets.reshape(-1).float())
[docs]
class SCODAUGRC(AUGRC):
r"""Area Under the SCOD Generalized Risk-Coverage curve.
Using the same notation as :class:`SCODAURC`, the generalized SCOD risk is:
.. math::
\text{SCOD-AUGRC} = \int_0^1 \kappa \cdot r(\kappa)\,\mathrm{d}\kappa
\approx \frac{1}{N}\sum_{k=1}^{N}\frac{k}{N}\cdot r\!\left(\tfrac{k}{N}\right).
"""
[docs]
def update(self, ood_scores: Tensor, targets: Tensor) -> None: # pyrefly: ignore[bad-override]
"""Store SCOD confidence scores and associated detection errors.
Args:
ood_scores: OOD scores where higher means more OOD-like.
targets: Binary labels, with 0 for ID and 1 for OOD.
"""
self.scores.append(-ood_scores.reshape(-1))
self.errors.append(targets.reshape(-1).float())
[docs]
class SCODCovAtxRisk(CovAtxRisk):
r"""Coverage at x SCOD risk.
This metric returns the maximum selective coverage for which the SCOD risk
remains below a target threshold.
"""
[docs]
def update(self, ood_scores: Tensor, targets: Tensor) -> None: # pyrefly: ignore[bad-override]
"""Store SCOD confidence scores and associated detection errors.
Args:
ood_scores: OOD scores where higher means more OOD-like.
targets: Binary labels, with 0 for ID and 1 for OOD.
"""
self.scores.append(-ood_scores.reshape(-1))
self.errors.append(targets.reshape(-1).float())
[docs]
class SCODCovAt5Risk(SCODCovAtxRisk):
r"""Coverage at 5% SCOD risk.
This is a specific case of :class:`SCODCovAtxRisk` with
``risk_threshold = 0.05``.
"""
def __init__(self, **kwargs) -> None:
super().__init__(risk_threshold=0.05, **kwargs)
[docs]
class SCODRiskAtxCov(RiskAtxCov):
r"""SCOD risk at x coverage.
This metric returns the SCOD error rate measured at a fixed selective
coverage level.
"""
[docs]
def update(self, ood_scores: Tensor, targets: Tensor) -> None: # pyrefly: ignore[bad-override]
"""Store SCOD confidence scores and associated detection errors.
Args:
ood_scores: OOD scores where higher means more OOD-like.
targets: Binary labels, with 0 for ID and 1 for OOD.
"""
self.scores.append(-ood_scores.reshape(-1))
self.errors.append(targets.reshape(-1).float())
[docs]
class SCODRiskAt80Cov(SCODRiskAtxCov):
r"""SCOD risk at 80% coverage.
This is a specific case of :class:`SCODRiskAtxCov` with
``cov_threshold = 0.8``.
"""
def __init__(self, **kwargs) -> None:
super().__init__(cov_threshold=0.8, **kwargs)