Shortcuts

SegmentationRoutine

class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss, num_estimators=1, optim_recipe=None, format_batch_fn=None)[source]

Routine for efficient training and testing on segmentation tasks using LightningModule.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes in the segmentation task.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • num_estimators (int, optional) – The number of estimators for the ensemble. Defaults to ̀`1̀` (single model).

  • optim_recipe (dict or Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • format_batch_fn (torch.nn.Module, optional) – The function to format the batch. Defaults to None.

Warning

You must define optim_recipe if you do not use the CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.