SegmentationRoutine¶
- class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss, num_estimators=1, optim_recipe=None, format_batch_fn=None)[source]¶
Routine for efficient training and testing on segmentation tasks using LightningModule.
- Parameters:
model (torch.nn.Module) – Model to train.
num_classes (int) – Number of classes in the segmentation task.
loss (torch.nn.Module) – Loss function to optimize the
model
.num_estimators (int, optional) – The number of estimators for the ensemble. Defaults to ̀`1̀` (single model).
optim_recipe (dict or Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.format_batch_fn (torch.nn.Module, optional) – The function to format the batch. Defaults to
None
.
Warning
You must define
optim_recipe
if you do not use the CLI.Note
optim_recipe
can be anything that can be returned byLightningModule.configure_optimizers()
. Find more details here.