ClassificationRoutine¶
- class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, calibration_set='val', num_calibration_bins=15, log_plots=False, save_in_csv=False)[source]¶
Routine for training & testing on classification tasks.
- Parameters:
model (torch.nn.Module) – Model to train.
num_classes (int) – Number of classes.
loss (torch.nn.Module) – Loss function to optimize the
model
.is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to
False
.format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to
torch.nn.Identity()
.optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to
None
.eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to
False
.eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to
False
.eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to
False
.ood_criterion (str, optional) – OOD criterion. Available options are -
"msp"
(default): Maximum softmax probability. -"logit"
: Maximum logit. -"energy"
: Logsumexp of the mean logits. -"entropy"
: Entropy of the mean prediction. -"mi"
: Mutual information of the ensemble. -"vr"
: Variation ratio of the ensemble.post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to
None
.calibration_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to
val
.num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to
15
.log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to
False
.save_in_csv (bool, optional) – Save the results in csv. Defaults to
False
.
Warning
You must define
optim_recipe
if you do not use the Lightning CLI.Note
optim_recipe
can be anything that can be returned byLightningModule.configure_optimizers()
. Find more details here.- forward(inputs, save_feats=False)[source]¶
Forward pass of the inner model.
- Parameters:
inputs (Tensor) – input tensor.
save_feats (bool, optional) – whether to store the features or not. Defaults to
False
.
Note
The features are stored in the
self.features
attribute.
- on_test_epoch_end()[source]¶
Compute, log, and plot the values of the collected metrics in test_step.
- on_test_start()[source]¶
Prepare the test step.
Setup the post-processing dataset and fit the post-processing method if needed, prepares the storage lists for logit plotting and update the batchnorms if needed.
- on_validation_epoch_end()[source]¶
Compute and log the values of the collected metrics in validation_step.
- on_validation_start()[source]¶
Prepare the validation step.
Update the model’s wrapper and the batchnorms if needed.
- save_results_to_csv(results)[source]¶
Save the metric results in a csv.
- Parameters:
results (dict[str, float]) – the dictionary containing all the values of the metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.
- Parameters:
batch (tuple[Tensor, Tensor]) – the test data and their corresponding targets.
batch_idx (int) – the number of the current batch (unused).
dataloader_idx (int) – 0 if in-distribution, 1 if out-of-distribution and 2 if distribution-shifted.