Source code for torch_uncertainty.post_processing.conformal.raps

from typing import Literal

import torch
from torch import Tensor, nn

from .aps import ConformalClsAPS


[docs] class ConformalClsRAPS(ConformalClsAPS): def __init__( self, alpha: float, model: nn.Module | None = None, randomized: bool = True, penalty: float = 0.1, regularization_rank: int = 1, ts_init_val: float = 1.0, ts_lr: float = 0.1, ts_max_iter: int = 100, enable_ts: bool = False, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: r"""Conformal classification with Regularised Adaptive Prediction Sets (RAPS; Angelopoulos, Bates, Jordan & Malik, 2021). A regularised variant of :class:`ConformalClsAPS` that penalises the inclusion of classes with a low predicted rank to produce *smaller* prediction sets without sacrificing coverage. The non-conformity score adds a rank-based regulariser to the APS score: .. math:: s(\mathbf{x}, y) = \underbrace{\sum_{i=1}^{k} \hat{p}_{(i)} - U \cdot \hat{p}_{(k)}}_{\text{APS}} + \lambda \cdot (k - k_\text{reg})_{+}, where :math:`k` is the rank of class :math:`y`, :math:`\lambda` is :attr:`penalty`, :math:`k_\text{reg}` is :attr:`regularization_rank`, and :math:`(\cdot)_+ = \max(\cdot, 0)`. Larger :math:`\lambda` and smaller :math:`k_\text{reg}` produce tighter sets at the cost of a coarser score. Args: alpha: Target mis-coverage level :math:`\alpha \in (0, 1)`. model: Trained classification model. Defaults to ``None``. randomized: Whether to use randomised tie-breaking. Defaults to ``True``. penalty: Regularisation weight :math:`\lambda`. Defaults to ``0.1``. regularization_rank: Rank threshold :math:`k_\text{reg}` above which the penalty is applied. Defaults to ``1``. ts_init_val: Initial value for the temperature. Defaults to ``1.0``. ts_lr: Learning rate for the temperature scaling optimizer. Defaults to ``0.1``. ts_max_iter: Maximum number of iterations for the temperature scaling optimizer. Defaults to ``100``. enable_ts: Whether to apply temperature scaling before computing the conformal scores. Defaults to ``False``. device: Device to use. Defaults to ``None``. Warning: This implementation only works in the multiclass setting. Raise an issue if binary support is needed. Reference: - `Angelopoulos, A. N., Bates, S., Jordan, M., & Malik, J. (2021). Uncertainty Sets for Image Classifiers using Conformal Prediction. ICLR 2021 <https://arxiv.org/abs/2009.14193>`_. Code inspired by TorchCP. """ super().__init__( alpha=alpha, model=model, randomized=randomized, ts_init_val=ts_init_val, ts_lr=ts_lr, ts_max_iter=ts_max_iter, enable_ts=enable_ts, device=device, ) if penalty < 0: raise ValueError(f"penalty should be non-negative. Got {penalty}.") if not isinstance(regularization_rank, int): raise TypeError(f"regularization_rank should be an integer. Got {regularization_rank}.") if regularization_rank < 0: raise ValueError( f"regularization_rank should be non-negative. Got {regularization_rank}." ) self.penalty = penalty self.regularization_rank = regularization_rank def _calculate_all_labels(self, probs: Tensor) -> Tensor: indices, ordered, cumsum = self._sort_sum(probs) if self.randomized: noise = torch.rand(probs.shape, device=probs.device) else: noise = torch.zeros_like(probs) reg = torch.maximum( self.penalty * ( torch.arange(1, probs.shape[-1] + 1, device=probs.device) - self.regularization_rank ), torch.tensor(0, device=probs.device), ) ordered_scores = cumsum - ordered * noise + reg _, sorted_indices = torch.sort(indices, descending=False, dim=-1) return ordered_scores.gather(dim=-1, index=sorted_indices) def _calculate_single_label(self, probs: Tensor, label: Tensor) -> Tensor: indices, ordered, cumsum = self._sort_sum(probs) if self.randomized: noise = torch.rand(indices.shape[0], device=probs.device) else: noise = torch.zeros(indices.shape[0], device=probs.device) idx = torch.where(indices == label.view(-1, 1)) reg = torch.maximum( self.penalty * (idx[1] + 1 - self.regularization_rank), torch.tensor(0).to(probs.device) ) return cumsum[idx] - noise * ordered[idx] + reg