Source code for torch_uncertainty.post_processing.conformal.aps
from typing import Literal
import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader
from .abstract import Conformal
[docs]
class ConformalClsAPS(Conformal):
def __init__(
self,
alpha: float,
model: nn.Module | None = None,
randomized: bool = True,
ts_init_val: float = 1,
ts_lr: float = 0.1,
ts_max_iter: int = 100,
enable_ts: bool = True,
device: Literal["cpu", "cuda"] | torch.device | None = None,
) -> None:
r"""Conformal prediction with APS scores.
Args:
alpha (float): The confidence level meaning we allow :math:`1-\alpha` error.
model (nn.Module): Trained classification model. Defaults to ``None``.
randomized (bool): Whether to use randomized smoothing in APS. Defaults to ``True``.
ts_init_val (float, optional): Initial value for the temperature.
Defaults to ``1.0``.
ts_lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``.
ts_max_iter (int, optional): Maximum number of iterations for the
optimizer. Defaults to ``100``.
enable_ts (bool): Whether to scale the logits. Defaults to ``False``.
device (Literal["cpu", "cuda"] | torch.device | None, optional): device.
Defaults to ``None``.
Reference:
- TODO:
Code inspired by TorchCP.
"""
super().__init__(
alpha=alpha,
model=model,
ts_init_val=ts_init_val,
ts_lr=ts_lr,
ts_max_iter=ts_max_iter,
enable_ts=enable_ts,
device=device,
)
self.randomized = randomized
[docs]
def model_forward(self, inputs: Tensor) -> Tensor:
"""Apply the model and return the scores."""
self.model.eval()
return self.model(inputs.to(self.device)).softmax(-1)
def _sort_sum(self, probs: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""Sort probabilities and compute cumulative sums."""
ordered, indices = torch.sort(probs, dim=-1, descending=True)
cumsum = torch.cumsum(ordered, dim=-1)
return indices, ordered, cumsum
def _calculate_all_labels(self, probs: Tensor) -> Tensor:
"""Calculate APS scores for all labels."""
indices, ordered, cumsum = self._sort_sum(probs)
if self.randomized:
noise = torch.rand(probs.shape, device=probs.device)
else:
noise = torch.zeros_like(probs)
ordered_scores = cumsum - ordered * noise
_, sorted_indices = torch.sort(indices, descending=False, dim=-1)
return ordered_scores.gather(dim=-1, index=sorted_indices)
def _calculate_single_label(self, probs: Tensor, label: Tensor) -> Tensor:
"""Calculate APS score for a single label."""
indices, ordered, cumsum = self._sort_sum(probs)
if self.randomized:
noise = torch.rand(indices.shape[0], device=probs.device)
else:
noise = torch.zeros(indices.shape[0], device=probs.device)
idx = torch.where(indices == label.view(-1, 1))
return cumsum[idx] - noise * ordered[idx]
[docs]
@torch.no_grad()
def fit(self, dataloader: DataLoader) -> None:
"""Calibrate the APS threshold q_hat on a calibration set."""
if self.enable_ts:
self.model.fit(dataloader=dataloader)
aps_scores = []
for images, labels in dataloader:
images, labels = images.to(self.device), labels.to(self.device)
probs = self.model_forward(images)
scores = self._calculate_single_label(probs, labels)
aps_scores.append(scores)
self.q_hat = torch.quantile(torch.cat(aps_scores), 1 - self.alpha).item()
[docs]
@torch.no_grad()
def conformal(self, inputs: Tensor) -> Tensor:
"""Compute the prediction set for each input."""
probs = self.model_forward(inputs)
pred_set = self._calculate_all_labels(probs) <= self.quantile
return pred_set.float() / pred_set.sum(dim=1, keepdim=True)