MCBatchNorm¶
- class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]¶
Monte Carlo Batch Normalization wrapper.
- Parameters:
model (nn.Module) – model to be converted.
num_estimators (int) – number of estimators.
convert (bool) – whether to convert the model.
mc_batch_size (int, optional) – Monte Carlo batch size. Defaults to 32.
device (Literal["cpu", "cuda"] | torch.device | None, optional) – device. Defaults to None.
Note
This wrapper will be stochastic in eval mode only.
- Reference:
Teye M, Azizpour H, Smith K. Bayesian uncertainty estimation for batch normalized deep networks. In ICML 2018.
- fit(dataset)[source]¶
Fit the model on the dataset.
- Parameters:
dataset (Dataset) – dataset to be used for fitting.
Note
This method is used to populate the MC BatchNorm layers. Use the training dataset.