Source code for torch_uncertainty.post_processing.conformal.abstract

from abc import abstractmethod
from typing import Literal

import torch
from torch import Tensor, nn

from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.post_processing.abstract import PostProcessing


[docs] class Conformal(PostProcessing): q_hat: Tensor | None = None def __init__( self, alpha: float, model: nn.Module | None, ts_init_val: float, ts_lr: float, ts_max_iter: int, enable_ts: bool, device: Literal["cpu", "cuda"] | torch.device | None, ) -> None: r"""Abstract base class for split-conformal classification predictors. Builds prediction sets :math:`\mathcal{C}(\mathbf{x}) \subseteq \{1, \dots, C\}` with the *marginal coverage guarantee* .. math:: \mathbb{P}\!\left[ Y \in \mathcal{C}(X) \right] \geq 1 - \alpha, provided that the calibration and test points are exchangeable. At fit time, non-conformity scores are computed on a held-out calibration set and the empirical :math:`(1 - \alpha)`-quantile :math:`\hat{q}` is stored in :attr:`q_hat`. At test time, the prediction set is built by including all classes whose conformal score is below :math:`\hat{q}`. See :class:`ConformalClsTHR`, :class:`ConformalClsAPS`, and :class:`ConformalClsRAPS` for concrete scores. Args: alpha: Target mis-coverage level :math:`\alpha \in (0, 1)`. A smaller :math:`\alpha` yields larger prediction sets. model: Underlying classifier. ts_init_val: Initial temperature used when :attr:`enable_ts` is ``True``. ts_lr: Learning rate for the temperature optimizer. ts_max_iter: Maximum number of iterations for the temperature optimizer. enable_ts: If ``True``, wraps the model in a :class:`TemperatureScaler` fit on the calibration set before computing the conformal scores. device: Device to run the post-processing on. Warning: This implementation only works in the multiclass setting. Raise an issue if binary support is needed. """ super().__init__(model=model) self.alpha = alpha self.enable_ts = enable_ts if enable_ts: self.model = TemperatureScaler( model=model, init_temperature=ts_init_val, lr=ts_lr, max_iter=ts_max_iter, device=device, ) else: self.model = model self.device = device or "cpu" def set_model(self, model: nn.Module) -> None: if self.enable_ts: assert self.model is not None self.model.set_model(model=model.eval()) else: self.model = model
[docs] def model_forward(self, inputs: Tensor) -> Tensor: """Apply the model and return the scores.""" assert self.model is not None self.model.eval() return self.model(inputs.to(self.device)).softmax(-1)
[docs] @abstractmethod def conformal(self, inputs: Tensor) -> Tensor: """Apply the conformal prediction rule to the inputs."""
def forward(self, inputs: Tensor) -> Tensor: return self.conformal(inputs) @property def quantile(self) -> Tensor: if self.q_hat is None: raise RuntimeError("Quantile q_hat is not set. Run `.fit()` first.") return self.q_hat @property def temperature(self) -> float: if self.enable_ts and (self.model is not None): return self.model.temperature[0].item() raise RuntimeError("Cannot return temperature when enable_ts is False.")