from typing import Any, Literal
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import Checkpoint, ModelCheckpoint
from lightning.pytorch.utilities.types import STEP_OUTPUT
from typing_extensions import override
class _TUCheckpoint(Checkpoint):
callbacks: dict[str, ModelCheckpoint]
@override
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
for callback in self.callbacks.values():
callback.setup(trainer=trainer, pl_module=pl_module, stage=stage)
@override
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
for callback in self.callbacks.values():
callback.on_train_start(trainer=trainer, pl_module=pl_module)
@override
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
for callback in self.callbacks.values():
callback.on_train_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_idx=batch_idx,
)
@override
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
for callback in self.callbacks.values():
callback.on_train_epoch_end(trainer=trainer, pl_module=pl_module)
@override
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
for callback in self.callbacks.values():
callback.on_validation_epoch_end(trainer=trainer, pl_module=pl_module)
@override
def state_dict(self) -> dict[str, dict[str, Any]]:
return {key: callback.state_dict() for key, callback in self.callbacks.items()}
@override
def load_state_dict(self, state_dict: dict[str, dict[str, Any]]) -> None:
for key, callback in self.callbacks.items():
callback.load_state_dict(state_dict=state_dict[key])
@property
def best_model_path(self) -> str:
"""Return the path to the best model checkpoint based on the primary metric."""
raise NotImplementedError
[docs]
class TUClsCheckpoint(_TUCheckpoint):
def __init__(self, save_last: bool | Literal["link"] = False) -> None:
"""Keep multiple checkpoints corresponding to the best model in terms of: Accuracy,
Expected Calibration Error, Brier-Score and Negative Log-Likelihood.
Args:
save_last (bool | "link", optional): When ``True``, saves a last.ckpt copy whenever a
checkpoint file gets saved. Can be set to ``"link"`` on a local filesystem to create a
symbolic link. This allows accessing the latest checkpoint in a deterministic
manner. Default to ``False``.
"""
super().__init__()
self.callbacks = {
"acc": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_acc={val/cls/Acc:.3f}",
monitor="val/cls/Acc",
mode="max",
save_last=save_last,
auto_insert_metric_name=False,
),
"ece": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_ece={val/cal/ECE:.3f}",
monitor="val/cal/ECE",
mode="min",
auto_insert_metric_name=False,
),
"brier": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_brier={val/cls/Brier:.3f}",
monitor="val/cls/Brier",
mode="min",
auto_insert_metric_name=False,
),
"nll": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_nll={val/cls/NLL:.3f}",
monitor="val/cls/NLL",
mode="min",
auto_insert_metric_name=False,
),
}
@property
def best_model_path(self) -> str:
return self.callbacks["acc"].best_model_path
[docs]
class TUSegCheckpoint(_TUCheckpoint):
def __init__(self, save_last: bool | Literal["link"] = False) -> None:
"""Keep multiple checkpoints corresponding to the best model in terms of: Mean Intersection
over Union, Expected Calibration Error, Brier-Score and Negative Log-Likelihood.
Args:
save_last (bool | "link", optional): When ``True``, saves a last.ckpt copy whenever a
checkpoint file gets saved. Can be set to ``"link"`` on a local filesystem to create a
symbolic link. This allows accessing the latest checkpoint in a deterministic
manner. Default to ``False``.
"""
super().__init__()
self.callbacks = {
"miou": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_miou={val/seg/mIoU:.3f}",
monitor="val/seg/mIoU",
mode="max",
save_last=save_last,
auto_insert_metric_name=False,
),
"ece": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_ece={val/cal/ECE:.3f}",
monitor="val/cal/ECE",
mode="min",
auto_insert_metric_name=False,
),
"brier": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_brier={val/seg/Brier:.3f}",
monitor="val/seg/Brier",
mode="min",
auto_insert_metric_name=False,
),
"nll": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_nll={val/seg/NLL:.3f}",
monitor="val/seg/NLL",
mode="min",
auto_insert_metric_name=False,
),
}
@property
def best_model_path(self) -> str:
return self.callbacks["miou"].best_model_path
[docs]
class TURegCheckpoint(_TUCheckpoint):
def __init__(
self, probabilistic: bool = False, save_last: bool | Literal["link"] = False
) -> None:
"""Keep multiple checkpoints corresponding to the best model in terms of: Mean Squared
Error, and eventually the Negative Log-Likelihood and Quantile Calibration Error.
Args:
probabilistic (bool, optional): If ``True``, also tracks the Negative Log-Likelihood and
the Quantile Calibration Error. Default to ``False``.
save_last (bool | "link", optional): When ``True``, saves a last.ckpt copy whenever a
checkpoint file gets saved. Can be set to ``"link"`` on a local filesystem to create a
symbolic link. This allows accessing the latest checkpoint in a deterministic
manner. Default to ``False``.
"""
super().__init__()
self.callbacks = {
"mse": ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_mse={val/reg/MSE:.3f}",
monitor="val/reg/MSE",
mode="min",
auto_insert_metric_name=False,
save_last=save_last,
),
}
if probabilistic:
self.callbacks["nll"] = ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_nll={val/reg/NLL:.3f}",
monitor="val/reg/NLL",
mode="min",
auto_insert_metric_name=False,
)
self.callbacks["qce"] = ModelCheckpoint(
filename="epoch={epoch}-step={step}-val_qce={val/cal/QCE:.3f}",
monitor="val/cal/QCE",
mode="min",
auto_insert_metric_name=False,
)
@property
def best_model_path(self) -> str:
return self.callbacks["mse"].best_model_path