Source code for torch_uncertainty.post_processing.conformal.abstract
from abc import abstractmethod
from typing import Literal
import torch
from torch import Tensor, nn
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.post_processing.abstract import PostProcessing
[docs]
class Conformal(PostProcessing):
"""Conformal base class."""
q_hat: float = None
def __init__(
self,
alpha: float,
model: nn.Module | None,
ts_init_val: float,
ts_lr: float,
ts_max_iter: int,
enable_ts: bool,
device: Literal["cpu", "cuda"] | torch.device | None,
) -> None:
super().__init__(model=model)
self.alpha = alpha
self.enable_ts = enable_ts
if enable_ts:
self.model = TemperatureScaler(
model=model,
init_val=ts_init_val,
lr=ts_lr,
max_iter=ts_max_iter,
device=device,
)
else:
self.model = model
self.device = device or "cpu"
def set_model(self, model: nn.Module | None) -> None:
if self.enable_ts:
self.model.set_model(model=model.eval())
else:
self.model = model
[docs]
def model_forward(self, inputs: Tensor) -> Tensor:
"""Apply the model and return the scores."""
self.model.eval()
return self.model(inputs.to(self.device)).softmax(-1)
@abstractmethod
def conformal(self, inputs: Tensor) -> Tensor: ...
def forward(self, inputs: Tensor) -> Tensor:
return self.conformal(inputs)
@property
def quantile(self) -> Tensor:
if self.q_hat is None:
raise RuntimeError("Quantile q_hat is not set. Run `.fit()` first.")
return self.q_hat
@property
def temperature(self) -> float:
if self.enable_ts:
return self.model.temperature[0].item()
raise RuntimeError("Cannot return temperature when enable_ts is False.")