SegmentationRoutine#

class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss=None, optim_recipe=None, eval_shift=False, format_batch_fn=None, metric_subsampling_rate=0.01, eval_ood=False, ood_criterion='msp', post_processing=None, log_plots=False, num_samples_to_plot=3, num_bins_cal_err=15, save_in_csv=False, csv_filename='results.csv')[source]#

Routine for training & testing on segmentation tasks.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • num_classes (int) – Number of classes in the segmentation task.

  • loss (torch.nn.Module) – Loss function to optimize the model. Defaults to None.

  • optim_recipe (dict or Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to None.

  • eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • format_batch_fn (torch.nn.Module, optional) – The function to format the batch. Defaults to None.

  • metric_subsampling_rate (float, optional) – The rate of subsampling for the memory consuming metrics. Defaults to 1e-2.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD performance. Defaults to False.

  • ood_criterion (TUOODCriterion, optional) – Criterion for the binary OOD detection task. Defaults to "msp" which amounts to the maximum softmax probability score (MSP).

  • post_processing (PostProcessing, optional) – The post-processing technique to use. Defaults to None. Warning: There is no post-processing technique implemented yet for segmentation tasks.

  • log_plots (bool, optional) – Indicates whether to log figures in the logger. Defaults to False.

  • num_samples_to_plot (int, optional) – Number of segmentation prediction and target to plot in the logger. Note that this is only used if log_plots is set to True. Defaults to 3.

  • num_bins_cal_err (int, optional) – Number of bins to compute calibration error metrics. Defaults to 15.

  • save_in_csv (bool, optional) – Save the results in csv. Defaults to False.

  • csv_filename (str, optional) – The name of the csv file to save the results in. Defaults to "results.csv".

Warning

You must define optim_recipe if you do not use the CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs)[source]#

Forward pass of the model.

Parameters:

inputs (Tensor) – input tensor.

Returns:

the prediction of the model.

Return type:

Tensor

log_segmentation_plots()[source]#

Build and log examples of segmentation plots from the test set.

on_test_epoch_end()[source]#

Compute, log, and plot the values of the collected metrics in test_step.

on_validation_epoch_end()[source]#

Compute and log the values of the collected metrics in validation_step.

subsample(pred, target)[source]#

Select a random sample of the data to compute the loss onto.

Parameters:
  • pred (Tensor) – the prediction tensor.

  • target (Tensor) – the target tensor.

Returns:

the subsampled prediction and target tensors.

Return type:

Tuple[Tensor, Tensor]

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Perform a single test step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the test batch.

Parameters:
  • batch (tuple[Tensor, Tensor]) – the test images and their corresponding targets

  • batch_idx (int) – the index of the batch in the test dataloader.

  • dataloader_idx (int, optional) – the index of the dataloader. Defaults to 0.

training_step(batch)[source]#

Perform a single training step based on the input tensors.

Parameters:

batch (tuple[Tensor, Tensor]) – the training images and their corresponding targets

Returns:

the loss corresponding to this training step.

Return type:

Tensor

validation_step(batch)[source]#

Perform a single validation step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the validation batch.

Parameters:

batch (tuple[Tensor, Tensor]) – the validation images and their corresponding targets