batch_ensemble#
- torch_uncertainty.models.batch_ensemble(model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]#
BatchEnsemble wrapper for a model.
- Parameters:
model (nn.Module) – model to wrap
num_estimators (int) – number of ensemble members
repeat_training_inputs (bool, optional) – whether to repeat the input batch during training. If
True, the input batch is repeated during both training and evaluation. IfFalse, the input batch is repeated only during evaluation. Default isFalse.convert_layers (bool, optional) – whether to convert the model’s layers to BatchEnsemble layers. If
True, the wrapper will convert allnn.Linearandnn.Conv2dlayers to their BatchEnsemble counterparts. Default isFalse.
- Returns:
BatchEnsemble wrapper for the model
- Return type: