batch_ensemble#
- torch_uncertainty.methods.batch_ensemble(core_model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]#
BatchEnsemble wrapper for a model.
- Parameters:
core_model (
Module) – Model to wrap.num_estimators (
int) – Number of ensemble members.repeat_training_inputs (
bool) – Whether to repeat the input batch during training. IfTrue, the input batch is repeated during both training and evaluation. IfFalse, the input batch is repeated only during evaluation. Defaults toFalse.convert_layers (
bool) – Whether to convert the model’s layers to BatchEnsemble layers. IfTrue, the wrapper will convert allnn.Linearandnn.Conv2dlayers to their BatchEnsemble counterparts. Defaults toFalse.
- Returns:
BatchEnsemble wrapper for the
core_model- Return type: