ClassificationRoutine¶
- class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, calibration_set='val', num_calibration_bins=15, log_plots=False, save_in_csv=False)[source]¶
Routine for training & testing on classification tasks.
- Parameters:
model (torch.nn.Module) – Model to train.
num_classes (int) – Number of classes.
loss (torch.nn.Module) – Loss function to optimize the
model
.is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to
False
.format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to
torch.nn.Identity()
.optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to
None
.eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to
False
.eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to
False
.eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to
False
.ood_criterion (str, optional) – OOD criterion. Available options are -
"msp"
(default): Maximum softmax probability. -"logit"
: Maximum logit. -"energy"
: Logsumexp of the mean logits. -"entropy"
: Entropy of the mean prediction. -"mi"
: Mutual information of the ensemble. -"vr"
: Variation ratio of the ensemble.post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to
None
.calibration_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to
val
.num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to
15
.log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to
False
.save_in_csv (bool, optional) – Save the results in csv. Defaults to
False
.
Warning
You must define
optim_recipe
if you do not use the Lightning CLI.Note
optim_recipe
can be anything that can be returned byLightningModule.configure_optimizers()
. Find more details here.