PixelRegressionRoutine¶
- class torch_uncertainty.routines.PixelRegressionRoutine(model, output_dim, loss, dist_family=None, dist_estimate='mean', is_ensemble=False, format_batch_fn=None, optim_recipe=None, eval_shift=False, num_image_plot=4, log_plots=False)[source]¶
Routine for training & testing on pixel regression tasks.
- Parameters:
model (nn.Module) – Model to train.
output_dim (int) – Number of outputs of the model.
loss (nn.Module) – Loss function to optimize the
model
.dist_family (str, optional) – The distribution family to use for probabilistic pixel regression. If
None
then point-wise regression. Defaults toNone
.dist_estimate (str, optional) – The estimate to use when computing the point-wise metrics. Defaults to
"mean"
.is_ensemble (bool, optional) – Whether the model is an ensemble. Defaults to
False
.optim_recipe (dict or Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.eval_shift (bool, optional) – Indicates whether to evaluate the Distribution shift performance. Defaults to
False
.format_batch_fn (nn.Module, optional) – The function to format the batch. Defaults to
None
.num_image_plot (int, optional) – Number of images to plot. Defaults to
4
.log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to
False
.
- evaluation_forward(inputs)[source]¶
Get the prediction and handle predicted eventual distribution parameters.
- Parameters:
inputs (Tensor) – the input data.
- Returns:
the prediction as a Tensor and a distribution.
- Return type:
tuple[Tensor, Distribution | None]
- forward(inputs)[source]¶
Forward pass of the routine.
The forward pass automatically squeezes the output if the regression is one-dimensional and if the routine contains a single model.
- Parameters:
inputs (Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
Tensor
- on_validation_epoch_end()[source]¶
Compute and log the values of the collected metrics in validation_step.
- on_validation_start()[source]¶
Prepare the validation step.
Update the model’s wrapper and the batchnorms if needed.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.
- Parameters:
batch (tuple[Tensor, Tensor]) – the test data and their corresponding targets.
batch_idx (int) – the number of the current batch (unused).
dataloader_idx (int) – 0 if in-distribution, 1 if out-of-distribution.
- training_step(batch)[source]¶
Perform a single training step based on the input tensors.
- Parameters:
batch (tuple[Tensor, Tensor]) – the training data and their corresponding targets
- Returns:
the loss corresponding to this training step.
- Return type:
Tensor
- validation_step(batch, batch_idx)[source]¶
Perform a single validation step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the validation batch.
- Parameters:
batch (tuple[Tensor, Tensor]) – the validation images and their corresponding targets.
batch_idx (int) – the id of the batch. Optionally plot images and the predictions with the first batch.