Shortcuts

Source code for torch_uncertainty.metrics.classification.risk_coverage

import math

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities.compute import _auc_compute
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.plot import _AX_TYPE


[docs]class AURC(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False scores: list[Tensor] errors: list[Tensor] def __init__(self, **kwargs) -> None: r"""Area Under the Risk-Coverage curve. The Area Under the Risk-Coverage curve (AURC) is the main metric for Selective Classification (SC) performance assessment. It evaluates the quality of uncertainty estimates by measuring the ability to discriminate between correct and incorrect predictions based on their rank (and not their values in contrast with calibration). As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities for each observation. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground-truth labels. As output to ``forward`` and ``compute`` the metric returns the following output: - ``aurc`` (:class:`~torch.Tensor`): A scalar tensor containing the area under the risk-coverage curve Args: kwargs: Additional keyword arguments. Reference: Geifman & El-Yaniv. "Selective classification for deep neural networks." In NeurIPS, 2017. """ super().__init__(**kwargs) self.add_state("scores", default=[], dist_reduce_fx="cat") self.add_state("errors", default=[], dist_reduce_fx="cat")
[docs] def update(self, probs: Tensor, targets: Tensor) -> None: """Store the scores and their associated errors for later computation. Args: probs (Tensor): The predicted probabilities of shape :math:`(N, C)`. targets (Tensor): The ground truth labels of shape :math:`(N,)`. """ if probs.ndim == 1: probs = torch.stack([1 - probs, probs], dim=-1) self.scores.append(probs.max(-1).values) self.errors.append((probs.argmax(-1) != targets) * 1.0)
[docs] def partial_compute(self) -> Tensor: """Compute the error and optimal error rates for the RC curve. Returns: Tensor: The error rates and the optimal/oracle error rates. """ scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) return _aurc_rejection_rate_compute(scores, errors)
[docs] def compute(self) -> Tensor: """Compute the Area Under the Risk-Coverage curve (AURC). Normalize the AURC as if its support was between 0 and 1. This has an impact on the AURC when the number of samples is small. Returns: Tensor: The AURC. """ error_rates = self.partial_compute() num_samples = error_rates.size(0) if num_samples < 2: return torch.tensor([float("nan")], device=self.device) cov = torch.arange(1, num_samples + 1, device=self.device) / num_samples return _auc_compute(cov, error_rates) / (1 - 1 / num_samples)
[docs] def plot( self, ax: _AX_TYPE | None = None, plot_value: bool = True, name: str | None = None, ) -> tuple[plt.Figure | None, plt.Axes]: """Plot the risk-cov. curve corresponding to the inputs passed to ``update``. Args: ax (Axes | None, optional): An matplotlib axis object. If provided will add plot to this axis. Defaults to None. plot_value (bool, optional): Whether to print the AURC value on the plot. Defaults to True. name (str | None, optional): Name of the model. Defaults to None. Returns: tuple[[Figure | None], Axes]: Figure object and Axes object """ fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) # Computation of AURC error_rates = self.partial_compute().cpu().flip(0) num_samples = error_rates.size(0) x = torch.arange(num_samples) / num_samples aurc = _auc_compute(x, error_rates).cpu().item() # reduce plot size plot_xs = np.arange(0.01, 100 + 0.01, 0.01) xs = np.arange(start=1, stop=num_samples + 1) / num_samples rejection_rates = np.interp(plot_xs, xs, x * 100) error_rates = np.interp(plot_xs, xs, error_rates) # plot ax.plot( 100 - rejection_rates, error_rates * 100, label="Model" if name is None else name, ) if plot_value: ax.text( 0.02, 0.95, f"AURC={aurc:.2%}", color="black", ha="left", va="bottom", transform=ax.transAxes, ) plt.grid(True, linestyle="--", alpha=0.7, zorder=0) ax.set_xlabel("Coverage (%)", fontsize=16) ax.set_ylabel("Risk - Error Rate (%)", fontsize=16) ax.set_xlim(0, 100) ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100))) ax.legend(loc="upper right") return fig, ax
def _aurc_rejection_rate_compute( scores: Tensor, errors: Tensor, ) -> Tensor: """Compute the cumulative error rates for a given set of scores and errors. Args: scores (Tensor): uncertainty scores of shape :math:`(B,)` errors (Tensor): binary errors of shape :math:`(B,)` """ errors = errors[scores.argsort(descending=True)] return errors.cumsum(dim=-1) / torch.arange( 1, scores.size(0) + 1, dtype=scores.dtype, device=scores.device )
[docs]class AUGRC(AURC): """Area Under the Generalized Risk-Coverage curve. The Area Under the Generalized Risk-Coverage curve (AUGRC) for Selective Classification (SC) performance assessment. It avoids putting too much weight on the most confident samples. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities for each observation. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground-truth labels. As output to ``forward`` and ``compute`` the metric returns the following output: - ``augrc`` (:class:`~torch.Tensor`): A scalar tensor containing the area under the risk-coverage curve Args: kwargs: Additional keyword arguments. Reference: Traub et al. Overcoming Common Flaws in the Evaluation of Selective Classification Systems. ArXiv. """
[docs] def compute(self) -> Tensor: """Compute the Area Under the Generalized Risk-Coverage curve (AUGRC). Normalize the AUGRC as if its support was between 0 and 1. This has an impact on the AUGRC when the number of samples is small. Returns: Tensor: The AUGRC. """ error_rates = self.partial_compute() num_samples = error_rates.size(0) if num_samples < 2: return torch.tensor([float("nan")], device=self.device) cov = torch.arange(1, num_samples + 1, device=self.device) / num_samples return _auc_compute(cov, error_rates * cov) / (1 - 1 / num_samples)
[docs] def plot( self, ax: _AX_TYPE | None = None, plot_value: bool = True, name: str | None = None, ) -> tuple[plt.Figure | None, plt.Axes]: """Plot the generalized risk-cov. curve corresponding to the inputs passed to ``update``. Args: ax (Axes | None, optional): An matplotlib axis object. If provided will add plot to this axis. Defaults to None. plot_value (bool, optional): Whether to print the AURC value on the plot. Defaults to True. name (str | None, optional): Name of the model. Defaults to None. Returns: tuple[[Figure | None], Axes]: Figure object and Axes object """ fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) # Computation of AUGRC error_rates = self.partial_compute().cpu().flip(0) num_samples = error_rates.size(0) cov = torch.arange(num_samples) / num_samples augrc = _auc_compute(cov, error_rates * cov).cpu().item() # reduce plot size plot_covs = np.arange(0.01, 100 + 0.01, 0.01) covs = np.arange(start=1, stop=num_samples + 1) / num_samples rejection_rates = np.interp(plot_covs, covs, cov * 100) error_rates = np.interp(plot_covs, covs, error_rates * covs[::-1] * 100) # plot ax.plot( 100 - rejection_rates, error_rates, label="Model" if name is None else name, ) if plot_value: ax.text( 0.02, 0.95, f"AUGRC={augrc:.2%}", color="black", ha="left", va="bottom", transform=ax.transAxes, ) plt.grid(True, linestyle="--", alpha=0.7, zorder=0) ax.set_xlabel("Coverage (%)", fontsize=16) ax.set_ylabel("Generalized Risk (%)", fontsize=16) ax.set_xlim(0, 100) ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100))) ax.legend(loc="upper right") return fig, ax
[docs]class CovAtxRisk(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False scores: list[Tensor] errors: list[Tensor] def __init__(self, risk_threshold: float, **kwargs) -> None: r"""Coverage at x Risk. If there are multiple coverage values corresponding to the given risk, i.e., the risk(coverage) is not monotonic, the coverage at x risk is the maximum coverage value corresponding to the given risk. If no there is no coverage value corresponding to the given risk, return float("nan"). Args: risk_threshold (float): The risk threshold at which to compute the coverage. kwargs: Additional arguments to pass to the metric class. """ super().__init__(**kwargs) self.add_state("scores", default=[], dist_reduce_fx="cat") self.add_state("errors", default=[], dist_reduce_fx="cat") _risk_coverage_checks(risk_threshold) self.risk_threshold = risk_threshold
[docs] def update(self, probs: Tensor, targets: Tensor) -> None: """Store the scores and their associated errors for later computation. Args: probs (Tensor): The predicted probabilities of shape :math:`(N, C)`. targets (Tensor): The ground truth labels of shape :math:`(N,)`. """ if probs.ndim == 1: probs = torch.stack([1 - probs, probs], dim=-1) self.scores.append(probs.max(-1).values) self.errors.append((probs.argmax(-1) != targets) * 1.0)
[docs] def compute(self) -> Tensor: """Compute the coverage at x Risk. Returns: Tensor: The coverage at x risk. """ scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) num_samples = scores.size(0) if num_samples < 1: return torch.tensor([float("nan")], device=self.device) error_rates = _aurc_rejection_rate_compute(scores, errors) admissible_risks = (error_rates > self.risk_threshold) * 1 max_cov_at_risk = admissible_risks.flip(0).argmin() # check if max_cov_at_risk is really admissible, if not return nan risk = admissible_risks[max_cov_at_risk] if risk > self.risk_threshold: return torch.tensor([float("nan")], device=self.device) return 1 - max_cov_at_risk / num_samples
[docs]class CovAt5Risk(CovAtxRisk): def __init__(self, **kwargs) -> None: r"""Coverage at 5% Risk. If there are multiple coverage values corresponding to 5% risk, the coverage at 5% risk is the maximum coverage value corresponding to 5% risk. If no there is no coverage value corresponding to the given risk, this metric returns float("nan"). """ super().__init__(risk_threshold=0.05, **kwargs)
[docs]class RiskAtxCov(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False scores: list[Tensor] errors: list[Tensor] def __init__(self, cov_threshold: float, **kwargs) -> None: r"""Risk at given Coverage. Args: cov_threshold (float): The coverage threshold at which to compute the risk. kwargs: Additional arguments to pass to the metric class. """ super().__init__(**kwargs) self.add_state("scores", default=[], dist_reduce_fx="cat") self.add_state("errors", default=[], dist_reduce_fx="cat") _risk_coverage_checks(cov_threshold) self.cov_threshold = cov_threshold
[docs] def update(self, probs: Tensor, targets: Tensor) -> None: """Store the scores and their associated errors for later computation. Args: probs (Tensor): The predicted probabilities of shape :math:`(N, C)`. targets (Tensor): The ground truth labels of shape :math:`(N,)`. """ if probs.ndim == 1: probs = torch.stack([1 - probs, probs], dim=-1) self.scores.append(probs.max(-1).values) self.errors.append((probs.argmax(-1) != targets) * 1.0)
[docs] def compute(self) -> Tensor: """Compute the risk at given coverage. Returns: Tensor: The risk at given coverage. """ scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) error_rates = _aurc_rejection_rate_compute(scores, errors) return error_rates[math.ceil(scores.size(0) * self.cov_threshold) - 1]
[docs]class RiskAt80Cov(RiskAtxCov): def __init__(self, **kwargs) -> None: r"""Risk at 80% Coverage.""" super().__init__(cov_threshold=0.8, **kwargs)
def _risk_coverage_checks(threshold: float) -> None: if not isinstance(threshold, float): raise TypeError(f"Expected threshold to be of type float, but got {type(threshold)}") if threshold < 0 or threshold > 1: raise ValueError(f"Threshold should be in the range [0, 1], but got {threshold}.")