Shortcuts

SegmentationRoutine

class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss, optim_recipe=None, eval_shift=False, format_batch_fn=None, metric_subsampling_rate=0.01, log_plots=False, num_samples_to_plot=3, num_calibration_bins=15)[source]

Routine for training & testing on segmentation tasks.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes in the segmentation task.

  • loss (torch.nn.Module) – Loss function to optimize the model.

  • optim_recipe (dict or Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • format_batch_fn (torch.nn.Module, optional) – The function to format the batch. Defaults to None.

  • metric_subsampling_rate (float, optional) – The rate of subsampling for the memory consuming metrics. Defaults to 1e-2.

  • log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to False.

  • num_samples_to_plot (int, optional) – Number of samples to plot in the segmentation results. Defaults to 3.

  • num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to 15.

Warning

You must define optim_recipe if you do not use the CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs)[source]

Forward pass of the model.

Parameters:

inputs (torch.Tensor) – Input tensor.

log_segmentation_plots()[source]

Builds and logs examples of segmentation plots from the test set.