BatchEnsemble#
- class torch_uncertainty.models.BatchEnsemble(model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]#
Wrap a BatchEnsemble model to ensure correct batch replication.
In a BatchEnsemble architecture, each estimator operates on a sub-batch of the input. This means that the input batch must be repeated
num_estimatorstimes before being processed.This wrapper automatically duplicates the input batch along the first axis, ensuring that each estimator receives the correct data format.
- Parameters:
model (nn.Module) – The BatchEnsemble model.
num_estimators (int) – Number of ensemble members.
repeat_training_inputs (optional, bool) – Whether to repeat the input batch during training. If
True, the input batch is repeated during both training and evaluation. IfFalse, the input batch is repeated only during evaluation. Default isFalse.convert_layers (optional, bool) – Whether to convert the model’s layers to BatchEnsemble layers. If
True, the wrapper will convert allnn.Linearandnn.Conv2dlayers to their BatchEnsemble counterparts. Default isFalse.
- Raises:
ValueError – If neither
BatchLinearnorBatchConv2dlayers are found in the model at the end of initialization.ValueError – If
num_estimatorsis less than or equal to0.ValueError – If
convert_layers=Trueand neithernn.Linearnornn.Conv2dlayers are found in the model.
Warning
If
convert_layers==True, the wrapper will attempt to convert allnn.Linearandnn.Conv2dlayers in the model to their BatchEnsemble counterparts. If the model contains other types of layers, the conversion won’t happen for these layers. If don’t have anynn.Linearornn.Conv2dlayers in the model, the wrapper will raise an error during conversion.Warning
If
repeat_training_inputs==Trueand you want to use one of thetorch_uncertainty.routinesfor training, be sure to setformat_batch_fn=RepeatTarget(num_repeats=num_estimators)when initializing the routine.Example
>>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2)) >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) >>> model BatchEnsemble( (model): Sequential( (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) (1): ReLU() (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) ) )