Source code for torch_uncertainty.routines.regression
import torch
from einops import rearrange
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor, nn
from torch.distributions import (
Categorical,
Distribution,
MixtureSameFamily,
)
from torch.optim import Optimizer
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.metrics import (
DistributionNLL,
)
from torch_uncertainty.models import (
EPOCH_UPDATE_MODEL,
STEP_UPDATE_MODEL,
)
from torch_uncertainty.utils.distributions import (
dist_rearrange,
dist_size,
dist_squeeze,
)
[docs]class RegressionRoutine(LightningModule):
def __init__(
self,
model: nn.Module,
output_dim: int,
probabilistic: bool,
loss: nn.Module,
is_ensemble: bool = False,
optim_recipe: dict | Optimizer | None = None,
eval_shift: bool = False,
format_batch_fn: nn.Module | None = None,
) -> None:
r"""Routine for training & testing on **regression** tasks.
Args:
model (torch.nn.Module): Model to train.
output_dim (int): Number of outputs of the model.
probabilistic (bool): Whether the model is probabilistic, i.e.,
outputs a PyTorch distribution.
loss (torch.nn.Module): Loss function to optimize the :attr:`model`.
is_ensemble (bool, optional): Whether the model is an ensemble.
Defaults to ``False``.
optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and
optionally the scheduler to use. Defaults to ``None``.
eval_shift (bool, optional): Indicates whether to evaluate the Distribution
shift performance. Defaults to ``False``.
format_batch_fn (torch.nn.Module, optional): The function to format the
batch. Defaults to ``None``.
Warning:
If :attr:`probabilistic` is True, the model must output a `PyTorch
distribution <https://pytorch.org/docs/stable/distributions.html>`_.
Warning:
You must define :attr:`optim_recipe` if you do not use
the CLI.
Note:
:attr:`optim_recipe` can be anything that can be returned by
:meth:`LightningModule.configure_optimizers()`. Find more details
`here <https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers>`_.
"""
super().__init__()
_regression_routine_checks(output_dim)
if eval_shift:
raise NotImplementedError(
"Distribution shift evaluation not implemented yet. Raise an issue "
"if needed."
)
self.model = model
self.probabilistic = probabilistic
self.output_dim = output_dim
self.loss = loss
self.is_ensemble = is_ensemble
self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL)
self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL)
if format_batch_fn is None:
format_batch_fn = nn.Identity()
self.optim_recipe = optim_recipe
self.format_batch_fn = format_batch_fn
reg_metrics = MetricCollection(
{
"reg/MAE": MeanAbsoluteError(),
"reg/MSE": MeanSquaredError(squared=True),
"reg/RMSE": MeanSquaredError(squared=False),
},
compute_groups=True,
)
self.val_metrics = reg_metrics.clone(prefix="val/")
self.test_metrics = reg_metrics.clone(prefix="test/")
if self.probabilistic:
reg_prob_metrics = MetricCollection(
{"reg/NLL": DistributionNLL(reduction="mean")}
)
self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/")
self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/")
self.one_dim_regression = output_dim == 1
def configure_optimizers(self) -> Optimizer | dict:
return self.optim_recipe
def on_train_start(self) -> None:
if self.logger is not None: # coverage: ignore
self.logger.log_hyperparams(
self.hparams,
)
def on_validation_start(self) -> None:
if self.needs_epoch_update and not self.trainer.sanity_checking:
self.model.update_wrapper(self.current_epoch)
if hasattr(self.model, "need_bn_update"):
self.model.bn_update(
self.trainer.train_dataloader, device=self.device
)
def on_test_start(self) -> None:
if hasattr(self.model, "need_bn_update"):
self.model.bn_update(
self.trainer.train_dataloader, device=self.device
)
[docs] def forward(self, inputs: Tensor) -> Tensor | Distribution:
"""Forward pass of the routine.
The forward pass automatically squeezes the output if the regression
is one-dimensional and if the routine contains a single model.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
pred = self.model(inputs)
if self.probabilistic:
if self.one_dim_regression:
pred = dist_squeeze(pred, -1)
if not self.is_ensemble:
pred = dist_squeeze(pred, -1)
else:
if self.one_dim_regression:
pred = pred.squeeze(-1)
if not self.is_ensemble:
pred = pred.squeeze(-1)
return pred
def training_step(
self, batch: tuple[Tensor, Tensor], batch_idx: int
) -> STEP_OUTPUT:
inputs, targets = self.format_batch_fn(batch)
if self.one_dim_regression:
targets = targets.unsqueeze(-1)
if isinstance(self.loss, ELBOLoss):
loss = self.loss(inputs, targets)
else:
dists = self.model(inputs)
loss = self.loss(dists, targets)
if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss
def validation_step(
self, batch: tuple[Tensor, Tensor], batch_idx: int
) -> None:
inputs, targets = batch
if self.one_dim_regression:
targets = targets.unsqueeze(-1)
batch_size = targets.size(0)
targets = rearrange(targets, "b c -> (b c)")
preds = self.model(inputs)
if self.probabilistic:
ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size)
mix = Categorical(
torch.ones(
dist_size(preds)[0] // batch_size, device=self.device
)
)
mixture = MixtureSameFamily(mix, ens_dist)
preds = mixture.mean
else:
preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size)
preds = preds.mean(dim=1)
self.val_metrics.update(preds, targets)
if self.probabilistic:
self.val_prob_metrics.update(mixture, targets)
def test_step(
self,
batch: tuple[Tensor, Tensor],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
if dataloader_idx != 0:
raise NotImplementedError(
"Regression OOD detection not implemented yet. Raise an issue "
"if needed."
)
inputs, targets = batch
if self.one_dim_regression:
targets = targets.unsqueeze(-1)
batch_size = targets.size(0)
targets = rearrange(targets, "b c -> (b c)")
preds = self.model(inputs)
if self.probabilistic:
ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size)
mix = Categorical(
torch.ones(
dist_size(preds)[0] // batch_size, device=self.device
)
)
mixture = MixtureSameFamily(mix, ens_dist)
preds = mixture.mean
else:
preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size)
preds = preds.mean(dim=1)
self.test_metrics.update(preds, targets)
if self.probabilistic:
self.test_prob_metrics.update(mixture, targets)
def on_validation_epoch_end(self) -> None:
res_dict = self.val_metrics.compute()
self.log_dict(res_dict, logger=True, sync_dist=True)
self.log(
"RMSE",
res_dict["val/reg/RMSE"],
prog_bar=True,
logger=False,
sync_dist=True,
)
self.val_metrics.reset()
if self.probabilistic:
self.log_dict(self.val_prob_metrics.compute(), sync_dist=True)
self.val_prob_metrics.reset()
def on_test_epoch_end(self) -> None:
self.log_dict(
self.test_metrics.compute(),
)
self.test_metrics.reset()
if self.probabilistic:
self.log_dict(
self.test_prob_metrics.compute(),
)
self.test_prob_metrics.reset()
def _regression_routine_checks(output_dim: int) -> None:
if output_dim < 1:
raise ValueError(f"output_dim must be positive, got {output_dim}.")