Source code for torch_uncertainty.post_processing.conformal.raps

from typing import Literal

import torch
from torch import Tensor, nn

from .aps import ConformalClsAPS


[docs] class ConformalClsRAPS(ConformalClsAPS): def __init__( self, alpha: float, model: nn.Module | None = None, randomized: bool = True, penalty: float = 0.1, regularization_rank: int = 1, ts_init_val: float = 1.0, ts_lr: float = 0.1, ts_max_iter: int = 100, enable_ts: bool = False, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: r"""Conformal prediction with RAPS scores. Args: alpha (float): The confidence level meaning we allow :math:`1-\alpha` error. model (nn.Module): Trained classification model. Defaults to ``None``. randomized (bool): Whether to use randomized smoothing in RAPS. Defaults to ``True``. penalty (float): Regularization weight. Defaults to ``0.1``. regularization_rank (int): Rank threshold for regularization. Defaults to ``1``. ts_init_val (float, optional): Initial value for the temperature. Defaults to ``1.0``. ts_lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``. ts_max_iter (int, optional): Maximum number of iterations for the optimizer. Defaults to ``100``. enable_ts (bool): Whether to scale the logits. Defaults to ``False``. device (Literal["cpu", "cuda"] | torch.device | None, optional): device. Defaults to ``None``. Reference: - TODO: Code inspired by TorchCP. """ super().__init__( alpha=alpha, model=model, randomized=randomized, ts_init_val=ts_init_val, ts_lr=ts_lr, ts_max_iter=ts_max_iter, enable_ts=enable_ts, device=device, ) if penalty < 0: raise ValueError(f"penalty should be non-negative. Got {penalty}.") if not isinstance(regularization_rank, int): raise TypeError(f"regularization_rank should be an integer. Got {regularization_rank}.") if regularization_rank < 0: raise ValueError( f"regularization_rank should be non-negative. Got {regularization_rank}." ) self.penalty = penalty self.regularization_rank = regularization_rank def _calculate_all_labels(self, probs: Tensor) -> Tensor: indices, ordered, cumsum = self._sort_sum(probs) if self.randomized: noise = torch.rand(probs.shape, device=probs.device) else: noise = torch.zeros_like(probs) reg = torch.maximum( self.penalty * ( torch.arange(1, probs.shape[-1] + 1, device=probs.device) - self.regularization_rank ), torch.tensor(0, device=probs.device), ) ordered_scores = cumsum - ordered * noise + reg _, sorted_indices = torch.sort(indices, descending=False, dim=-1) return ordered_scores.gather(dim=-1, index=sorted_indices) def _calculate_single_label(self, probs: Tensor, label: Tensor) -> Tensor: indices, ordered, cumsum = self._sort_sum(probs) if self.randomized: noise = torch.rand(indices.shape[0], device=probs.device) else: noise = torch.zeros(indices.shape[0], device=probs.device) idx = torch.where(indices == label.view(-1, 1)) reg = torch.maximum( self.penalty * (idx[1] + 1 - self.regularization_rank), torch.tensor(0).to(probs.device) ) return cumsum[idx] - noise * ordered[idx] + reg