Shortcuts

Source code for torch_uncertainty.metrics.sparsification

from importlib import util

import matplotlib.pyplot as plt
import numpy as np
import torch

if util.find_spec("sklearn"):
    from sklearn.metrics import auc

    sklearn_installed = True
else:  # coverage: ignore
    sklearn_installed = False

from torch import Tensor
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.plot import _AX_TYPE


[docs]class AUSE(Metric): is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 100.0 plot_legend_name: str = "Sparsification Curves" scores: list[Tensor] errors: list[Tensor] def __init__(self, **kwargs) -> None: r"""The Area Under the Sparsification Error curve (AUSE) metric to estimate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors. Args: kwargs: Additional keyword arguments, see `Advanced metric settings <https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_. Reference: From the paper `Uncertainty estimates and multi-hypotheses for optical flow <https://arxiv.org/abs/1802.07095>`_. In ECCV, 2018. Inputs: - :attr:`scores`: Uncertainty scores of shape :math:`(B,)`. A higher score means a higher uncertainty. - :attr:`errors`: Errors of shape :math:`(B,)`, where :math:`B` is the batch size. Note: A higher AUSE means a lower quality of the uncertainty estimates. """ super().__init__(**kwargs) self.add_state("scores", default=[], dist_reduce_fx="cat") self.add_state("errors", default=[], dist_reduce_fx="cat") if not sklearn_installed: raise ImportError("Please install scikit-learn to use AUSE.")
[docs] def update(self, scores: Tensor, errors: Tensor) -> None: """Store the scores and their associated errors for later computation. Args: scores (Tensor): uncertainty scores of shape :math:`(B,)` errors (Tensor): errors of shape :math:`(B,)` """ self.scores.append(scores) self.errors.append(errors)
def partial_compute(self) -> tuple[Tensor, Tensor]: scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) error_rates = _ause_rejection_rate_compute(scores, errors) optimal_error_rates = _ause_rejection_rate_compute(errors, errors) return error_rates.cpu(), optimal_error_rates.cpu()
[docs] def compute(self) -> Tensor: """Compute the Area Under the Sparsification Error curve (AUSE) based on inputs passed to ``update``. Returns: Tensor: The AUSE. """ error_rates, optimal_error_rates = self.partial_compute() num_samples = error_rates.size(0) x = np.arange(1, num_samples + 1) / num_samples y = (error_rates - optimal_error_rates).numpy() return torch.tensor([auc(x, y)])
[docs] def plot( self, ax: _AX_TYPE | None = None, plot_oracle: bool = True, plot_value: bool = True, ) -> tuple[plt.Figure | None, plt.Axes]: """Plot the sparsification curve corresponding to the inputs passed to ``update``, and the oracle sparsification curve. Args: ax (Axes | None, optional): An matplotlib axis object. If provided will add plot to this axis. Defaults to None. plot_oracle (bool, optional): Whether to plot the oracle sparsification curve. Defaults to True. plot_value (bool, optional): Whether to plot the AUSE value. Defaults to True. Returns: tuple[[Figure | None], Axes]: Figure object and Axes object """ fig, ax = plt.subplots() if ax is None else (None, ax) # Computation of AUSEC error_rates, optimal_error_rates = self.partial_compute() num_samples = error_rates.size(0) x = np.arange(num_samples) / num_samples y = (error_rates - optimal_error_rates).numpy() ausec = auc(x, y) rejection_rates = (np.arange(num_samples) / num_samples) * 100 ax.plot( rejection_rates, error_rates * 100, label="Model", ) if plot_oracle: ax.plot( rejection_rates, optimal_error_rates * 100, label="Oracle", ) ax.set_xlabel("Rejection Rate (%)") ax.set_ylabel("Error Rate (%)") ax.set_xlim(self.plot_lower_bound, self.plot_upper_bound) ax.set_ylim(self.plot_lower_bound, self.plot_upper_bound) ax.legend(loc="upper right") if plot_value: ax.text( 0.02, 0.02, f"AUSEC={ausec:.03}", color="black", ha="left", va="bottom", transform=ax.transAxes, ) return fig, ax
def _ause_rejection_rate_compute( scores: Tensor, errors: Tensor, ) -> Tensor: """Compute the cumulative error rates for a given set of scores and errors. Args: scores (Tensor): uncertainty scores of shape :math:`(B,)` errors (Tensor): errors of shape :math:`(B,)` """ num_samples = errors.size(0) order = scores.argsort() errors = errors[order] error_rates = torch.zeros(num_samples + 1) error_rates[0] = errors.sum() error_rates[1:] = errors.cumsum(dim=-1).flip(0) return error_rates / error_rates[0]