Shortcuts

BatchEnsemble

class torch_uncertainty.models.BatchEnsemble(model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]

Wrap a BatchEnsemble model to ensure correct batch replication.

In a BatchEnsemble architecture, each estimator operates on a sub-batch of the input. This means that the input batch must be repeated num_estimators times before being processed.

This wrapper automatically duplicates the input batch along the first axis, ensuring that each estimator receives the correct data format.

Parameters:
  • model (nn.Module) – The BatchEnsemble model.

  • num_estimators (int) – Number of ensemble members.

  • repeat_training_inputs (optional, bool) – Whether to repeat the input batch during training. If True, the input batch is repeated during both training and evaluation. If False, the input batch is repeated only during evaluation. Default is False.

  • convert_layers (optional, bool) – Whether to convert the model’s layers to BatchEnsemble layers. If True, the wrapper will convert all nn.Linear and nn.Conv2d layers to their BatchEnsemble counterparts. Default is False.

Raises:
  • ValueError – If neither BatchLinear nor BatchConv2d layers are found in the model at the end of initialization.

  • ValueError – If num_estimators is less than or equal to 0.

  • ValueError – If convert_layers=True and neither nn.Linear nor nn.Conv2d layers are found in the model.

Warning

If convert_layers==True, the wrapper will attempt to convert all nn.Linear and nn.Conv2d layers in the model to their BatchEnsemble counterparts. If the model contains other types of layers, the conversion won’t happen for these layers. If don’t have any nn.Linear or nn.Conv2d layers in the model, the wrapper will raise an error during conversion.

Warning

If repeat_training_inputs==True and you want to use one of the torch_uncertainty.routines for training, be sure to set format_batch_fn=RepeatTarget(num_repeats=num_estimators) when initializing the routine.

Example

>>> model = nn.Sequential(
...     nn.Linear(10, 5),
...     nn.ReLU(),
...     nn.Linear(5, 2)
... )
>>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True)
>>> model
BatchEnsemble(
(model): Sequential(
    (0): BatchLinear(in_features=10, out_features=5, num_estimators=4)
    (1): ReLU()
    (2): BatchLinear(in_features=5, out_features=2, num_estimators=4)
)
)
forward(x)[source]

Repeat the input if self.training==False or repeat_training_inputs==True and pass it through the model.