Source code for torch_uncertainty.post_processing.conformal.thr

from typing import Literal

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from .abstract import Conformal


[docs] class ConformalClsTHR(Conformal): def __init__( self, alpha: float, model: nn.Module | None = None, ts_init_val: float = 1.0, ts_lr: float = 0.1, ts_max_iter: int = 100, enable_ts: bool = True, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: r"""Conformal prediction post-processing for calibrated models. Args: alpha (float): The confidence level, meaning we allow :math:`1-\alpha` error. model (nn.Module, optional): Model to be calibrated. Defaults to ``None``. ts_init_val (float, optional): Initial value for the temperature. Defaults to ``1.0``. ts_lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``. ts_max_iter (int, optional): Maximum number of iterations for the optimizer. Defaults to ``100``. enable_ts (bool): Whether to scale the logits. Defaults to ``True``. device (Literal["cpu", "cuda"] | torch.device | None, optional): device. Defaults to ``None``. Reference: - `Least ambiguous set-valued classifiers with bounded error levels, Sadinle, M. et al., (2016) <https://arxiv.org/abs/1609.00451>`_. Code inspired by TorchCP. """ super().__init__( alpha=alpha, model=model, ts_init_val=ts_init_val, ts_lr=ts_lr, ts_max_iter=ts_max_iter, enable_ts=enable_ts, device=device, ) def fit(self, dataloader: DataLoader) -> None: if self.enable_ts: self.model.fit(dataloader=dataloader) logit_list = [] label_list = [] with torch.no_grad(): for images, labels in dataloader: images, labels = images.to(self.device), labels.to(self.device) logit_list.append(self.model_forward(images)) label_list.append(labels) probs = torch.cat(logit_list) labels = torch.cat(label_list).long() true_class_probs = probs.gather(1, labels.unsqueeze(1)).squeeze(1) scores = 1.0 - true_class_probs self.q_hat = torch.quantile(scores, 1.0 - self.alpha).item()
[docs] @torch.no_grad() def conformal(self, inputs: Tensor) -> Tensor: """Perform conformal prediction on the test set.""" probs = self.model_forward(inputs) pred_set = probs >= 1.0 - self.quantile top1 = torch.argmax(probs, dim=1, keepdim=True) pred_set.scatter_(1, top1, True) # Always include top-1 class return pred_set.float() / pred_set.sum(dim=1, keepdim=True)