BatchEnsemble#
- class torch_uncertainty.methods.BatchEnsemble(core_model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]#
Wrap a BatchEnsemble model to ensure correct batch replication.
In a BatchEnsemble architecture, each estimator operates on a sub-batch of the input. This means that the input batch must be repeated
num_estimatorstimes before being processed.This wrapper automatically duplicates the input batch along the first axis, ensuring that each estimator receives the correct data format.
- Parameters:
core_model (
Module) – The BatchEnsemble model.num_estimators (
int) – Number of ensemble members.repeat_training_inputs (
bool) – Whether to repeat the input batch during training. IfTrue, the input batch is repeated during both training and evaluation. IfFalse, the input batch is repeated only during evaluation. Defaults toFalse.convert_layers (
bool) – Whether to convert the model’s layers to BatchEnsemble layers. IfTrue, the wrapper will convert allnn.Linearandnn.Conv2dlayers to their BatchEnsemble counterparts. Defaults toFalse.
- Raises:
ValueError – If neither
BatchLinearnorBatchConv2dlayers are found in the core_model at the end of initialization.ValueError – If
num_estimatorsis less than or equal to0.ValueError – If
convert_layers=Trueand neithernn.Linearnornn.Conv2dlayers are found in the core_model.
Warning
If
convert_layers==True, the wrapper will attempt to convert allnn.Linearandnn.Conv2dlayers in the core_model to their BatchEnsemble counterparts. If the core_model contains other types of layers, the conversion won’t happen for these layers. If don’t have anynn.Linearornn.Conv2dlayers in the core_model, the wrapper will raise an error during conversion.Warning
If
repeat_training_inputs==Trueand you want to use one of thetorch_uncertainty.routinesfor training, be sure to setformat_batch_fn=RepeatTarget(num_repeats=num_estimators)when initializing the routine.Example
>>> core_model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2)) >>> model = BatchEnsemble(core_model, num_estimators=4, convert_layers=True) >>> model BatchEnsemble( (core_model): Sequential( (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) (1): ReLU() (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) ) )