Shortcuts

ClassificationRoutine

class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, calibration_set='val', num_calibration_bins=15, log_plots=False, save_in_csv=False)[source]

Routine for training & testing on classification tasks.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to False.

  • format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to torch.nn.Identity().

  • optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to None.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to False.

  • eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to False.

  • ood_criterion (str, optional) – OOD criterion. Available options are - "msp" (default): Maximum softmax probability. - "logit": Maximum logit. - "energy": Logsumexp of the mean logits. - "entropy": Entropy of the mean prediction. - "mi": Mutual information of the ensemble. - "vr": Variation ratio of the ensemble.

  • post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to None.

  • calibration_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to val.

  • num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to 15.

  • log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to False.

  • save_in_csv (bool, optional) – Save the results in csv. Defaults to False.

Warning

You must define optim_recipe if you do not use the Lightning CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs, save_feats=False)[source]

Forward pass of the inner model.

Parameters:
  • inputs (Tensor) – input tensor.

  • save_feats (bool, optional) – whether to store the features or not. Defaults to False.

Note

The features are stored in the self.features attribute.

on_test_epoch_end()[source]

Compute, log, and plot the values of the collected metrics in test_step.

on_test_start()[source]

Prepare the test step.

Setup the post-processing dataset and fit the post-processing method if needed, prepares the storage lists for logit plotting and update the batchnorms if needed.

on_train_start()[source]

Put the hyperparameters in tensorboard.

on_validation_epoch_end()[source]

Compute and log the values of the collected metrics in validation_step.

on_validation_start()[source]

Prepare the validation step.

Update the model’s wrapper and the batchnorms if needed.

save_results_to_csv(results)[source]

Save the metric results in a csv.

Parameters:

results (dict[str, float]) – the dictionary containing all the values of the metrics.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Perform a single test step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.

Parameters:
  • batch (tuple[Tensor, Tensor]) – the test data and their corresponding targets.

  • batch_idx (int) – the number of the current batch (unused).

  • dataloader_idx (int) – 0 if in-distribution, 1 if out-of-distribution and 2 if distribution-shifted.

training_step(batch)[source]

Perform a single training step based on the input tensors.

Parameters:

batch (tuple[Tensor, Tensor]) – the training data and their corresponding targets

Returns:

the loss corresponding to this training step.

Return type:

Tensor

validation_step(batch)[source]

Perform a single validation step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the validation batch.

Parameters:

batch (tuple[Tensor, Tensor]) – the validation data and their corresponding targets