Shortcuts

Source code for torch_uncertainty.models.wrappers.batch_ensemble

import torch
from einops import repeat
from torch import nn

from torch_uncertainty.layers import BatchConv2d, BatchLinear


[docs]class BatchEnsemble(nn.Module): def __init__( self, model: nn.Module, num_estimators: int, repeat_training_inputs: bool = False, convert_layers: bool = False, ) -> None: """Wrap a BatchEnsemble model to ensure correct batch replication. In a BatchEnsemble architecture, each estimator operates on a **sub-batch** of the input. This means that the input batch must be **repeated** :attr:`num_estimators` times before being processed. This wrapper automatically **duplicates the input batch** along the first axis, ensuring that each estimator receives the correct data format. Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. If ``True``, the input batch is repeated during both training and evaluation. If ``False``, the input batch is repeated only during evaluation. Default is ``False``. convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. If ``True``, the wrapper will convert all ``nn.Linear`` and ``nn.Conv2d`` layers to their BatchEnsemble counterparts. Default is ``False``. Raises: ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the end of initialization. ValueError: If ``num_estimators`` is less than or equal to ``0``. ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are found in the model. Warning: If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` layers in the model to their BatchEnsemble counterparts. If the model contains other types of layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` layers in the model, the wrapper will raise an error during conversion. Warning: If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when initializing the routine. Example: >>> model = nn.Sequential( ... nn.Linear(10, 5), ... nn.ReLU(), ... nn.Linear(5, 2) ... ) >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) >>> model BatchEnsemble( (model): Sequential( (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) (1): ReLU() (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) ) ) """ super().__init__() self.model = model self.num_estimators = num_estimators self.repeat_training_inputs = repeat_training_inputs if convert_layers: self._convert_layers() filtered_modules = [ module for module in self.model.modules() if isinstance(module, BatchLinear | BatchConv2d) ] _batch_ensemble_checks(filtered_modules, num_estimators)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and pass it through the model.""" if not self.training or self.repeat_training_inputs: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.model(x)
def _convert_layers(self) -> None: """Convert the model's layers to BatchEnsemble layers.""" no_valid_layers = True for name, layer in self.model.named_modules(): if isinstance(layer, nn.Linear): setattr( self.model, name, BatchLinear.from_linear(layer, num_estimators=self.num_estimators), ) no_valid_layers = False elif isinstance(layer, nn.Conv2d): setattr( self.model, name, BatchConv2d.from_conv2d(layer, num_estimators=self.num_estimators), ) no_valid_layers = False if no_valid_layers: raise ValueError( "No valid layers found in the model. " "Please use `nn.Linear` or `nn.Conv2d` layers to apply BatchEnsemble." )
def _batch_ensemble_checks(filtered_modules, num_estimators): """Check if the model contains the required number of dropout modules.""" if len(filtered_modules) == 0: raise ValueError( "No BatchEnsemble layers found in the model. " "Please use `BatchLinear` or `BatchConv2d` layers in your model " "or set `convert_layers=True` when initializing the wrapper." ) if num_estimators <= 0: raise ValueError("`num_estimators` must be greater than 0.")
[docs]def batch_ensemble( model: nn.Module, num_estimators: int, repeat_training_inputs: bool = False, convert_layers: bool = False, ) -> BatchEnsemble: """BatchEnsemble wrapper for a model. Args: model (nn.Module): model to wrap num_estimators (int): number of ensemble members repeat_training_inputs (bool, optional): whether to repeat the input batch during training. If ``True``, the input batch is repeated during both training and evaluation. If ``False``, the input batch is repeated only during evaluation. Default is ``False``. convert_layers (bool, optional): whether to convert the model's layers to BatchEnsemble layers. If ``True``, the wrapper will convert all ``nn.Linear`` and ``nn.Conv2d`` layers to their BatchEnsemble counterparts. Default is ``False``. Returns: BatchEnsemble: BatchEnsemble wrapper for the model """ return BatchEnsemble( model=model, num_estimators=num_estimators, repeat_training_inputs=repeat_training_inputs, convert_layers=convert_layers, )