Source code for torch_uncertainty.post_processing.calibration.temperature_scaler

import logging
from typing import Literal

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from .scaler import Scaler


[docs] class TemperatureScaler(Scaler): def __init__( self, model: nn.Module | None = None, init_temperature: float | Tensor = 1, lr: float = 0.1, max_iter: int = 100, eps: float = 1e-8, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: r"""Temperature scaling post-processing for calibrated probabilities. Rescales the model's logits by a single learnable scalar :math:`T > 0` (the *temperature*) before the softmax: .. math:: \tilde{\mathbf{p}}(\mathbf{x}) = \mathrm{softmax}\!\left(\mathbf{z}(\mathbf{x}) / T\right). :math:`T` is fit by minimising the cross-entropy on a held-out calibration set. Despite being a single-parameter transformation, temperature scaling is a remarkably effective recipe for fixing the overconfidence of modern neural networks (Guo et al., 2017). Args: model: Model to calibrate. init_temperature: Initial value for the temperature :math:`T`. Defaults to ``1``. lr: Learning rate for the optimizer. Defaults to ``0.1``. max_iter: Maximum number of iterations for the optimizer. Defaults to ``100``. eps: Small value for stability. Defaults to ``1e-8``. device: Device to use for optimization. Defaults to ``None``. References: [1] `Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. ICML 2017 <https://arxiv.org/abs/1706.04599>`_. Warning: For binary models, a sigmoid is applied before the prediction is transposed to the corresponding 2-class logits. Note: The scaler will log an error if the temperature converges to a negative value. """ super().__init__(model=model, lr=lr, max_iter=max_iter, eps=eps, device=device) if init_temperature <= 0: raise ValueError(f"Initial temperature value must be positive. Got {init_temperature}.") self.set_temperature(init_temperature) def fit( self, dataloader: DataLoader, save_logits: bool = False, progress: bool = True, ) -> None: super().fit(dataloader=dataloader, save_logits=save_logits, progress=progress) if self.inv_temp.item() <= 0: # coverage: ignore logging.error( "TemperatureScaler converged to a negative temperature %.3f.", 1 / self.inv_temp )
[docs] def set_temperature(self, val: float | Tensor) -> None: """Set the temperature to a fixed value. Args: val: Temperature value. """ if val <= 0: raise ValueError(f"Temperature value must be strictly positive. Got {val}.") self.inv_temp = nn.Parameter(torch.ones(1, device=self.device) / val, requires_grad=True) self.trained = False
def _scale(self, logits: Tensor) -> Tensor: return self.inv_temp * logits @property def inv_temperature(self) -> list: return [self.inv_temp] @property def temperature(self) -> list: return [1 / self.inv_temp]