Source code for torch_uncertainty.baselines.classification.resnet

from typing import Literal

from torch import nn
from torch.optim import Optimizer

from torch_uncertainty.models import mc_dropout
from torch_uncertainty.models.classification import (
    batched_resnet,
    lpbnn_resnet,
    masked_resnet,
    mimo_resnet,
    packed_resnet,
    resnet,
)
from torch_uncertainty.ood_criteria import TUOODCriterion
from torch_uncertainty.routines.classification import ClassificationRoutine
from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget

ENSEMBLE_METHODS = [
    "packed",
    "batched",
    "lpbnn",
    "masked",
    "mc-dropout",
    "mimo",
]


[docs] class ResNetBaseline(ClassificationRoutine): versions = { "std": resnet, "packed": packed_resnet, "batched": batched_resnet, "lpbnn": lpbnn_resnet, "masked": masked_resnet, "mimo": mimo_resnet, "mc-dropout": resnet, } archs = [18, 20, 34, 44, 50, 56, 101, 110, 152, 1202] def __init__( self, num_classes: int, in_channels: int, loss: nn.Module, version: Literal[ "std", "mc-dropout", "packed", "batched", "lpbnn", "masked", "mimo", ], arch: int, style: str = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, num_estimators: int = 1, dropout_rate: float = 0.0, optim_recipe: dict | Optimizer | None = None, mixup_params: dict | None = None, last_layer_dropout: bool = False, width_multiplier: float = 1.0, groups: int = 1, conv_bias: bool = False, scale: float | None = None, alpha: int | None = None, gamma: int = 1, rho: float = 1.0, batch_repeat: int = 1, ood_criterion: TUOODCriterion | str = "msp", log_plots: bool = False, save_in_csv: bool = False, eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, num_bins_cal_err: int = 15, pretrained: bool = False, ) -> None: r"""ResNet backbone baseline for classification providing support for various versions and architectures. Args: num_classes (int): Number of classes to predict. in_channels (int): Number of input channels. loss (nn.Module): Training loss. optim_recipe (Any): optimization recipe, corresponds to what expect the `LightningModule.configure_optimizers() <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#configure-optimizers>`_ method. version (str): Determines which ResNet version to use: - ``"std"``: original ResNet - ``"packed"``: Packed-Ensembles ResNet - ``"batched"``: BatchEnsemble ResNet - ``"masked"``: Masksemble ResNet - ``"mimo"``: MIMO ResNet - ``"mc-dropout"``: Monte-Carlo Dropout ResNet arch (int): Determines which ResNet architecture to use, one of: - ``18``: ResNet-18 - ``32``: ResNet-32 - ``50``: ResNet-50 - ``101``: ResNet-101 - ``152``: ResNet-152 style (str, optional): Which ResNet style to use. Defaults to ``imagenet``. normalization_layer (type[nn.Module], optional): Normalization layer to use. Defaults to ``nn.BatchNorm2d``. num_estimators (int, optional): Number of estimators in the ensemble. Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. mixup_params (dict, optional): Mixup parameters. Can include mixtype, mixmode, dist_sim, kernel_tau_max, kernel_tau_std, mixup_alpha, and cutmix_alpha. If None, no augmentations. Defaults to ``None``. width_multiplier (float, optional): Expansion factor affecting the width of the estimators. Defaults to ``1.0`` groups (int, optional): Number of groups in convolutions. Defaults to ``1``. scale (float, optional): Expansion factor affecting the width of the estimators. Only used if :attr:`version` is ``"masked"``. Defaults to ``None``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. conv_bias (bool, optional): Whether to include bias in the convolutional layers. Defaults to ``False``. scale (float, optional): Expansion factor affecting the width of the estimators. Only used if :attr:`version` is ``"masked"``. Defaults to ``None``. alpha (float, optional): Expansion factor affecting the width of the estimators. Only used if :attr:`version` is ``"packed"``. Defaults to ``None``. gamma (int, optional): Number of groups within each estimator. Only used if :attr:`version` is ``"packed"`` and scales with :attr:`groups`. Defaults to ``1``. rho (float, optional): Probability that all estimators share the same input. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. batch_repeat (int, optional): Number of times to repeat the batch. Only used if :attr:`version` is ``"mimo"``. Defaults to ``1``. ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. Defaults to None which amounts to the maximum softmax probability score (MSP). log_plots (bool, optional): Indicates whether to log the plots or not. Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. num_bins_cal_err (int, optional): Number of calibration bins. Defaults to ``15``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. Defaults to ``False``. Raises: ValueError: If :attr:`version` is not either ``"std"``, ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"``. Returns: LightningModule: ResNet baseline ready for training and evaluation. """ params = { "arch": arch, "conv_bias": conv_bias, "dropout_rate": dropout_rate, "groups": groups, "width_multiplier": width_multiplier, "in_channels": in_channels, "num_classes": num_classes, "style": style, "normalization_layer": normalization_layer, } format_batch_fn = nn.Identity() if version not in self.versions: raise ValueError(f"Unknown version: {version}") if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } if version != "mc-dropout": format_batch_fn = RepeatTarget(num_repeats=num_estimators) if version == "packed": params |= { "alpha": alpha, "gamma": gamma, "pretrained": pretrained, } elif version == "masked": params |= { "scale": scale, } elif version == "mimo": format_batch_fn = MIMOBatchFormat( num_estimators=num_estimators, rho=rho, batch_repeat=batch_repeat, ) if version == "mc-dropout": # std ResNets don't have `num_estimators` del params["num_estimators"] model = self.versions[version](**params) if version == "mc-dropout": model = mc_dropout( model=model, num_estimators=num_estimators, last_layer=last_layer_dropout, ) super().__init__( num_classes=num_classes, model=model, loss=loss, is_ensemble=version in ENSEMBLE_METHODS, optim_recipe=optim_recipe, format_batch_fn=format_batch_fn, mixup_params=mixup_params, eval_ood=eval_ood, eval_shift=eval_shift, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, num_bins_cal_err=num_bins_cal_err, ) self.save_hyperparameters(ignore=["loss"])