Shortcuts

Training a LeNet with Monte-Carlo Dropout

In this tutorial, we will train a LeNet classifier on the MNIST dataset using Monte-Carlo Dropout (MC Dropout), a computationally efficient Bayesian approximation method. To estimate the predictive mean and uncertainty (variance), we perform multiple forward passes through the network with dropout layers enabled in train mode.

For more information on Monte-Carlo Dropout, we refer the reader to the following resources:

  • Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning ICML 2016

  • What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? NeurIPS 2017

Training a LeNet with MC Dropout using TorchUncertainty models and PyTorch Lightning

In this part, we train a LeNet with dropout layers, based on the model and routines already implemented in TU.

1. Loading the utilities

First, we have to load the following utilities from TorchUncertainty:

  • the TUTrainer from TorchUncertainty utils

  • the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules

  • the model: lenet from torch_uncertainty.models

  • the MC Dropout wrapper: mc_dropout, from torch_uncertainty.models.wrappers

  • the classification training & evaluation routine in the torch_uncertainty.routines

  • an optimization recipe in the torch_uncertainty.optim_recipes module.

We also need import the neural network utils within torch.nn.

from pathlib import Path

from torch_uncertainty import TUTrainer
from torch import nn

from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.models import mc_dropout
from torch_uncertainty.optim_recipes import optim_cifar10_resnet18
from torch_uncertainty.routines import ClassificationRoutine

2. Defining the Model and the Trainer

In the following, we first create the trainer and instantiate the datamodule that handles the MNIST dataset, dataloaders and transforms. We create the model using the blueprint from torch_uncertainty.models and we wrap it into an mc_dropout. To use the mc_dropout wrapper, make sure that you use dropout modules and not functionals. Moreover, they have to be instantiated in the __init__ method.

trainer = TUTrainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False)

# datamodule
root = Path("data")
datamodule = MNISTDataModule(root=root, batch_size=128)


model = lenet(
    in_channels=datamodule.num_channels,
    num_classes=datamodule.num_classes,
    dropout_rate=0.4,
)

mc_model = mc_dropout(model, num_estimators=16, last_layer=False)

3. The Loss and the Training Routine

This is a classification problem, and we use CrossEntropyLoss as the (negative-log-)likelihood. We define the training routine using the classification training routine from torch_uncertainty.routines. We provide the number of classes the optimization recipe and tell the routine that our model is an ensemble at evaluation time.

routine = ClassificationRoutine(
    num_classes=datamodule.num_classes,
    model=mc_model,
    loss=nn.CrossEntropyLoss(),
    optim_recipe=optim_cifar10_resnet18(mc_model),
    is_ensemble=True,
)

4. Gathering Everything and Training the Model

We can now train the model using the trainer. We pass the routine and the datamodule to the fit and test methods of the trainer. It will automatically evaluate some uncertainty metrics that you will find in the table below.

trainer.fit(model=routine, datamodule=datamodule)
results = trainer.test(model=routine, datamodule=datamodule)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0.00/9.91M [00:00<?, ?B/s]
  1%|          | 98.3k/9.91M [00:00<00:14, 697kB/s]
  4%|▎         | 360k/9.91M [00:00<00:06, 1.38MB/s]
 12%|█▏        | 1.15M/9.91M [00:00<00:02, 3.29MB/s]
 47%|████▋     | 4.62M/9.91M [00:00<00:00, 11.7MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 15.4MB/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 410kB/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0.00/1.65M [00:00<?, ?B/s]
  6%|▌         | 98.3k/1.65M [00:00<00:02, 682kB/s]
 24%|██▍       | 393k/1.65M [00:00<00:00, 1.48MB/s]
 93%|█████████▎| 1.54M/1.65M [00:00<00:00, 4.45MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.81MB/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.85MB/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          83.28%           │
│    Brier     │          0.41875          │
│   Entropy    │          1.66528          │
│     NLL      │          0.93644          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          0.37747          │
│     aECE     │          0.37747          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │           3.48%           │
│     AURC     │           4.33%           │
│  Cov@5Risk   │          63.91%           │
│  Risk@80Cov  │           9.34%           │
└──────────────┴───────────────────────────┘

5. Testing the Model

Now that the model is trained, let’s test it on MNIST. Don’t forget to call .eval() to enable dropout at evaluation and get multiple (here 16) predictions.

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision


def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis("off")
    plt.tight_layout()
    plt.show()


dataiter = iter(datamodule.val_dataloader())
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images[:6, ...], padding=0))
print("Ground truth labels: ", " ".join(f"{labels[j]}" for j in range(6)))

routine.eval()
logits = routine(images).reshape(16, 128, 10)

probs = torch.nn.functional.softmax(logits, dim=-1)


for j in range(6):
    values, predicted = torch.max(probs[:, j], 1)
    print(
        f"MC-Dropout predictions for the image {j+1}: ",
        " ".join([str(image_id.item()) for image_id in predicted]),
    )
tutorial mc dropout
Ground truth labels:  7 2 1 0 4 1
MC-Dropout predictions for the image 1:  7 7 7 1 7 3 7 7 7 7 7 7 7 7 7 2
MC-Dropout predictions for the image 2:  6 2 2 2 2 3 2 2 8 0 8 2 2 1 6 2
MC-Dropout predictions for the image 3:  1 1 6 1 1 1 1 1 1 1 1 1 1 1 1 1
MC-Dropout predictions for the image 4:  0 6 8 6 0 8 0 0 8 0 0 8 9 0 6 1
MC-Dropout predictions for the image 5:  6 9 6 4 4 4 9 4 4 4 4 9 4 9 4 6
MC-Dropout predictions for the image 6:  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

Most of the time, we see that there is some disagreement between the samples of the dropout approximation of the posterior distribution.

Total running time of the script: (0 minutes 50.671 seconds)

Gallery generated by Sphinx-Gallery