

class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, calibration_set='val', num_calibration_bins=15, log_plots=False, save_in_csv=False)[source]

Routine for training & testing on classification tasks.

  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to False.

  • format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to torch.nn.Identity().

  • optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to None.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to False.

  • eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to False.

  • ood_criterion (str, optional) – OOD criterion. Available options are - "msp" (default): Maximum softmax probability. - "logit": Maximum logit. - "energy": Logsumexp of the mean logits. - "entropy": Entropy of the mean prediction. - "mi": Mutual information of the ensemble. - "vr": Variation ratio of the ensemble.

  • post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to None.

  • calibration_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to val.

  • num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to 15.

  • log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to False.

  • save_in_csv (bool, optional) – Save the results in csv. Defaults to False.


You must define optim_recipe if you do not use the Lightning CLI.


optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs, save_feats=False)[source]

Forward pass of the inner model.

  • inputs (Tensor) – input tensor.

  • save_feats (bool, optional) – whether to store the features or not. Defaults to False.


The features are stored in the self.features attribute.


Compute, log, and plot the values of the collected metrics in test_step.


Prepare the test step.

Setup the post-processing dataset and fit the post-processing method if needed, prepares the storage lists for logit plotting and update the batchnorms if needed.


Put the hyperparameters in tensorboard.


Compute and log the values of the collected metrics in validation_step.


Prepare the validation step.

Update the model’s wrapper and the batchnorms if needed.


Save the metric results in a csv.


results (dict[str, float]) – the dictionary containing all the values of the metrics.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Perform a single test step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.

  • batch (tuple[Tensor, Tensor]) – the test data and their corresponding targets.

  • batch_idx (int) – the number of the current batch (unused).

  • dataloader_idx (int) – 0 if in-distribution, 1 if out-of-distribution and 2 if distribution-shifted.


Perform a single training step based on the input tensors.


batch (tuple[Tensor, Tensor]) – the training data and their corresponding targets


the loss corresponding to this training step.

Return type:



Perform a single validation step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the validation batch.


batch (tuple[Tensor, Tensor]) – the validation data and their corresponding targets