ClassificationRoutine¶
- class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, num_bins_cal_err=15, log_plots=False, save_in_csv=False)[source]¶
Routine for training & testing on classification tasks.
- Parameters:
model (torch.nn.Module) – Model to train.
num_classes (int) – Number of classes.
loss (torch.nn.Module) – Loss function to optimize the
model
.is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to
False
.format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to
torch.nn.Identity()
.optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to
None
.eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to
False
.eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to
False
.eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to
False
.ood_criterion (str, optional) – OOD criterion. Available options are -
"msp"
(default): Maximum softmax probability. -"logit"
: Maximum logit. -"energy"
: Logsumexp of the mean logits. -"entropy"
: Entropy of the mean prediction. -"mi"
: Mutual information of the ensemble. -"vr"
: Variation ratio of the ensemble.post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to
None
.num_bins_cal_err (int, optional) – Number of bins to compute calibration error metrics. Defaults to
15
.log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to
False
.save_in_csv (bool, optional) – Save the results in csv. Defaults to
False
.
Warning
You must define
optim_recipe
if you do not use the Lightning CLI.Warning
When using an ensemble model, you must: 1. Set
is_ensemble
toTrue
. 2. Setformat_batch_fn
totorch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)
. 3. Ensure that the model’s forward pass outputs a tensor of shape \((M \times B, C)\), where \(M\) is the number of estimators, \(B\) is the batch size, \(C\) is the number of classes.For automated batch handling, consider using the available model wrappers in torch_uncertainty.models.wrappers.
Note
If
eval_ood
isTrue
, we perform a binary classification and update the OOD-related metrics twice: - once during the test on ID values where the given binary label is 0 (for ID) - once during the test on OOD values where the given binary label is 1 (for OOD)Note
optim_recipe
can be anything that can be returned byLightningModule.configure_optimizers()
. Find more details here.- forward(inputs, save_feats=False)[source]¶
Forward pass of the inner model.
- Parameters:
inputs (Tensor) – input tensor.
save_feats (bool, optional) – whether to store the features or not. Defaults to
False
.
Note
The features are stored in the
self.features
attribute.
- on_test_epoch_end()[source]¶
Compute, log, and plot the values of the collected metrics in test_step.
- on_test_start()[source]¶
Prepare the test step.
Setup the post-processing dataset and fit the post-processing method if needed, prepares the storage lists for logit plotting and update the batchnorms if needed.
- on_validation_epoch_end()[source]¶
Compute and log the values of the collected metrics in validation_step.
- on_validation_start()[source]¶
Prepare the validation step.
Update the model’s wrapper and the batchnorms if needed.
- save_results_to_csv(results)[source]¶
Save the metric results in a csv.
- Parameters:
results (dict[str, float]) – the dictionary containing all the values of the metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.
- Parameters:
batch (tuple[Tensor, Tensor]) – the test data and their corresponding targets.
batch_idx (int) – the number of the current batch (unused).
dataloader_idx (int) – 0 if in-distribution, 1 if out-of-distribution and 2 if distribution-shifted.