Source code for torch_uncertainty.callbacks.compound_checkpoint
from datetime import timedelta
from pathlib import Path
from typing import Literal
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import Tensor
[docs]
class CompoundCheckpoint(ModelCheckpoint):
def __init__(
self,
compound_metric_dict: dict,
dirpath: str | Path | None = None,
verbose: bool = False,
save_last: bool | Literal["link"] = False,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = "min",
every_n_train_steps: int | None = None,
train_time_interval: timedelta | None = None,
every_n_epochs: int | None = None,
save_on_train_epoch_end: bool | None = None,
enable_version_counter: bool = True,
) -> None:
r"""Save the checkpoints maximizing or minimizing a given linear form on the metric values.
Args:
compound_metric_dict (dict): A dictionary mapping metric names (key) to their
corresponding factors (value) in the linear form:
.. math:: \sum_{i} \text{metric}_i \times \text{value}_i
dirpath (str | Path | None, optional): The directory to save the checkpoints in.
Defaults to ``None``.
verbose (bool, optional): Whether to print verbose output. Defaults to False.
save_last (bool | Literal["link"], optional): Whether to save the last checkpoint.
Defaults to ``False``.
save_top_k (int, optional): The number of best checkpoints to save. Defaults to ``1``.
save_weights_only (bool, optional): Whether to save only the weights. Defaults to
``False``.
mode (str, optional): The mode to optimize the compound metric. Defaults to ``"min"``.
every_n_train_steps (int | None, optional): The number of training steps to wait
between saving checkpoints. Defaults to ``None``.
train_time_interval (timedelta | None, optional): The time interval to wait between
saving checkpoints. Defaults to ``None``.
every_n_epochs (int | None, optional): The number of epochs to wait between saving
checkpoints. Defaults to ``None``.
save_on_train_epoch_end (bool | None, optional): Whether to save the checkpoint at the
end of each training epoch. Defaults to ``None``.
enable_version_counter (bool, optional): Whether to enable the version counter for the
saved checkpoints. Defaults to ``True``.
"""
self.compound_metric_dict = compound_metric_dict
super().__init__(
dirpath=dirpath,
filename="epoch={epoch}-step={step}-compound={compound_metric:.3f}",
monitor="compound_metric",
verbose=verbose,
save_last=save_last,
save_top_k=save_top_k,
save_weights_only=save_weights_only,
mode=mode,
auto_insert_metric_name=False,
every_n_train_steps=every_n_train_steps,
train_time_interval=train_time_interval,
every_n_epochs=every_n_epochs,
save_on_train_epoch_end=save_on_train_epoch_end,
enable_version_counter=enable_version_counter,
)
def _monitor_candidates(self, trainer: Trainer) -> dict[str, Tensor]:
monitor_candidates = super()._monitor_candidates(trainer)
result = torch.tensor(
0.0, dtype=torch.float32, device=next(iter(monitor_candidates.values())).device
)
for metric, factor in self.compound_metric_dict.items():
result += factor * monitor_candidates[metric].to(
dtype=result.dtype, device=result.device
)
monitor_candidates["compound_metric"] = result
return monitor_candidates