SegmentationRoutine#

class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss=None, *, optim_recipe=None, eval_shift=False, format_batch_fn=None, metric_subsampling_rate=0.01, eval_ood=False, ood_criterion='msp', post_processing=None, log_plots=False, num_samples_to_plot=3, num_bins_calibration_error=15, save_to_csv=False, csv_filename='results.csv')[source]#

Routine for training & testing on segmentation tasks.

Parameters:
  • model (Module) – Model to train.

  • num_classes (int) – Number of classes in the segmentation task.

  • loss (Module | None) – Loss function to optimize the model. Defaults to None.

  • optim_recipe (Union[Callable[[Module], Union[Optimizer, Sequence[Optimizer], tuple[Sequence[Optimizer], Sequence[Union[LRScheduler, ReduceLROnPlateau, LRSchedulerConfig]]], OptimizerConfig, OptimizerLRSchedulerConfig, Sequence[OptimizerConfig], Sequence[OptimizerLRSchedulerConfig], None]], Optimizer, Sequence[Optimizer], tuple[Sequence[Optimizer], Sequence[Union[LRScheduler, ReduceLROnPlateau, LRSchedulerConfig]]], OptimizerConfig, OptimizerLRSchedulerConfig, Sequence[OptimizerConfig], Sequence[OptimizerLRSchedulerConfig], None]) – The optimizer and optionally the scheduler to use, or a callable that returns them. Defaults to None.

  • eval_shift (bool) – Indicates whether to evaluate the Distribution shift performance. Defaults to False.

  • format_batch_fn (Module | None) – The function to format the batch. Defaults to None.

  • metric_subsampling_rate (float) – The rate of subsampling for the memory consuming metrics. Defaults to 1e-2.

  • eval_ood (bool) – Indicates whether to evaluate the OOD performance. Defaults to False.

  • ood_criterion (TUOODCriterion | str) – Criterion for the binary OOD detection task. Defaults to "msp" which amounts to the maximum softmax probability score (MSP).

  • post_processing (PostProcessing | None) – The post-processing technique to use. Defaults to None. Warning: There is no post-processing technique implemented yet for segmentation tasks.

  • log_plots (bool) – Indicates whether to log figures in the logger. Defaults to False.

  • num_samples_to_plot (int) – Number of segmentation prediction and target to plot in the logger. Note that this is only used if log_plots is set to True. Defaults to 3.

  • num_bins_calibration_error (int) – Number of bins to compute calibration error metrics. Defaults to 15.

  • save_to_csv (bool) – Save the results in csv. Defaults to False.

  • csv_filename (str) – The name of the csv file to save the results in. Defaults to "results.csv".

Warning

You must define optim_recipe if you do not use the CLI.

Note

optim_recipe can be anything that can be returned by LightningModule.configure_optimizers(). Find more details here.

forward(inputs)[source]#

Forward pass of the model.

Parameters:

inputs (Tensor) – input tensor.

Returns:

the prediction of the model.

Return type:

Tensor

on_test_epoch_end()[source]#

Compute, log, and plot the values of the collected metrics in test_step.

Return type:

None

on_validation_epoch_end()[source]#

Compute and log the values of the collected metrics in validation_step.

Return type:

None

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Perform a single test step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the test batch.

Parameters:
  • batch (tuple[Tensor, Tensor]) – the test images and their corresponding targets

  • batch_idx (int) – the index of the batch in the test dataloader.

  • dataloader_idx (int) – the index of the dataloader. Defaults to 0.

Return type:

None

training_step(batch)[source]#

Perform a single training step based on the input tensors.

Parameters:

batch (tuple[Tensor, Tensor]) – the training images and their corresponding targets

Returns:

the loss corresponding to this training step.

Return type:

Tensor

validation_step(batch)[source]#

Perform a single validation step based on the input tensors.

Compute the prediction of the model and the value of the metrics on the validation batch.

Parameters:

batch (tuple[Tensor, Tensor]) – the validation images and their corresponding targets

Return type:

None