SegmentationRoutine#
- class torch_uncertainty.routines.SegmentationRoutine(model, num_classes, loss=None, *, optim_recipe=None, eval_shift=False, format_batch_fn=None, metric_subsampling_rate=0.01, eval_ood=False, ood_criterion='msp', post_processing=None, log_plots=False, num_samples_to_plot=3, num_bins_calibration_error=15, save_to_csv=False, csv_filename='results.csv')[source]#
Routine for training & testing on segmentation tasks.
- Parameters:
model (
Module) – Model to train.num_classes (
int) – Number of classes in the segmentation task.loss (
Module|None) – Loss function to optimize themodel. Defaults toNone.optim_recipe (
Union[Callable[[Module],Union[Optimizer,Sequence[Optimizer],tuple[Sequence[Optimizer],Sequence[Union[LRScheduler,ReduceLROnPlateau,LRSchedulerConfig]]],OptimizerConfig,OptimizerLRSchedulerConfig,Sequence[OptimizerConfig],Sequence[OptimizerLRSchedulerConfig],None]],Optimizer,Sequence[Optimizer],tuple[Sequence[Optimizer],Sequence[Union[LRScheduler,ReduceLROnPlateau,LRSchedulerConfig]]],OptimizerConfig,OptimizerLRSchedulerConfig,Sequence[OptimizerConfig],Sequence[OptimizerLRSchedulerConfig],None]) – The optimizer and optionally the scheduler to use, or a callable that returns them. Defaults toNone.eval_shift (
bool) – Indicates whether to evaluate the Distribution shift performance. Defaults toFalse.format_batch_fn (
Module|None) – The function to format the batch. Defaults toNone.metric_subsampling_rate (
float) – The rate of subsampling for the memory consuming metrics. Defaults to1e-2.eval_ood (
bool) – Indicates whether to evaluate the OOD performance. Defaults toFalse.ood_criterion (
TUOODCriterion|str) – Criterion for the binary OOD detection task. Defaults to"msp"which amounts to the maximum softmax probability score (MSP).post_processing (
PostProcessing|None) – The post-processing technique to use. Defaults toNone. Warning: There is no post-processing technique implemented yet for segmentation tasks.log_plots (
bool) – Indicates whether to log figures in the logger. Defaults toFalse.num_samples_to_plot (
int) – Number of segmentation prediction and target to plot in the logger. Note that this is only used iflog_plotsis set toTrue. Defaults to3.num_bins_calibration_error (
int) – Number of bins to compute calibration error metrics. Defaults to15.save_to_csv (
bool) – Save the results in csv. Defaults toFalse.csv_filename (
str) – The name of the csv file to save the results in. Defaults to"results.csv".
Warning
You must define
optim_recipeif you do not use the CLI.Note
optim_recipecan be anything that can be returned byLightningModule.configure_optimizers(). Find more details here.- forward(inputs)[source]#
Forward pass of the model.
- Parameters:
inputs (
Tensor) – input tensor.- Returns:
the prediction of the model.
- Return type:
Tensor
- on_test_epoch_end()[source]#
Compute, log, and plot the values of the collected metrics in test_step.
- Return type:
None
- on_validation_epoch_end()[source]#
Compute and log the values of the collected metrics in validation_step.
- Return type:
None
- test_step(batch, batch_idx, dataloader_idx=0)[source]#
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch.
- Parameters:
batch (
tuple[Tensor,Tensor]) – the test images and their corresponding targetsbatch_idx (
int) – the index of the batch in the test dataloader.dataloader_idx (
int) – the index of the dataloader. Defaults to0.
- Return type:
None