BatchEnsemble¶
- class torch_uncertainty.models.BatchEnsemble(model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]¶
Wrap a BatchEnsemble model to ensure correct batch replication.
In a BatchEnsemble architecture, each estimator operates on a sub-batch of the input. This means that the input batch must be repeated
num_estimators
times before being processed.This wrapper automatically duplicates the input batch along the first axis, ensuring that each estimator receives the correct data format.
- Parameters:
model (nn.Module) – The BatchEnsemble model.
num_estimators (int) – Number of ensemble members.
repeat_training_inputs (optional, bool) – Whether to repeat the input batch during training. If
True
, the input batch is repeated during both training and evaluation. IfFalse
, the input batch is repeated only during evaluation. Default isFalse
.convert_layers (optional, bool) – Whether to convert the model’s layers to BatchEnsemble layers. If
True
, the wrapper will convert allnn.Linear
andnn.Conv2d
layers to their BatchEnsemble counterparts. Default isFalse
.
- Raises:
ValueError – If neither
BatchLinear
norBatchConv2d
layers are found in the model at the end of initialization.ValueError – If
num_estimators
is less than or equal to0
.ValueError – If
convert_layers=True
and neithernn.Linear
nornn.Conv2d
layers are found in the model.
Warning
If
convert_layers==True
, the wrapper will attempt to convert allnn.Linear
andnn.Conv2d
layers in the model to their BatchEnsemble counterparts. If the model contains other types of layers, the conversion won’t happen for these layers. If don’t have anynn.Linear
ornn.Conv2d
layers in the model, the wrapper will raise an error during conversion.Warning
If
repeat_training_inputs==True
and you want to use one of thetorch_uncertainty.routines
for training, be sure to setformat_batch_fn=RepeatTarget(num_repeats=num_estimators)
when initializing the routine.Example
>>> model = nn.Sequential( ... nn.Linear(10, 5), ... nn.ReLU(), ... nn.Linear(5, 2) ... ) >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) >>> model BatchEnsemble( (model): Sequential( (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) (1): ReLU() (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) ) )