Shortcuts

From a Standard Classifier to a Packed-Ensemble

This tutorial is heavily inspired by PyTorch’s Training a Classifier tutorial.

Let’s dive step by step into the process to modify a standard classifier into a packed-ensemble classifier.

Dataset

In this tutorial we will use the CIFAR10 dataset available in the torchvision package. The CIFAR10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

Here is an example of what the data looks like:

cifar10

Sample of the CIFAR-10 dataset

Training an image Packed-Ensemble classifier

Here is the outline of the process:

  1. Load and normalizing the CIFAR10 training and test datasets using torchvision

  2. Define a Packed-Ensemble from a standard classifier

  3. Define a loss function

  4. Train the Packed-Ensemble on the training data

  5. Test the Packed-Ensemble on the test data and evaluate its performance w.r.t. uncertainty quantification and OOD detection

1. Load and normalize CIFAR10

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

torch.set_num_threads(1)

The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1].

Note

If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

batch_size = 4

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
testloader = DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2
)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz

  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 32.8k/170M [00:00<11:50, 240kB/s]
  0%|          | 229k/170M [00:00<03:01, 940kB/s]
  1%|          | 885k/170M [00:00<01:03, 2.68MB/s]
  2%|▏         | 3.51M/170M [00:00<00:18, 9.16MB/s]
  6%|▌         | 9.63M/170M [00:00<00:06, 24.3MB/s]
  7%|▋         | 12.7M/170M [00:00<00:06, 25.5MB/s]
 10%|█         | 17.3M/170M [00:00<00:04, 31.4MB/s]
 13%|█▎        | 22.1M/170M [00:00<00:04, 35.4MB/s]
 15%|█▌        | 25.9M/170M [00:01<00:04, 36.1MB/s]
 18%|█▊        | 31.3M/170M [00:01<00:03, 41.2MB/s]
 21%|██        | 35.5M/170M [00:01<00:03, 38.7MB/s]
 23%|██▎       | 39.6M/170M [00:01<00:03, 39.1MB/s]
 26%|██▌       | 43.9M/170M [00:01<00:03, 40.1MB/s]
 28%|██▊       | 48.0M/170M [00:01<00:03, 39.3MB/s]
 31%|███       | 52.2M/170M [00:01<00:02, 40.0MB/s]
 33%|███▎      | 56.2M/170M [00:01<00:02, 39.3MB/s]
 35%|███▌      | 60.4M/170M [00:01<00:02, 40.0MB/s]
 38%|███▊      | 64.5M/170M [00:02<00:02, 40.1MB/s]
 40%|████      | 68.5M/170M [00:02<00:02, 39.3MB/s]
 43%|████▎     | 72.7M/170M [00:02<00:02, 39.9MB/s]
 45%|████▍     | 76.7M/170M [00:02<00:02, 39.1MB/s]
 47%|████▋     | 80.8M/170M [00:02<00:02, 39.5MB/s]
 50%|████▉     | 84.8M/170M [00:02<00:02, 39.5MB/s]
 52%|█████▏    | 89.0M/170M [00:02<00:02, 40.3MB/s]
 55%|█████▍    | 93.1M/170M [00:02<00:01, 40.3MB/s]
 57%|█████▋    | 97.2M/170M [00:02<00:01, 40.5MB/s]
 59%|█████▉    | 101M/170M [00:02<00:01, 40.1MB/s]
 62%|██████▏   | 105M/170M [00:03<00:01, 40.6MB/s]
 64%|██████▍   | 110M/170M [00:03<00:01, 39.7MB/s]
 67%|██████▋   | 114M/170M [00:03<00:01, 39.8MB/s]
 69%|██████▉   | 118M/170M [00:03<00:01, 40.3MB/s]
 71%|███████▏  | 122M/170M [00:03<00:01, 39.9MB/s]
 74%|███████▍  | 126M/170M [00:03<00:01, 40.4MB/s]
 76%|███████▌  | 130M/170M [00:03<00:01, 39.6MB/s]
 79%|███████▊  | 134M/170M [00:03<00:00, 39.7MB/s]
 81%|████████  | 138M/170M [00:03<00:00, 40.0MB/s]
 84%|████████▎ | 142M/170M [00:03<00:00, 39.8MB/s]
 86%|████████▌ | 146M/170M [00:04<00:00, 40.0MB/s]
 88%|████████▊ | 151M/170M [00:04<00:00, 39.6MB/s]
 91%|█████████ | 155M/170M [00:04<00:00, 41.0MB/s]
 93%|█████████▎| 159M/170M [00:04<00:00, 41.4MB/s]
 96%|█████████▌| 164M/170M [00:04<00:00, 42.7MB/s]
 99%|█████████▊| 168M/170M [00:04<00:00, 42.4MB/s]
100%|██████████| 170M/170M [00:04<00:00, 36.7MB/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified

Let us show some of the training images, for fun.

import matplotlib.pyplot as plt

import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.figure(figsize=(10, 3))
    plt.axis("off")
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images, pad_value=1))
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))
tutorial pe cifar10
deer  frog  dog   dog

2. Define a Packed-Ensemble from a standard classifier

First we define a standard classifier for CIFAR10 for reference. We will use a convolutional neural network.

import torch.nn.functional as F
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

Let’s modify the standard classifier into a Packed-Ensemble classifier of parameters \(M=4,\ \alpha=2\text{ and }\gamma=1\).

from einops import rearrange

from torch_uncertainty.layers import PackedConv2d, PackedLinear


class PackedNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        M = 4
        alpha = 2
        gamma = 1
        self.conv1 = PackedConv2d(
            3, 6, 5, alpha=alpha, num_estimators=M, gamma=gamma, first=True
        )
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = PackedConv2d(6, 16, 5, alpha=alpha, num_estimators=M, gamma=gamma)
        self.fc1 = PackedLinear(
            16 * 5 * 5, 120, alpha=alpha, num_estimators=M, gamma=gamma
        )
        self.fc2 = PackedLinear(120, 84, alpha=alpha, num_estimators=M, gamma=gamma)
        self.fc3 = PackedLinear(
            84, 10 * M, alpha=alpha, num_estimators=M, gamma=gamma, last=True
        )

        self.num_estimators = M

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = rearrange(x, "e (m c) h w -> (m e) c h w", m=self.num_estimators)
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


packed_net = PackedNet()

3. Define a Loss function and optimizer

Let’s use a Classification Cross-Entropy loss and SGD with momentum.

from torch import optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(packed_net.parameters(), lr=0.001, momentum=0.9)

4. Train the Packed-Ensemble on the training data

Let’s train the Packed-Ensemble on the training data.

for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = packed_net(inputs)
        loss = criterion(outputs, labels.repeat(packed_net.num_estimators))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
            running_loss = 0.0

print("Finished Training")
[1,  2000] loss: 2.631
[1,  4000] loss: 2.168
[1,  6000] loss: 2.048
[1,  8000] loss: 1.954
[1, 10000] loss: 1.863
[1, 12000] loss: 1.808
[2,  2000] loss: 1.722
[2,  4000] loss: 1.681
[2,  6000] loss: 1.643
[2,  8000] loss: 1.620
[2, 10000] loss: 1.609
[2, 12000] loss: 1.573
Finished Training

Save our trained model:

PATH = "./cifar_packed_net.pth"
torch.save(packed_net.state_dict(), PATH)

5. Test the Packed-Ensemble on the test data

Let us display an image from the test set to get familiar.

dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images, pad_value=1))
print(
    "GroundTruth: ",
    " ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)),
)
tutorial pe cifar10
GroundTruth:  cat   ship  ship  plane

Next, let us load back in our saved model (note: saving and re-loading the model wasn’t necessary here, we only did it to illustrate how to do so):

packed_net = PackedNet()
packed_net.load_state_dict(torch.load(PATH))
/home/runner/work/torch-uncertainty/torch-uncertainty/auto_tutorials_source/tutorial_pe_cifar10.py:264: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  packed_net.load_state_dict(torch.load(PATH))

<All keys matched successfully>

Let us see what the Packed-Ensemble thinks these examples above are:

logits = packed_net(images)
logits = rearrange(logits, "(n b) c -> b n c", n=packed_net.num_estimators)
probs_per_est = F.softmax(logits, dim=-1)
outputs = probs_per_est.mean(dim=1)

_, predicted = torch.max(outputs, 1)

print(
    "Predicted: ",
    " ".join(f"{classes[predicted[j]]:5s}" for j in range(batch_size)),
)
Predicted:  cat   ship  ship  ship

The results seem pretty good.

Total running time of the script: (1 minutes 55.850 seconds)

Gallery generated by Sphinx-Gallery