Source code for torch_uncertainty.post_processing.conformal.thr

from typing import Literal

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from .abstract import Conformal


[docs] class ConformalClsTHR(Conformal): def __init__( self, alpha: float, model: nn.Module | None = None, ts_init_val: float = 1.0, ts_lr: float = 0.1, ts_max_iter: int = 100, enable_ts: bool = True, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: r"""Threshold-based conformal classifier (THR; Sadinle et al., 2019). The simplest conformal classification rule. Defines the non-conformity score of class :math:`c` as :math:`s(\mathbf{x}, c) = 1 - \hat{p}_c(\mathbf{x})`, calibrates the empirical :math:`(1 - \alpha)`-quantile :math:`\hat{q}` of the scores at the true class on a held-out calibration set, and at test time outputs the prediction set .. math:: \mathcal{C}(\mathbf{x}) = \{ c : \hat{p}_c(\mathbf{x}) \geq 1 - \hat{q} \}, guaranteeing marginal coverage of :math:`1 - \alpha`. The top-1 class is always included to avoid empty sets. Probabilities can optionally be calibrated first via temperature scaling (``enable_ts=True``), which usually yields smaller prediction sets. Args: alpha: Target mis-coverage level :math:`\alpha \in (0, 1)`. model: Model to be calibrated. Defaults to ``None``. ts_init_val: Initial value for the temperature. Defaults to ``1.0``. ts_lr: Learning rate for the temperature scaling optimizer. Defaults to ``0.1``. ts_max_iter: Maximum number of iterations for the temperature scaling optimizer. Defaults to ``100``. enable_ts: Whether to scale the logits via temperature scaling before computing the conformal scores. Defaults to ``True``. device: Device to use. Defaults to ``None``. Warning: This implementation only works in the multiclass setting. Raise an issue if binary support is needed. Reference: - `Sadinle, M., Lei, J., & Wasserman, L. (2019). Least Ambiguous Set-Valued Classifiers with Bounded Error Levels <https://arxiv.org/abs/1609.00451>`_. Code inspired by TorchCP. """ super().__init__( alpha=alpha, model=model, ts_init_val=ts_init_val, ts_lr=ts_lr, ts_max_iter=ts_max_iter, enable_ts=enable_ts, device=device, ) def fit(self, dataloader: DataLoader) -> None: assert self.model is not None if self.enable_ts: self.model.fit(dataloader=dataloader) logit_list = [] label_list = [] with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(self.device), labels.to(self.device) logit_list.append(self.model_forward(inputs)) label_list.append(labels) probs = torch.cat(logit_list) labels = torch.cat(label_list).long() true_class_probs = probs.gather(1, labels.unsqueeze(1)).squeeze(1) scores = 1.0 - true_class_probs self.q_hat = torch.quantile(scores, 1.0 - self.alpha)
[docs] @torch.no_grad() def conformal(self, inputs: Tensor) -> Tensor: """Perform conformal prediction on the test set.""" probs = self.model_forward(inputs) pred_set = probs >= 1.0 - self.quantile top1 = torch.argmax(probs, dim=1, keepdim=True) pred_set.scatter_(1, top1, True) # Always include top-1 class return pred_set.float() / pred_set.sum(dim=1, keepdim=True)