Shortcuts

ClassificationRoutine

class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, num_bins_cal_err=15, log_plots=False, save_in_csv=False)[source]

Routine for training & testing on classification tasks.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to False.

  • format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to torch.nn.Identity().

  • optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to None.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to False.

  • eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to False.

  • ood_criterion (str, optional) – OOD criterion. Available options are - "msp" (default): Maximum softmax probability. - "logit": Maximum logit. - "energy": Logsumexp of the mean logits. - "entropy": Entropy of the mean prediction. - "mi": Mutual information of the ensemble. - "vr": Variation ratio of the ensemble.

  • post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to None.

  • num_bins_cal_err (int, optional) – Number of bins to compute calibration error metrics. Defaults to 15.

  • log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to False.

  • save_in_csv (bool, optional) – Save the results in csv. Defaults to False.

Warning

You must define optim_recipe if you do not use the Lightning CLI.

Warning

When using an ensemble model, you must: 1. Set is_ensemble to True. 2. Set format_batch_fn to torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators). 3. Ensure that the model’s forward pass outputs a tensor of shape \((M \times B, C)\), where \(M\) is the number of estimators, \(B\) is the batch size, \(C\) is the number of classes.

For automated batch handling, consider using the available model wrappers in torch_uncertainty.models.wrappers.

Note

If eval_ood is True, we perform a binary classification and update the OOD-related metrics twice: - once during the test on ID values where the given binary label is 0 (for ID) - once during the test on OOD values where the given binary label is 1 (for OOD)

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs, save_feats=False)[source]

Forward pass of the inner model.

Parameters:
  • inputs (Tensor) – input tensor.

  • save_feats (bool, optional) – whether to store the features or not. Defaults to False.

Note

The features are stored in the self.features attribute.

on_test_epoch_end()[source]

Compute, log, and plot the values of the collected metrics in test_step.

on_test_start()[source]

Prepare the test step.

Setup the post-processing dataset and fit the post-processing method if needed, prepares the storage lists for logit plotting and update the batchnorms if needed.

on_train_start()[source]

Put the hyperparameters in tensorboard.

on_validation_epoch_end()[source]

Compute and log the values of the collected metrics in validation_step.

on_validation_start()[source]

Prepare the validation step.

Update the model’s wrapper and the batchnorms if needed.

save_results_to_csv(results)[source]

Save the metric results in a csv.

Parameters:

results (dict[str, float]) – the dictionary containing all the values of the metrics.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Perform a single test step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.

Parameters:
  • batch (tuple[Tensor, Tensor]) – the test data and their corresponding targets.

  • batch_idx (int) – the number of the current batch (unused).

  • dataloader_idx (int) – 0 if in-distribution, 1 if out-of-distribution and 2 if distribution-shifted.

training_step(batch)[source]

Perform a single training step based on the input tensors.

Parameters:

batch (tuple[Tensor, Tensor]) – the training data and their corresponding targets

Returns:

the loss corresponding to this training step.

Return type:

Tensor

validation_step(batch)[source]

Perform a single validation step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the validation batch.

Parameters:

batch (tuple[Tensor, Tensor]) – the validation data and their corresponding targets