from typing import Literal
import matplotlib.cm as cm
import torch
from einops import rearrange
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor, nn
from torch.distributions import (
Categorical,
Distribution,
Independent,
MixtureSameFamily,
)
from torch.optim import Optimizer
from torchmetrics import MeanSquaredError, MetricCollection
from torchvision.transforms.v2 import functional as F
from torchvision.utils import make_grid
from torch_uncertainty.metrics import (
DistributionNLL,
Log10,
MeanAbsoluteErrorInverse,
MeanGTRelativeAbsoluteError,
MeanGTRelativeSquaredError,
MeanSquaredErrorInverse,
MeanSquaredLogError,
SILog,
ThresholdAccuracy,
)
from torch_uncertainty.models import (
EPOCH_UPDATE_MODEL,
STEP_UPDATE_MODEL,
)
from torch_uncertainty.utils.distributions import (
get_dist_class,
get_dist_estimate,
)
[docs]class PixelRegressionRoutine(LightningModule):
inv_norm_params = {
"mean": [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255],
"std": [1 / 0.229, 1 / 0.224, 1 / 0.255],
}
def __init__(
self,
model: nn.Module,
output_dim: int,
loss: nn.Module,
dist_family: str | None = None,
dist_estimate: str = "mean",
is_ensemble: bool = False,
format_batch_fn: nn.Module | None = None,
optim_recipe: dict | Optimizer | None = None,
eval_shift: bool = False,
num_image_plot: int = 4,
log_plots: bool = False,
) -> None:
r"""Routine for training & testing on **pixel regression** tasks.
Args:
model (nn.Module): Model to train.
output_dim (int): Number of outputs of the model.
loss (nn.Module): Loss function to optimize the :attr:`model`.
dist_family (str, optional): The distribution family to use for
probabilistic pixel regression. If ``None`` then point-wise regression.
Defaults to ``None``.
dist_estimate (str, optional): The estimate to use when computing the
point-wise metrics. Defaults to ``"mean"``.
is_ensemble (bool, optional): Whether the model is an ensemble.
Defaults to ``False``.
optim_recipe (dict or Optimizer, optional): The optimizer and
optionally the scheduler to use. Defaults to ``None``.
eval_shift (bool, optional): Indicates whether to evaluate the Distribution
shift performance. Defaults to ``False``.
format_batch_fn (nn.Module, optional): The function to format the
batch. Defaults to ``None``.
num_image_plot (int, optional): Number of images to plot. Defaults to ``4``.
log_plots (bool, optional): Indicates whether to log plots from
metrics. Defaults to ``False``.
"""
super().__init__()
_depth_routine_checks(output_dim, num_image_plot, log_plots)
if eval_shift:
raise NotImplementedError(
"Distribution shift evaluation not implemented yet. Raise an issue " "if needed."
)
self.model = model
self.output_dim = output_dim
self.one_dim_depth = output_dim == 1
self.dist_family = dist_family
self.dist_estimate = dist_estimate
self.probabilistic = dist_family is not None
self.loss = loss
self.num_image_plot = num_image_plot
self.is_ensemble = is_ensemble
self.log_plots = log_plots
self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL)
self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL)
if format_batch_fn is None:
format_batch_fn = nn.Identity()
self.optim_recipe = optim_recipe
self.format_batch_fn = format_batch_fn
self._init_metrics()
def _init_metrics(self) -> None:
"""Initialize the metrics depending on the exact task."""
depth_metrics = MetricCollection(
{
"reg/SILog": SILog(),
"reg/log10": Log10(),
"reg/ARE": MeanGTRelativeAbsoluteError(),
"reg/RSRE": MeanGTRelativeSquaredError(squared=False),
"reg/RMSE": MeanSquaredError(squared=False),
"reg/RMSELog": MeanSquaredLogError(squared=False),
"reg/iMAE": MeanAbsoluteErrorInverse(),
"reg/iRMSE": MeanSquaredErrorInverse(squared=False),
"reg/d1": ThresholdAccuracy(power=1),
"reg/d2": ThresholdAccuracy(power=2),
"reg/d3": ThresholdAccuracy(power=3),
},
compute_groups=False,
)
self.val_metrics = depth_metrics.clone(prefix="val/")
self.test_metrics = depth_metrics.clone(prefix="test/")
if self.probabilistic:
depth_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")})
self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/")
def configure_optimizers(self) -> Optimizer | dict:
return self.optim_recipe
[docs] def on_train_start(self) -> None:
"""Put the hyperparameters in tensorboard."""
if self.logger is not None: # coverage: ignore
self.logger.log_hyperparams(
self.hparams,
)
[docs] def on_validation_start(self) -> None:
"""Prepare the validation step.
Update the model's wrapper and the batchnorms if needed.
"""
if self.needs_epoch_update and not self.trainer.sanity_checking:
self.model.update_wrapper(self.current_epoch)
if hasattr(self.model, "need_bn_update"):
self.model.bn_update(self.trainer.train_dataloader, device=self.device)
[docs] def on_test_start(self) -> None:
"""Prepare the test step.
Update the batchnorms if needed.
"""
if hasattr(self.model, "need_bn_update"):
self.model.bn_update(self.trainer.train_dataloader, device=self.device)
[docs] def forward(self, inputs: Tensor) -> Tensor | Distribution:
"""Forward pass of the routine.
The forward pass automatically squeezes the output if the regression
is one-dimensional and if the routine contains a single model.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
pred = self.model(inputs)
if self.probabilistic:
if not self.is_ensemble:
pred = {k: v.squeeze(-1) for k, v in pred.items()}
else:
if not self.is_ensemble:
pred = pred.squeeze(-1)
return pred
[docs] def training_step(self, batch: tuple[Tensor, Tensor]) -> STEP_OUTPUT:
"""Perform a single training step based on the input tensors.
Args:
batch (tuple[Tensor, Tensor]): the training data and their corresponding targets
Returns:
Tensor: the loss corresponding to this training step.
"""
inputs, target = self.format_batch_fn(batch)
if self.one_dim_depth:
target = target.unsqueeze(1)
out = self.model(inputs)
out_shape = out[next(iter(out))].shape[-2:] if self.probabilistic else out.shape[-2:]
target = F.resize(target, out_shape, interpolation=F.InterpolationMode.NEAREST)
target = rearrange(target, "b c h w -> b h w c")
padding_mask = torch.isnan(target).any(dim=-1)
if self.probabilistic:
dist_params = {k: rearrange(v, "b c h w -> b h w c") for k, v in out.items()}
# Adding the Independent wrapper to the distribution to compute correctly the
# log-likelihood given a target. Here the last dimension is the event dimension.
# When computing the log-likelihood, the values are summed over the event dimension.
dists = Independent(get_dist_class(self.dist_family)(**dist_params), 1)
loss = self.loss(dists, target, padding_mask)
else:
out = rearrange(out, "b c h w -> b h w c")
loss = self.loss(out[padding_mask], target[padding_mask])
if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss
[docs] def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | None]:
"""Get the prediction and handle predicted eventual distribution parameters.
Args:
inputs (Tensor): the input data.
Returns:
tuple[Tensor, Distribution | None]: the prediction as a Tensor and a distribution.
"""
batch_size = inputs.size(0)
preds = self.model(inputs)
if self.probabilistic:
dist_params = {
k: rearrange(v, "(m b) c h w -> b h w m c", b=batch_size) for k, v in preds.items()
}
# Adding the Independent wrapper to the distribution to create a MixtureSameFamily.
# As required by the torch.distributions API, the last dimension is the event dimension.
comp = Independent(get_dist_class(self.dist_family)(**dist_params), 1)
mix = Categorical(torch.ones(comp.batch_shape, device=self.device))
mixture = MixtureSameFamily(mix, comp)
preds = get_dist_estimate(comp, self.dist_estimate).mean(-2)
return preds, mixture
preds = rearrange(preds, "(m b) c h w -> b m h w c", b=batch_size)
return preds.mean(dim=1), None
[docs] def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Perform a single validation step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the validation batch.
Args:
batch (tuple[Tensor, Tensor]): the validation images and their corresponding targets.
batch_idx (int): the id of the batch. Optionally plot images and the predictions with
the first batch.
"""
inputs, targets = batch
if self.one_dim_depth:
targets = targets.unsqueeze(1)
targets = rearrange(targets, "b c h w -> b h w c")
preds, dist = self.evaluation_forward(inputs)
if batch_idx == 0 and self.log_plots:
self._plot_depth(
inputs[: self.num_image_plot, ...],
preds[: self.num_image_plot, ...],
targets[: self.num_image_plot, ...],
stage="val",
)
padding_mask = torch.isnan(targets).any(dim=-1)
self.val_metrics.update(preds[padding_mask], targets[padding_mask])
if isinstance(dist, Distribution):
self.val_prob_metrics.update(dist, targets, padding_mask)
[docs] def test_step(
self,
batch: tuple[Tensor, Tensor],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also
handle OOD and distribution-shifted images.
Args:
batch (tuple[Tensor, Tensor]): the test data and their corresponding targets.
batch_idx (int): the number of the current batch (unused).
dataloader_idx (int): 0 if in-distribution, 1 if out-of-distribution.
"""
if dataloader_idx != 0:
raise NotImplementedError(
"Depth OOD detection not implemented yet. Raise an issue " "if needed."
)
inputs, targets = batch
if self.one_dim_depth:
targets = targets.unsqueeze(1)
targets = rearrange(targets, "b c h w -> b h w c")
preds, dist = self.evaluation_forward(inputs)
if batch_idx == 0 and self.log_plots:
num_images = (
self.num_image_plot if self.num_image_plot < inputs.size(0) else inputs.size(0)
)
self._plot_depth(
inputs[:num_images, ...],
preds[:num_images, ...],
targets[:num_images, ...],
stage="test",
)
padding_mask = torch.isnan(targets).any(dim=-1)
self.test_metrics.update(preds[padding_mask], targets[padding_mask])
if isinstance(dist, Distribution):
self.test_prob_metrics.update(dist, targets, padding_mask)
[docs] def on_validation_epoch_end(self) -> None:
"""Compute and log the values of the collected metrics in `validation_step`."""
res_dict = self.val_metrics.compute()
self.log_dict(res_dict, logger=True, sync_dist=True)
self.log(
"RMSE",
res_dict["val/reg/RMSE"],
prog_bar=True,
logger=False,
sync_dist=True,
)
self.val_metrics.reset()
if self.probabilistic:
self.log_dict(
self.val_prob_metrics.compute(),
sync_dist=True,
)
self.val_prob_metrics.reset()
[docs] def on_test_epoch_end(self) -> None:
"""Compute and log the values of the collected metrics in `test_step`."""
self.log_dict(
self.test_metrics.compute(),
sync_dist=True,
)
self.test_metrics.reset()
if self.probabilistic:
self.log_dict(
self.test_prob_metrics.compute(),
sync_dist=True,
)
self.test_prob_metrics.reset()
def _plot_depth(
self,
inputs: Tensor,
preds: Tensor,
target: Tensor,
stage: Literal["val", "test"],
) -> None:
if (
self.logger is not None
and isinstance(self.logger, TensorBoardLogger)
and self.one_dim_depth
):
all_imgs = []
for i in range(inputs.size(0)):
img = F.normalize(inputs[i, ...].cpu(), **self.inv_norm_params)
pred = colorize(preds[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth)
tgt = colorize(target[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth)
all_imgs.extend([img, pred, tgt])
self.logger.experiment.add_image(
f"{stage}/samples",
make_grid(torch.stack(all_imgs, dim=0), nrow=3),
self.current_epoch,
)
def colorize(
value: Tensor,
vmin: float | None = None,
vmax: float | None = None,
cmap: str = "magma",
):
"""Colorize a tensor of depth values.
Args:
value (Tensor): The tensor of depth values.
vmin (float, optional): The minimum depth value. Defaults to None.
vmax (float, optional): The maximum depth value. Defaults to None.
cmap (str, optional): The colormap to use. Defaults to 'magma'.
"""
vmin = value.min().item() if vmin is None else vmin
vmax = value.max().item() if vmax is None else vmax
if vmin == vmax:
return torch.zeros_like(value)
value = (value - vmin) / (vmax - vmin)
cmapper = cm.get_cmap(cmap)
value = cmapper(value.numpy(), bytes=True)
img = value[:, :, :3]
return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0
def _depth_routine_checks(output_dim: int, num_image_plot: int, log_plots: bool) -> None:
"""Check the domains of the routine's parameters.
Args:
output_dim (int): the dimension of the output of the regression task.
num_image_plot (int): the number of images to plot at evaluation time.
log_plots (bool): whether to plot images and predictions during evaluation.
"""
if output_dim < 1:
raise ValueError(f"output_dim must be positive, got {output_dim}.")
if num_image_plot < 1 and log_plots:
raise ValueError(f"num_image_plot must be positive, got {num_image_plot}.")