PixelRegressionRoutine#
- class torch_uncertainty.routines.PixelRegressionRoutine(model, output_dim, loss=None, dist_family=None, dist_estimate='mean', *, is_ensemble=False, format_batch_fn=None, optim_recipe=None, eval_shift=False, num_image_plot=4, log_plots=False, save_to_csv=False, csv_filename='results.csv')[source]#
Routine for training & testing on pixel regression tasks.
- Parameters:
model (
Module) – Model to train.output_dim (
int) – Number of outputs of the model.loss (
Module|None) – Loss function to optimize themodel. Defaults toNone.dist_family (
str|None) – The distribution family to use for probabilistic pixel regression. IfNonethen point-wise regression. Defaults toNone.dist_estimate (
str) – The estimate to use when computing the point-wise metrics. Defaults to"mean".is_ensemble (
bool) – Whether the model is an ensemble. Defaults toFalse.optim_recipe (
Union[Callable[[Module],Union[Optimizer,Sequence[Optimizer],tuple[Sequence[Optimizer],Sequence[Union[LRScheduler,ReduceLROnPlateau,LRSchedulerConfig]]],OptimizerConfig,OptimizerLRSchedulerConfig,Sequence[OptimizerConfig],Sequence[OptimizerLRSchedulerConfig],None]],Optimizer,Sequence[Optimizer],tuple[Sequence[Optimizer],Sequence[Union[LRScheduler,ReduceLROnPlateau,LRSchedulerConfig]]],OptimizerConfig,OptimizerLRSchedulerConfig,Sequence[OptimizerConfig],Sequence[OptimizerLRSchedulerConfig],None]) – The optimizer and optionally the scheduler to use, or a callable that returns them. Defaults toNone.eval_shift (
bool) – Indicates whether to evaluate the Distribution shift performance. Defaults toFalse.format_batch_fn (
Module|None) – The function to format the batch. Defaults toNone.num_image_plot (
int) – Number of images to plot. Defaults to4.log_plots (
bool) – Indicates whether to log plots from metrics. Defaults toFalse.save_to_csv (
bool) – Save the results in csv. Defaults toFalse.csv_filename (
str) – Name of the csv file. Defaults to"results.csv". Note that this is only used ifsave_to_csvisTrue.
- evaluation_forward(inputs)[source]#
Get the prediction and handle predicted eventual distribution parameters.
- Parameters:
inputs (
Tensor) – the input data.- Returns:
the prediction as a Tensor and a distribution.
- Return type:
tuple[Tensor, Distribution | None]
- forward(inputs)[source]#
Forward pass of the routine.
The forward pass automatically squeezes the output if the regression is one-dimensional and if the routine contains a single model.
- Parameters:
inputs (
Tensor) – The input tensor.- Returns:
The output tensor.
- Return type:
Tensor
- on_test_epoch_end()[source]#
Compute and log the values of the collected metrics in test_step.
- Return type:
None
- on_validation_epoch_end()[source]#
Compute and log the values of the collected metrics in validation_step.
- Return type:
None
- on_validation_start()[source]#
Prepare the validation step.
Update the model’s wrapper and the batchnorms if needed.
- Return type:
None
- test_step(batch, batch_idx, dataloader_idx=0)[source]#
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.
- Parameters:
batch (
tuple[Tensor,Tensor]) – the test data and their corresponding targets.batch_idx (
int) – the number of the current batch (unused).dataloader_idx (
int) – 0 if in-distribution, 1 if out-of-distribution.
- Return type:
None
- training_step(batch)[source]#
Perform a single training step based on the input tensors.
- Parameters:
batch (
tuple[Tensor,Tensor]) – the training data and their corresponding targets- Returns:
the loss corresponding to this training step.
- Return type:
Tensor
- validation_step(batch, batch_idx)[source]#
Perform a single validation step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the validation batch.
- Parameters:
batch (
tuple[Tensor,Tensor]) – the validation images and their corresponding targets.batch_idx (
int) – the id of the batch. Optionally plot images and the predictions with the first batch.
- Return type:
None