Shortcuts

Source code for torch_uncertainty.routines.pixel_regression

from typing import Literal

import matplotlib.cm as cm
import torch
from einops import rearrange
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor, nn
from torch.distributions import (
    Categorical,
    Distribution,
    Independent,
    MixtureSameFamily,
)
from torch.optim import Optimizer
from torchmetrics import MeanSquaredError, MetricCollection
from torchvision.transforms.v2 import functional as F
from torchvision.utils import make_grid

from torch_uncertainty.metrics import (
    DistributionNLL,
    Log10,
    MeanAbsoluteErrorInverse,
    MeanGTRelativeAbsoluteError,
    MeanGTRelativeSquaredError,
    MeanSquaredErrorInverse,
    MeanSquaredLogError,
    SILog,
    ThresholdAccuracy,
)
from torch_uncertainty.models import (
    EPOCH_UPDATE_MODEL,
    STEP_UPDATE_MODEL,
)
from torch_uncertainty.utils.distributions import (
    dist_rearrange,
    dist_size,
    dist_squeeze,
)


[docs]class PixelRegressionRoutine(LightningModule): inv_norm_params = { "mean": [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255], "std": [1 / 0.229, 1 / 0.224, 1 / 0.255], } def __init__( self, model: nn.Module, output_dim: int, probabilistic: bool, loss: nn.Module, is_ensemble: bool = False, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, eval_shift: bool = False, num_image_plot: int = 4, log_plots: bool = False, ) -> None: """Routine for training & testing on **pixel regression** tasks. Args: model (nn.Module): Model to train. output_dim (int): Number of outputs of the model. probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. loss (nn.Module): Loss function to optimize the :attr:`model`. is_ensemble (bool, optional): Whether the model is an ensemble. Defaults to ``False``. optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. eval_shift (bool, optional): Indicates whether to evaluate the Distribution shift performance. Defaults to ``False``. format_batch_fn (nn.Module, optional): The function to format the batch. Defaults to ``None``. num_image_plot (int, optional): Number of images to plot. Defaults to ``4``. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. """ super().__init__() _depth_routine_checks(output_dim, num_image_plot, log_plots) if eval_shift: raise NotImplementedError( "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." ) self.model = model self.output_dim = output_dim self.one_dim_depth = output_dim == 1 self.probabilistic = probabilistic self.loss = loss self.num_image_plot = num_image_plot self.is_ensemble = is_ensemble self.log_plots = log_plots self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() self.optim_recipe = optim_recipe self.format_batch_fn = format_batch_fn depth_metrics = MetricCollection( { "reg/SILog": SILog(), "reg/log10": Log10(), "reg/ARE": MeanGTRelativeAbsoluteError(), "reg/RSRE": MeanGTRelativeSquaredError(squared=False), "reg/RMSE": MeanSquaredError(squared=False), "reg/RMSELog": MeanSquaredLogError(squared=False), "reg/iMAE": MeanAbsoluteErrorInverse(), "reg/iRMSE": MeanSquaredErrorInverse(squared=False), "reg/d1": ThresholdAccuracy(power=1), "reg/d2": ThresholdAccuracy(power=2), "reg/d3": ThresholdAccuracy(power=3), }, compute_groups=False, ) self.val_metrics = depth_metrics.clone(prefix="val/") self.test_metrics = depth_metrics.clone(prefix="test/") if self.probabilistic: depth_prob_metrics = MetricCollection( {"reg/NLL": DistributionNLL(reduction="mean")} ) self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/") self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, ) def on_validation_start(self) -> None: if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): self.model.bn_update( self.trainer.train_dataloader, device=self.device ) def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): self.model.bn_update( self.trainer.train_dataloader, device=self.device )
[docs] def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. The forward pass automatically squeezes the output if the regression is one-dimensional and if the routine contains a single model. Args: inputs (Tensor): The input tensor. Returns: Tensor: The output tensor. """ pred = self.model(inputs) if self.probabilistic: if not self.is_ensemble: pred = dist_squeeze(pred, -1) else: if not self.is_ensemble: pred = pred.squeeze(-1) return pred
def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: inputs, target = self.format_batch_fn(batch) if self.one_dim_depth: target = target.unsqueeze(1) dists = self.model(inputs) if self.probabilistic: out_shape = dist_size(dists)[-2:] else: out_shape = dists.shape[-2:] target = F.resize( target, out_shape, interpolation=F.InterpolationMode.NEAREST ) padding_mask = torch.isnan(target) if self.probabilistic: loss = self.loss(dists, target, padding_mask) else: loss = self.loss(dists[padding_mask], target[padding_mask]) if self.needs_step_update: self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss, prog_bar=True, logger=True) return loss def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) batch_size = targets.size(0) targets = rearrange(targets, "b c h w -> (b c h w)") preds = self.model(inputs) if self.probabilistic: ens_dist = Independent( dist_rearrange( preds, "(m b) c h w -> (b c h w) m", b=batch_size ), 0, ) mix = Categorical( torch.ones( (dist_size(preds)[0] // batch_size), device=self.device ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) if batch_idx == 0 and self.log_plots: self._plot_depth( inputs[: self.num_image_plot, ...], preds[: self.num_image_plot, ...], targets[: self.num_image_plot, ...], stage="val", ) padding_mask = torch.isnan(targets) self.val_metrics.update(preds[padding_mask], targets[padding_mask]) if self.probabilistic: self.val_prob_metrics.update(mixture, targets, padding_mask) def test_step( self, batch: tuple[Tensor, Tensor], batch_idx: int, dataloader_idx: int = 0, ) -> None: if dataloader_idx != 0: raise NotImplementedError( "Depth OOD detection not implemented yet. Raise an issue " "if needed." ) inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) batch_size = targets.size(0) targets = rearrange(targets, "b c h w -> (b c h w)") preds = self.model(inputs) if self.probabilistic: ens_dist = dist_rearrange( preds, "(m b) c h w -> (b c h w) m", b=batch_size ) mix = Categorical( torch.ones( (dist_size(preds)[0] // batch_size), device=self.device ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) if batch_idx == 0 and self.log_plots: num_images = ( self.num_image_plot if self.num_image_plot < inputs.size(0) else inputs.size(0) ) self._plot_depth( inputs[:num_images, ...], preds[:num_images, ...], targets[:num_images, ...], stage="test", ) padding_mask = torch.isnan(targets) self.test_metrics.update(preds[padding_mask], targets[padding_mask]) if self.probabilistic: self.test_prob_metrics.update(mixture, targets, padding_mask) def on_validation_epoch_end(self) -> None: res_dict = self.val_metrics.compute() self.log_dict(res_dict, logger=True, sync_dist=True) self.log( "RMSE", res_dict["val/reg/RMSE"], prog_bar=True, logger=False, sync_dist=True, ) self.val_metrics.reset() if self.probabilistic: self.log_dict( self.val_prob_metrics.compute(), sync_dist=True, ) self.val_prob_metrics.reset() def on_test_epoch_end(self) -> None: self.log_dict( self.test_metrics.compute(), sync_dist=True, ) self.test_metrics.reset() if self.probabilistic: self.log_dict( self.test_prob_metrics.compute(), sync_dist=True, ) self.test_prob_metrics.reset() def _plot_depth( self, inputs: Tensor, preds: Tensor, target: Tensor, stage: Literal["val", "test"], ) -> None: if ( self.logger is not None and isinstance(self.logger, TensorBoardLogger) and self.one_dim_depth ): all_imgs = [] for i in range(inputs.size(0)): img = F.normalize(inputs[i, ...].cpu(), **self.inv_norm_params) pred = colorize( preds[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth ) tgt = colorize( target[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth ) all_imgs.extend([img, pred, tgt]) self.logger.experiment.add_image( f"{stage}/samples", make_grid(torch.stack(all_imgs, dim=0), nrow=3), self.current_epoch, )
def colorize( value: Tensor, vmin: float | None = None, vmax: float | None = None, cmap: str = "magma", ): """Colorize a tensor of depth values. Args: value (Tensor): The tensor of depth values. vmin (float, optional): The minimum depth value. Defaults to None. vmax (float, optional): The maximum depth value. Defaults to None. cmap (str, optional): The colormap to use. Defaults to 'magma'. """ vmin = value.min().item() if vmin is None else vmin vmax = value.max().item() if vmax is None else vmax if vmin == vmax: return torch.zeros_like(value) value = (value - vmin) / (vmax - vmin) cmapper = cm.get_cmap(cmap) value = cmapper(value.numpy(), bytes=True) img = value[:, :, :3] return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 def _depth_routine_checks( output_dim: int, num_image_plot: int, log_plots: bool ) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") if num_image_plot < 1 and log_plots: raise ValueError( f"num_image_plot must be positive, got {num_image_plot}." )