SegmentationRoutine¶
- class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss, optim_recipe=None, eval_shift=False, format_batch_fn=None, metric_subsampling_rate=0.01, log_plots=False, num_samples_to_plot=3, num_calibration_bins=15)[source]¶
Routine for training & testing on segmentation tasks.
- Parameters:
model (torch.nn.Module) – Model to train.
num_classes (int) – Number of classes in the segmentation task.
loss (torch.nn.Module) – Loss function to optimize the
model
.optim_recipe (dict or Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to
False
.format_batch_fn (torch.nn.Module, optional) – The function to format the batch. Defaults to
None
.metric_subsampling_rate (float, optional) – The rate of subsampling for the memory consuming metrics. Defaults to
1e-2
.log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to
False
.num_samples_to_plot (int, optional) – Number of samples to plot in the segmentation results. Defaults to
3
.num_calibration_bins (int, optional) – Number of bins to compute calibration metrics. Defaults to
15
.
Warning
You must define
optim_recipe
if you do not use the CLI.Note
optim_recipe
can be anything that can be returned byLightningModule.configure_optimizers()
. Find more details here.- forward(inputs)[source]¶
Forward pass of the model.
- Parameters:
inputs (Tensor) – input tensor.
- Returns:
the prediction of the model.
- Return type:
Tensor
- on_test_epoch_end()[source]¶
Compute, log, and plot the values of the collected metrics in test_step.
- on_validation_epoch_end()[source]¶
Compute and log the values of the collected metrics in validation_step.
- subsample(pred, target)[source]¶
Select a random sample of the data to compute the loss onto.
- Parameters:
pred (Tensor) – the prediction tensor.
target (Tensor) – the target tensor.
- Returns:
the subsampled prediction and target tensors.
- Return type:
Tuple[Tensor, Tensor]
- test_step(batch)[source]¶
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch.
- Parameters:
batch (tuple[Tensor, Tensor]) – the test images and their corresponding targets