Source code for torch_uncertainty.post_processing.mc_batch_norm
from copy import deepcopy
from typing import Literal
import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset
from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d
from torch_uncertainty.post_processing import PostProcessing
[docs]class MCBatchNorm(PostProcessing):
counter: int = 0
mc_batch_norm_layers: list[MCBatchNorm2d] = []
trained = False
def __init__(
self,
model: nn.Module | None = None,
num_estimators: int = 16,
convert: bool = True,
mc_batch_size: int = 32,
device: Literal["cpu", "cuda"] | torch.device | None = None,
) -> None:
"""Monte Carlo Batch Normalization wrapper.
Args:
model (nn.Module): model to be converted.
num_estimators (int): number of estimators.
convert (bool): whether to convert the model.
mc_batch_size (int, optional): Monte Carlo batch size. Defaults to 32.
device (Literal["cpu", "cuda"] | torch.device | None, optional): device.
Defaults to None.
Note:
This wrapper will be stochastic in eval mode only.
Reference:
Teye M, Azizpour H, Smith K. Bayesian uncertainty estimation for
batch normalized deep networks. In ICML 2018.
"""
super().__init__()
self.mc_batch_size = mc_batch_size
self.convert = convert
self.num_estimators = num_estimators
self.device = device
if model is not None:
self._setup_model(model)
def _setup_model(self, model):
_mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert)
self.model = deepcopy(model) # Is it necessary?
self.model = self.model.eval()
if self.convert:
self._convert()
if not has_mcbn(self.model):
raise ValueError("model does not contain any MCBatchNorm2d after conversion.")
def set_model(self, model: nn.Module) -> None:
self.model = model
self._setup_model(model)
[docs] def fit(self, dataset: Dataset) -> None:
"""Fit the model on the dataset.
Args:
dataset (Dataset): dataset to be used for fitting.
Note:
This method is used to populate the MC BatchNorm layers.
Use the training dataset.
"""
self.dl = DataLoader(dataset, batch_size=self.mc_batch_size, shuffle=True)
self.counter = 0
self.reset_counters()
self.set_accumulate(True)
self.eval()
for x, _ in self.dl:
self.model(x.to(self.device))
self.raise_counters()
if self.counter == self.num_estimators:
self.set_accumulate(False)
self.trained = True
return
raise ValueError("The dataset is too small to populate the MC BatchNorm statistics.")
def _est_forward(self, x: Tensor) -> Tensor:
"""Forward pass of a single estimator."""
logit = self.model(x)
self.raise_counters()
return logit
def forward(
self,
x: Tensor,
) -> Tensor:
if self.training:
return self.model(x)
if not self.trained:
raise RuntimeError("MCBatchNorm has not been trained. Call .fit() first.")
self.reset_counters()
return torch.cat([self._est_forward(x) for _ in range(self.num_estimators)], dim=0)
def _convert(self) -> None:
"""Convert all BatchNorm2d layers to MCBatchNorm2d layers."""
self.replace_layers(self.model)
[docs] def reset_counters(self) -> None:
"""Reset all counters to 0."""
self.counter = 0
for layer in self.mc_batch_norm_layers:
layer.set_counter(0)
[docs] def raise_counters(self) -> None:
"""Raise all counters by 1."""
self.counter += 1
for layer in self.mc_batch_norm_layers:
layer.set_counter(self.counter)
[docs] def set_accumulate(self, accumulate: bool) -> None:
"""Set the accumulate flag for all MCBatchNorm2d layers.
Args:
accumulate (bool): accumulate flag.
"""
for layer in self.mc_batch_norm_layers:
layer.accumulate = accumulate
[docs] def replace_layers(self, model: nn.Module) -> None:
"""Replace all BatchNorm2d layers with MCBatchNorm2d layers.
Args:
model (nn.Module): model to be converted.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
self.replace_layers(module)
if isinstance(module, nn.BatchNorm2d):
mc_layer = MCBatchNorm2d(
num_features=module.num_features,
num_estimators=self.num_estimators,
eps=module.eps,
momentum=module.momentum,
affine=module.affine,
track_running_stats=module.track_running_stats,
device=module.weight.device,
dtype=module.weight.dtype,
)
mc_layer.training = module.training
mc_layer.weight = module.weight
mc_layer.bias = module.bias
setattr(model, name, mc_layer)
# Save pointers to the MC BatchNorm layers
self.mc_batch_norm_layers.append(mc_layer)
def has_mcbn(model: nn.Module) -> bool:
"""Check if the model contains any MCBatchNorm2d layers."""
return any(isinstance(module, MCBatchNorm2d) for module in model.modules())
def _mcbn_checks(model, num_estimators, mc_batch_size, convert):
if num_estimators < 1 or not isinstance(num_estimators, int):
raise ValueError(f"num_estimators must be a positive integer, got {num_estimators}.")
if mc_batch_size < 1 or not isinstance(mc_batch_size, int):
raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.")
if not convert and not has_mcbn(model):
raise ValueError("model does not contain any MCBatchNorm2d nor is not to be " "converted.")