DirichletScaler#
- class torch_uncertainty.post_processing.DirichletScaler(num_classes, model=None, init_weight_temperature=1, init_bias_temperature=None, lr=0.1, max_iter=200, lambda_reg=None, mu_reg=None, eps=1e-08, device=None)[source]#
Dirichlet scaling post-processing for calibrated probabilities.
- Parameters:
num_classes (
int) – Number of classes.model (
Module|None) – Model to calibrate. Defaults toNone.init_weight_temperature (
float) – Initial value for the weight matrix. Defaults to1.init_bias_temperature (
float|None) – Initial value for the bias. The inverse bias will be set to the0vector if set toNone. Defaults toNone.lr (
float) – Learning rate for the optimizer. Defaults to0.1.max_iter (
int) – Maximum number of iterations for the optimizer. Defaults to200.lambda_reg (
float|None) – Regularization coefficient applied to the off-diagonal elements of the weight matrix. Used to mitigate overfitting. Defaults toNone.mu_reg (
float|None) – Regularization coefficient applied to the bias vector. Defaults toNone.eps (
float) – Small value for numerical stability. Defaults to1e-8.device (
Union[Literal['cpu','cuda'],device,None]) – Device to use for optimization. Defaults toNone.
References
Warning
If the model is binary, we will by default apply the sigmoid before transposing the prediction to the 2-class case.
- fit(dataloader, save_logits=False, progress=True)[source]#
Fit the temperature parameters to the calibration data.
- Parameters:
dataloader (
DataLoader) – Dataloader with the calibration data. If there is no model, the dataloader should include the confidence score directly and not the logits.save_logits (
bool) – Whether to save the logits and labels in memory. Defaults toFalse.progress (
bool) – Whether to show a progress bar. Defaults toTrue.
- Return type:
None
- set_model(model)#
Attach a model to the post-processing module.
- Return type:
None
- set_temperature(val_weight, val_bias)#
Set the temperature matrix to a given value.
- Parameters:
val_weight (
float|Tensor) – Weight temperature value.val_bias (
float|Tensor|None) – Bias temperature value.
- Return type:
None