torch_uncertainty.models.batch_ensemble¶
- torch_uncertainty.models.batch_ensemble(model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]¶
BatchEnsemble wrapper for a model.
- Parameters:
model (nn.Module) – model to wrap
num_estimators (int) – number of ensemble members
repeat_training_inputs (bool, optional) – whether to repeat the input batch during training. If
True
, the input batch is repeated during both training and evaluation. IfFalse
, the input batch is repeated only during evaluation. Default isFalse
.convert_layers (bool, optional) – whether to convert the model’s layers to BatchEnsemble layers. If
True
, the wrapper will convert allnn.Linear
andnn.Conv2d
layers to their BatchEnsemble counterparts. Default isFalse
.
- Returns:
BatchEnsemble wrapper for the model
- Return type: