Source code for torch_uncertainty.post_processing.calibration.matrix_scaler
from typing import Literal
import torch
from torch import Tensor, device, nn
from torch.nn.functional import linear
from .scaler import Scaler
from .utils import _check_classes
[docs]
class MatrixScaler(Scaler):
def __init__(
self,
num_classes: int,
model: nn.Module | None = None,
init_weight_temperature: float | Tensor = 1,
init_bias_temperature: float | Tensor | None = None,
lr: float = 0.1,
max_iter: int = 200,
eps: float = 1e-8,
device: Literal["cpu", "cuda"] | device | None = None,
) -> None:
"""Matrix scaling post-processing for calibrated probabilities.
Args:
num_classes (int): Number of classes.
model (nn.Module | None): Model to calibrate. Defaults to ``None``.
init_weight_temperature (float | Tensor , optional): Initial value for the weights. Defaults to ``1``.
init_bias_temperature (float | Tensor | None, optional): Initial value for the bias. The inverse bias will be
set to the ``0`` vector if set to ``None``. Defaults to ``None``.
lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``.
max_iter (int, optional): Maximum number of iterations for the optimizer. Defaults to ``100``.
eps (float): Small value for stability. Defaults to ``1e-8``.
device (Optional[Literal["cpu", "cuda"]], optional): Device to use for optimization. Defaults to ``None``.
References:
[1] `On calibration of modern neural networks. In ICML 2017
<https://arxiv.org/abs/1706.04599>`_.
Warning:
If the model is binary, we will by default apply the sigmoid before transposing the prediction to the
2-class case.
"""
super().__init__(model=model, lr=lr, max_iter=max_iter, eps=eps, device=device)
_check_classes(num_classes)
self.num_classes = num_classes
self.set_temperature(init_weight_temperature, init_bias_temperature)
[docs]
def set_temperature(self, val_weight: float | Tensor, val_bias: float | Tensor | None) -> None:
"""Set the temperature matrix to a given value.
Args:
val_weight (float | Tensor): Weight temperature value.
val_bias (float | Tensor): Bias temperature value.
"""
eye = torch.eye(self.num_classes, device=self.device)
self.inv_temperature_weight = nn.Parameter(
eye / val_weight,
requires_grad=True,
)
if val_bias is None:
bias = torch.zeros(self.num_classes, device=self.device)
else:
bias = torch.ones(self.num_classes, device=self.device) / val_bias
self.inv_temperature_bias = nn.Parameter(
bias,
requires_grad=True,
)
self.trained = False
def _scale(self, logits: Tensor) -> Tensor:
return linear(logits, self.inv_temperature_weight, self.inv_temperature_bias)
@property
def inv_temperature(self) -> list:
return [self.inv_temperature_weight, self.inv_temperature_bias]
@property
def temperature(self) -> list:
return [torch.inverse(self.inv_temperature_weight), 1 / self.inv_temperature_bias]