DirichletScaler#

class torch_uncertainty.post_processing.DirichletScaler(num_classes, model=None, init_weight_temperature=1, init_bias_temperature=None, lr=0.1, max_iter=200, lambda_reg=None, mu_reg=None, eps=1e-08, device=None)[source]#

Dirichlet scaling post-processing for calibrated probabilities.

Parameters:
  • num_classes (int) – Number of classes.

  • model (nn.Module | None) – Model to calibrate. Defaults to None.

  • init_weight_temperature (float, optional) – Initial value for the weight matrix. Defaults to 1.

  • init_bias_temperature (float | None, optional) – Initial value for the bias. The inverse bias will be set to the 0 vector if set to None. Defaults to None.

  • lr (float, optional) – Learning rate for the optimizer. Defaults to 0.1.

  • max_iter (int, optional) – Maximum number of iterations for the optimizer. Defaults to 200.

  • lambda_reg (float | None, optional) – Regularization coefficient applied to the off-diagonal elements of the weight matrix. Used to mitigate overfitting. Defaults to None.

  • mu_reg (float | None, optional) – Regularization coefficient applied to the bias vector. Defaults to None.

  • eps (float) – Small value for numerical stability. Defaults to 1e-8.

  • device (Optional[Literal["cpu", "cuda"]], optional) – Device to use for optimization. Defaults to None.

References

[1] Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration.

Warning

If the model is binary, we will by default apply the sigmoid before transposing the prediction to the 2-class case.

fit(dataloader, save_logits=False, progress=True)[source]#

Fit the temperature parameters to the calibration data.

Parameters:
  • dataloader (DataLoader) – Dataloader with the calibration data. If there is no model, the dataloader should include the confidence score directly and not the logits.

  • save_logits (bool, optional) – Whether to save the logits and labels in memory. Defaults to False.

  • progress (bool, optional) – Whether to show a progress bar. Defaults to True.

set_temperature(val_weight, val_bias)#

Set the temperature matrix to a given value.

Parameters:
  • val_weight (float | Tensor) – Weight temperature value.

  • val_bias (float | Tensor) – Bias temperature value.