import torch
from einops import rearrange
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor, nn
from torch.distributions import (
Categorical,
Distribution,
Independent,
MixtureSameFamily,
)
from torch.optim import Optimizer
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.metrics import (
DistributionNLL,
)
from torch_uncertainty.models import (
EPOCH_UPDATE_MODEL,
STEP_UPDATE_MODEL,
)
from torch_uncertainty.utils.distributions import (
get_dist_class,
get_dist_estimate,
)
[docs]class RegressionRoutine(LightningModule):
def __init__(
self,
model: nn.Module,
output_dim: int,
loss: nn.Module,
dist_family: str | None = None,
dist_estimate: str = "mean",
is_ensemble: bool = False,
optim_recipe: dict | Optimizer | None = None,
eval_shift: bool = False,
format_batch_fn: nn.Module | None = None,
) -> None:
r"""Routine for training & testing on **regression** tasks.
Args:
model (torch.nn.Module): Model to train.
output_dim (int): Number of outputs of the model.
loss (torch.nn.Module): Loss function to optimize the :attr:`model`.
dist_family (str, optional): The distribution family to use for
probabilistic regression. If ``None`` then point-wise regression.
Defaults to ``None``.
dist_estimate (str, optional): The estimate to use when computing the
point-wise metrics. Defaults to ``"mean"``.
is_ensemble (bool, optional): Whether the model is an ensemble.
Defaults to ``False``.
optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and
optionally the scheduler to use. Defaults to ``None``.
eval_shift (bool, optional): Indicates whether to evaluate the Distribution
shift performance. Defaults to ``False``.
format_batch_fn (torch.nn.Module, optional): The function to format the
batch. Defaults to ``None``.
Warning:
If :attr:`probabilistic` is True, the model must output a `PyTorch
distribution <https://pytorch.org/docs/stable/distributions.html>`_.
Warning:
You must define :attr:`optim_recipe` if you do not use
the CLI.
Note:
:attr:`optim_recipe` can be anything that can be returned by
:meth:`LightningModule.configure_optimizers()`. Find more details
`here <https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers>`_.
"""
super().__init__()
_regression_routine_checks(output_dim)
if eval_shift:
raise NotImplementedError(
"Distribution shift evaluation not implemented yet. Raise an issue " "if needed."
)
self.model = model
self.dist_family = dist_family
self.dist_estimate = dist_estimate
self.probabilistic = dist_family is not None
self.output_dim = output_dim
self.loss = loss
self.is_ensemble = is_ensemble
self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL)
self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL)
if format_batch_fn is None:
format_batch_fn = nn.Identity()
self.optim_recipe = optim_recipe
self.format_batch_fn = format_batch_fn
self.one_dim_regression = output_dim == 1
self._init_metrics()
def _init_metrics(self) -> None:
"""Initialize the metrics depending on the exact task."""
reg_metrics = MetricCollection(
{
"reg/MAE": MeanAbsoluteError(),
"reg/MSE": MeanSquaredError(squared=True),
"reg/RMSE": MeanSquaredError(squared=False),
},
compute_groups=True,
)
self.val_metrics = reg_metrics.clone(prefix="val/")
self.test_metrics = reg_metrics.clone(prefix="test/")
if self.probabilistic:
reg_prob_metrics = MetricCollection({"reg/NLL": DistributionNLL(reduction="mean")})
self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/")
def configure_optimizers(self) -> Optimizer | dict:
return self.optim_recipe
[docs] def on_train_start(self) -> None:
"""Put the hyperparameters in tensorboard."""
if self.logger is not None: # coverage: ignore
self.logger.log_hyperparams(
self.hparams,
)
[docs] def on_validation_start(self) -> None:
"""Prepare the validation step.
Update the model's wrapper and the batchnorms if needed.
"""
if self.needs_epoch_update and not self.trainer.sanity_checking:
self.model.update_wrapper(self.current_epoch)
if hasattr(self.model, "need_bn_update"):
self.model.bn_update(self.trainer.train_dataloader, device=self.device)
[docs] def on_test_start(self) -> None:
"""Prepare the test step.
Update the batchnorms if needed.
"""
if hasattr(self.model, "need_bn_update"):
self.model.bn_update(self.trainer.train_dataloader, device=self.device)
[docs] def forward(self, inputs: Tensor) -> Tensor | dict[str, Tensor]:
"""Forward pass of the routine.
The forward pass automatically squeezes the output if the regression
is one-dimensional and if the routine contains a single model.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor | dict[str, Tensor]: The output tensor or the parameters of the output
distribution.
"""
pred = self.model(inputs)
if self.probabilistic:
if isinstance(pred, dict):
if self.one_dim_regression:
pred = {k: v.squeeze(-1) for k, v in pred.items()}
if not self.is_ensemble:
pred = {k: v.squeeze(-1) for k, v in pred.items()}
else:
raise TypeError(
"If the model is probabilistic, the output must be a dictionary ",
"of PyTorch distributions.",
)
else:
if self.one_dim_regression:
pred = pred.squeeze(-1)
if not self.is_ensemble:
pred = pred.squeeze(-1)
return pred
[docs] def training_step(self, batch: tuple[Tensor, Tensor]) -> STEP_OUTPUT:
"""Perform a single training step based on the input tensors.
Args:
batch (tuple[Tensor, Tensor]): the training data and their corresponding targets
Returns:
Tensor: the loss corresponding to this training step.
"""
inputs, targets = self.format_batch_fn(batch)
if self.one_dim_regression:
targets = targets.unsqueeze(-1)
if isinstance(self.loss, ELBOLoss):
loss = self.loss(inputs, targets)
else:
out = self.model(inputs)
if self.probabilistic:
# Adding the Independent wrapper to the distribution to compute correctly the
# log-likelihood given a target. Here the last dimension is the event dimension.
# When computing the log-likelihood, the values are summed over the event
# dimension.
dists = Independent(get_dist_class(self.dist_family)(**out), 1)
loss = self.loss(dists, targets)
else:
loss = self.loss(out, targets)
if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss
[docs] def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | None]:
"""Get the prediction and handle predicted eventual distribution parameters.
Args:
inputs (Tensor): the input data.
Returns:
tuple[Tensor, Distribution | None]: the prediction as a Tensor and a distribution.
"""
batch_size = inputs.size(0)
preds = self.model(inputs)
if self.probabilistic:
dist_params = {
k: rearrange(v, "(m b) c -> b m c", b=batch_size) for k, v in preds.items()
}
# Adding the Independent wrapper to the distribution to create a MixtureSameFamily.
# As required by the torch.distributions API, the last dimension is the event dimension.
comp = Independent(get_dist_class(self.dist_family)(**dist_params), 1)
mix = Categorical(torch.ones(comp.batch_shape, device=self.device))
dist = MixtureSameFamily(mix, comp)
preds = get_dist_estimate(comp, self.dist_estimate).mean(1)
return preds, dist
preds = rearrange(preds, "(m b) c -> b m c", b=batch_size)
return preds.mean(dim=1), None
[docs] def validation_step(self, batch: tuple[Tensor, Tensor]) -> None:
"""Perform a single validation step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the validation batch.
Args:
batch (tuple[Tensor, Tensor]): the validation data and their corresponding targets.
"""
inputs, targets = batch
if self.one_dim_regression:
targets = targets.unsqueeze(-1)
preds, dist = self.evaluation_forward(inputs)
self.val_metrics.update(preds, targets)
if isinstance(dist, Distribution):
self.val_prob_metrics.update(dist, targets)
[docs] def test_step(
self,
batch: tuple[Tensor, Tensor],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also
handle OOD and distribution-shifted images.
Args:
batch (tuple[Tensor, Tensor]): the test data and their corresponding targets.
batch_idx (int): the number of the current batch (unused).
dataloader_idx (int): 0 if in-distribution, 1 if out-of-distribution.
"""
if dataloader_idx != 0:
raise NotImplementedError(
"Regression OOD detection not implemented yet. Raise an issue " "if needed."
)
inputs, targets = batch
if self.one_dim_regression:
targets = targets.unsqueeze(-1)
preds, dist = self.evaluation_forward(inputs)
self.test_metrics.update(preds, targets)
if isinstance(dist, Distribution):
self.test_prob_metrics.update(dist, targets)
[docs] def on_validation_epoch_end(self) -> None:
"""Compute and log the values of the collected metrics in `validation_step`."""
res_dict = self.val_metrics.compute()
self.log_dict(res_dict, logger=True, sync_dist=True)
self.log(
"RMSE",
res_dict["val/reg/RMSE"],
prog_bar=True,
logger=False,
sync_dist=True,
)
self.val_metrics.reset()
if self.probabilistic:
self.log_dict(self.val_prob_metrics.compute(), sync_dist=True)
self.val_prob_metrics.reset()
[docs] def on_test_epoch_end(self) -> None:
"""Compute and log the values of the collected metrics in `test_step`."""
self.log_dict(
self.test_metrics.compute(),
)
self.test_metrics.reset()
if self.probabilistic:
self.log_dict(
self.test_prob_metrics.compute(),
)
self.test_prob_metrics.reset()
def _regression_routine_checks(output_dim: int) -> None:
"""Check the domains of the routine's parameters.
Args:
output_dim (int): the dimension of the output of the regression task.
"""
if output_dim < 1:
raise ValueError(f"output_dim must be positive, got {output_dim}.")