.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorials/Bayesian_Methods/tutorial_bayesian.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorials_Bayesian_Methods_tutorial_bayesian.py: Training a Bayesian Neural Network in 20 seconds ================================================ In this tutorial, we will train a variational inference Bayesian Neural Network (viBNN) LeNet classifier on the MNIST dataset. Foreword on Bayesian Neural Networks ------------------------------------ Bayesian Neural Networks (BNNs) are a class of neural networks that estimate the uncertainty on their predictions via uncertainty on their weights. This is achieved by considering the weights of the neural network as random variables, and by learning their posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights (this can be seen as Dirac distributions on the weights). For more information on Bayesian Neural Networks, we refer to the following resources: - Weight Uncertainty in Neural Networks `ICML2015 `_ - Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine `_ Training a Bayesian LeNet using TorchUncertainty models and Lightning --------------------------------------------------------------------- In this first part, we train a Bayesian LeNet, based on the model and routines already implemented in TU. 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ To train a BNN using TorchUncertainty, we have to load the following modules: - our TUTrainer to improve the display of our metrics - the model: bayesian_lenet, which lies in the torch_uncertainty.model.classification.lenet module - the classification training routine from torch_uncertainty.routines module - the Bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file - the datamodule that handles dataloaders: MNISTDataModule from torch_uncertainty.datamodules We will also need to define an optimizer using torch.optim and Pytorch's neural network utils from torch.nn. .. GENERATED FROM PYTHON SOURCE LINES 43-57 .. code-block:: Python from pathlib import Path from torch import nn, optim from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule from torch_uncertainty.losses import ELBOLoss from torch_uncertainty.models.classification.lenet import bayesian_lenet from torch_uncertainty.routines import ClassificationRoutine # We also define the main hyperparameters, with just one epoch for the sake of time BATCH_SIZE = 512 MAX_EPOCHS = 2 .. GENERATED FROM PYTHON SOURCE LINES 58-66 2. Creating the necessary variables ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In the following, we instantiate our trainer, define the root of the datasets and the logs. We also create the datamodule that handles the MNIST dataset, dataloaders and transforms. Please note that the datamodules can also handle OOD detection by setting the `eval_ood` parameter to True, as well as distribution shift with `eval_shift`. Finally, we create the model using the blueprint from torch_uncertainty.models. .. GENERATED FROM PYTHON SOURCE LINES 66-76 .. code-block:: Python trainer = TUTrainer(accelerator="gpu", devices=1, enable_progress_bar=False, max_epochs=MAX_EPOCHS) # datamodule root = Path("data") datamodule = MNISTDataModule(root=root, batch_size=BATCH_SIZE, num_workers=8) # model model = bayesian_lenet(datamodule.num_channels, datamodule.num_classes) .. GENERATED FROM PYTHON SOURCE LINES 77-87 3. The Loss and the Training Routine ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Then, we just define the loss to be used during training, which is a bit special and called the evidence lower bound. We use the hyperparameters proposed in the blitz library. As we are training a classification model, we use the CrossEntropyLoss as the negative log likelihood. We then define the training routine using the classification training routine from torch_uncertainty.classification. We provide the model, the ELBO loss and the optimizer to the routine. We use an Adam optimizer with a learning rate of 0.02. .. GENERATED FROM PYTHON SOURCE LINES 87-103 .. code-block:: Python loss = ELBOLoss( model=model, inner_loss=nn.CrossEntropyLoss(), kl_weight=1 / 10000, num_samples=3, ) routine = ClassificationRoutine( model=model, num_classes=datamodule.num_classes, loss=loss, optim_recipe=optim.Adam(model.parameters(), lr=2e-2), is_ensemble=True, ) .. GENERATED FROM PYTHON SOURCE LINES 104-113 4. Gathering Everything and Training the Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that we have prepared all of this, we just have to gather everything in the main function and to train the model using our wrapper of Lightning Trainer. Specifically, it needs the routine, that includes the model as well as the training/eval logic and the datamodule. The dataset will be downloaded automatically in the root/data folder, and the logs will be saved in the root/logs folder. .. GENERATED FROM PYTHON SOURCE LINES 113-116 .. code-block:: Python trainer.fit(model=routine, datamodule=datamodule) trainer.test(model=routine, datamodule=datamodule) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/9.91M [00:00 None: npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.axis("off") plt.tight_layout() plt.show() images, labels = next(iter(datamodule.val_dataloader())) # print images imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) # Put the model in eval mode to use several samples model = routine.eval() logits = routine(images[:4, ...]) print("Output logit shape (Num predictions x Batch) x Classes: ", logits.shape) logits = rearrange(logits, "(m b) c -> b m c", b=4) # batch_size, num_estimators, num_classes # We apply the softmax on the classes then average over the estimators probs = torch.nn.functional.softmax(logits, dim=-1) avg_probs = probs.mean(dim=1) var_probs = probs.std(dim=1) predicted = torch.argmax(avg_probs, -1) print("Predicted digits: ", " ".join(f"{predicted[j]}" for j in range(4))) print( "Std. dev. of the scores over the posterior samples", " ".join(f"{var_probs[j][predicted[j]]:.3f}" for j in range(4)), ) .. image-sg:: /auto_tutorials/Bayesian_Methods/images/sphx_glr_tutorial_bayesian_001.png :alt: tutorial bayesian :srcset: /auto_tutorials/Bayesian_Methods/images/sphx_glr_tutorial_bayesian_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Ground truth: 7 2 1 0 Output logit shape (Num predictions x Batch) x Classes: torch.Size([64, 10]) Predicted digits: 7 2 1 0 Std. dev. of the scores over the posterior samples 0.000 0.028 0.002 0.026 .. GENERATED FROM PYTHON SOURCE LINES 165-175 Here, we show the variance of the top prediction. This is a non-standard but intuitive way to show the diversity of the predictions of the ensemble. Ideally, the variance should be high when the prediction is incorrect. References ---------- - **LeNet & MNIST:** LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. `Proceedings of the IEEE `_. - **Bayesian Neural Networks:** Blundell, C., Cornebise, J., Kavukcuoglu, K., & Wierstra, D. (2015). Weight Uncertainty in Neural Networks. `ICML 2015 `_. - **The Adam optimizer:** Kingma, D. P., & Ba, J. (2014). "Adam: A method for stochastic optimization." `ICLR 2015 `_. - **The Blitz** `library `_ (for the hyperparameters). .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 15.227 seconds) .. _sphx_glr_download_auto_tutorials_Bayesian_Methods_tutorial_bayesian.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tutorial_bayesian.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_bayesian.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tutorial_bayesian.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_