Shortcuts

ClassificationRoutine

class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, is_ensemble=False, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, calibration_set='val', num_calibration_bins=15, log_plots=False, save_in_csv=False)[source]

Routine for training & testing on classification tasks.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • is_ensemble (bool, optional) – Indicates whether the model is an ensemble at test time or not. Defaults to False.

  • format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to torch.nn.Identity().

  • optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • mixup_params (dict, optional) – Mixup parameters. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults to None.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance. Defaults to False.

  • eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to False.

  • ood_criterion (str, optional) – OOD criterion. Available options are - "msp" (default): Maximum softmax probability. - "logit": Maximum logit. - "energy": Logsumexp of the mean logits. - "entropy": Entropy of the mean prediction. - "mi": Mutual information of the ensemble. - "vr": Variation ratio of the ensemble.

  • post_processing (PostProcessing, optional) – Post-processing method to train on the calibration set. No post-processing if None. Defaults to None.

  • calibration_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to val.

  • num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to 15.

  • log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to False.

  • save_in_csv (bool, optional) – Save the results in csv. Defaults to False.

Warning

You must define optim_recipe if you do not use the Lightning CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs, save_feats=False)[source]

Forward pass of the model.

Parameters:
  • inputs (Tensor) – Input tensor.

  • save_feats (bool, optional) – Whether to store the features or not. Defaults to False.

Note

The features are stored in the self.features attribute.