Shortcuts

ClassificationRoutine

class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, num_estimators=1, format_batch_fn=None, optim_recipe=None, mixtype='erm', mixmode='elem', dist_sim='emb', kernel_tau_max=1.0, kernel_tau_std=0.5, mixup_alpha=0, cutmix_alpha=0, eval_ood=False, eval_grouping_loss=False, ood_criterion='msp', log_plots=False, save_in_csv=False, calibration_set=None)[source]

Routine for efficient training and testing on classification tasks using LightningModule.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • num_estimators (int, optional) – Number of estimators for the ensemble. Defaults to 1 (single model).

  • format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to torch.nn.Identity().

  • optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • mixtype (str, optional) – Mixup type. Defaults to "erm".

  • mixmode (str, optional) – Mixup mode. Defaults to "elem".

  • dist_sim (str, optional) – Distance similarity. Defaults to "emb".

  • kernel_tau_max (float, optional) – Maximum value for the kernel tau. Defaults to 1.0.

  • kernel_tau_std (float, optional) – Standard deviation for the kernel tau. Defaults to 0.5.

  • mixup_alpha (float, optional) – Alpha parameter for Mixup. Defaults to 0.

  • cutmix_alpha (float, optional) – Alpha parameter for Cutmix. Defaults to 0.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance or not. Defaults to False.

  • eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to False.

  • ood_criterion (str, optional) –

    OOD criterion. Available options are

    • "msp" (default): Maximum softmax probability.

    • "logit": Maximum logit.

    • "energy": Logsumexp of the mean logits.

    • "entropy": Entropy of the mean prediction.

    • "mi": Mutual information of the ensemble.

    • "vr": Variation ratio of the ensemble.

  • log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to False.

  • save_in_csv (bool, optional) – Save the results in csv. Defaults to False.

  • calibration_set (str, optional) – The calibration dataset to use for scaling. If not None, it uses either the validation set when set to "val" or the test set when set to "test". Defaults to None.

Warning

You must define optim_recipe if you do not use the CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs, save_feats=False)[source]

Forward pass of the model.

Parameters:
  • inputs (Tensor) – Input tensor.

  • save_feats (bool, optional) – Whether to store the features or not. Defaults to False.

Note

The features are stored in the self.features attribute.