Shortcuts

torch_uncertainty.models.batch_ensemble

torch_uncertainty.models.batch_ensemble(model, num_estimators, repeat_training_inputs=False, convert_layers=False)[source]

BatchEnsemble wrapper for a model.

Parameters:
  • model (nn.Module) – model to wrap

  • num_estimators (int) – number of ensemble members

  • repeat_training_inputs (bool, optional) – whether to repeat the input batch during training. If True, the input batch is repeated during both training and evaluation. If False, the input batch is repeated only during evaluation. Default is False.

  • convert_layers (bool, optional) – whether to convert the model’s layers to BatchEnsemble layers. If True, the wrapper will convert all nn.Linear and nn.Conv2d layers to their BatchEnsemble counterparts. Default is False.

Returns:

BatchEnsemble wrapper for the model

Return type:

BatchEnsemble