Note
Go to the end to download the full example code.
From a Standard Classifier to a Packed-Ensemble¶
This tutorial is heavily inspired by PyTorch’s Training a Classifier tutorial.
Let’s dive step by step into the process to modify a standard classifier into a packed-ensemble classifier.
Dataset¶
In this tutorial we will use the CIFAR10 dataset available in the torchvision package. The CIFAR10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
Here is an example of what the data looks like:
Training an image Packed-Ensemble classifier¶
Here is the outline of the process:
Load and normalizing the CIFAR10 training and test datasets using
torchvision
Define a Packed-Ensemble from a standard classifier
Define a loss function
Train the Packed-Ensemble on the training data
Test the Packed-Ensemble on the test data and evaluate its performance w.r.t. uncertainty quantification and OOD detection
1. Load and normalize CIFAR10¶
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
torch.set_num_threads(1)
The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1].
Note
If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
batch_size = 4
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
testloader = DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
0%| | 0.00/170M [00:00<?, ?B/s]
0%| | 32.8k/170M [00:00<10:54, 260kB/s]
0%| | 229k/170M [00:00<02:50, 997kB/s]
1%| | 918k/170M [00:00<00:57, 2.96MB/s]
2%|▏ | 2.98M/170M [00:00<00:18, 8.83MB/s]
4%|▍ | 6.85M/170M [00:00<00:08, 18.2MB/s]
7%|▋ | 11.6M/170M [00:00<00:05, 27.4MB/s]
9%|▉ | 15.9M/170M [00:00<00:04, 31.4MB/s]
12%|█▏ | 20.7M/170M [00:00<00:04, 36.3MB/s]
15%|█▍ | 24.8M/170M [00:01<00:03, 37.1MB/s]
17%|█▋ | 29.5M/170M [00:01<00:03, 39.9MB/s]
20%|█▉ | 33.9M/170M [00:01<00:03, 40.4MB/s]
23%|██▎ | 38.5M/170M [00:01<00:03, 42.0MB/s]
25%|██▌ | 42.9M/170M [00:01<00:03, 41.7MB/s]
28%|██▊ | 47.5M/170M [00:01<00:02, 42.8MB/s]
30%|███ | 51.9M/170M [00:01<00:02, 42.2MB/s]
33%|███▎ | 56.7M/170M [00:01<00:02, 43.8MB/s]
36%|███▌ | 61.1M/170M [00:01<00:02, 43.4MB/s]
38%|███▊ | 65.5M/170M [00:01<00:02, 43.5MB/s]
41%|████ | 69.9M/170M [00:02<00:02, 43.4MB/s]
44%|████▎ | 74.3M/170M [00:02<00:02, 43.3MB/s]
46%|████▌ | 78.6M/170M [00:02<00:02, 43.2MB/s]
49%|████▊ | 83.0M/170M [00:02<00:02, 43.1MB/s]
52%|█████▏ | 88.1M/170M [00:02<00:01, 42.7MB/s]
55%|█████▍ | 93.0M/170M [00:02<00:01, 44.6MB/s]
57%|█████▋ | 97.6M/170M [00:02<00:01, 44.6MB/s]
60%|██████ | 102M/170M [00:02<00:01, 46.0MB/s]
63%|██████▎ | 107M/170M [00:02<00:01, 44.8MB/s]
66%|██████▌ | 112M/170M [00:02<00:01, 45.3MB/s]
68%|██████▊ | 116M/170M [00:03<00:01, 44.0MB/s]
71%|███████ | 121M/170M [00:03<00:01, 44.9MB/s]
74%|███████▎ | 126M/170M [00:03<00:01, 44.0MB/s]
76%|███████▋ | 130M/170M [00:03<00:00, 44.2MB/s]
79%|███████▉ | 135M/170M [00:03<00:00, 43.2MB/s]
82%|████████▏ | 139M/170M [00:03<00:00, 44.4MB/s]
84%|████████▍ | 144M/170M [00:03<00:00, 43.2MB/s]
87%|████████▋ | 149M/170M [00:03<00:00, 44.4MB/s]
90%|████████▉ | 153M/170M [00:03<00:00, 43.1MB/s]
93%|█████████▎| 158M/170M [00:04<00:00, 44.4MB/s]
95%|█████████▌| 162M/170M [00:04<00:00, 42.9MB/s]
98%|█████████▊| 167M/170M [00:04<00:00, 42.3MB/s]
100%|██████████| 170M/170M [00:04<00:00, 39.4MB/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Let us show some of the training images, for fun.
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.figure(figsize=(10, 3))
plt.axis("off")
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images, pad_value=1))
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))
frog car truck horse
2. Define a Packed-Ensemble from a standard classifier¶
First we define a standard classifier for CIFAR10 for reference. We will use a convolutional neural network.
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
Let’s modify the standard classifier into a Packed-Ensemble classifier of parameters \(M=4,\ \alpha=2\text{ and }\gamma=1\).
from einops import rearrange
from torch_uncertainty.layers import PackedConv2d, PackedLinear
class PackedNet(nn.Module):
def __init__(self) -> None:
super().__init__()
M = 4
alpha = 2
gamma = 1
self.conv1 = PackedConv2d(
3, 6, 5, alpha=alpha, num_estimators=M, gamma=gamma, first=True
)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = PackedConv2d(6, 16, 5, alpha=alpha, num_estimators=M, gamma=gamma)
self.fc1 = PackedLinear(
16 * 5 * 5, 120, alpha=alpha, num_estimators=M, gamma=gamma
)
self.fc2 = PackedLinear(120, 84, alpha=alpha, num_estimators=M, gamma=gamma)
self.fc3 = PackedLinear(
84, 10 * M, alpha=alpha, num_estimators=M, gamma=gamma, last=True
)
self.num_estimators = M
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = rearrange(x, "e (m c) h w -> (m e) c h w", m=self.num_estimators)
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
packed_net = PackedNet()
3. Define a Loss function and optimizer¶
Let’s use a Classification Cross-Entropy loss and SGD with momentum.
from torch import optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(packed_net.parameters(), lr=0.001, momentum=0.9)
4. Train the Packed-Ensemble on the training data¶
Let’s train the Packed-Ensemble on the training data.
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = packed_net(inputs)
loss = criterion(outputs, labels.repeat(packed_net.num_estimators))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
print("Finished Training")
[1, 2000] loss: 2.626
[1, 4000] loss: 2.186
[1, 6000] loss: 2.039
[1, 8000] loss: 1.930
[1, 10000] loss: 1.847
[1, 12000] loss: 1.763
[2, 2000] loss: 1.703
[2, 4000] loss: 1.689
[2, 6000] loss: 1.659
[2, 8000] loss: 1.621
[2, 10000] loss: 1.601
[2, 12000] loss: 1.577
Finished Training
Save our trained model:
PATH = "./cifar_packed_net.pth"
torch.save(packed_net.state_dict(), PATH)
5. Test the Packed-Ensemble on the test data¶
Let us display an image from the test set to get familiar.
dataiter = iter(testloader)
images, labels = next(dataiter)
# print images
imshow(torchvision.utils.make_grid(images, pad_value=1))
print(
"GroundTruth: ",
" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)),
)
GroundTruth: cat ship ship plane
Next, let us load back in our saved model (note: saving and re-loading the model wasn’t necessary here, we only did it to illustrate how to do so):
packed_net = PackedNet()
packed_net.load_state_dict(torch.load(PATH))
<All keys matched successfully>
Let us see what the Packed-Ensemble thinks these examples above are:
logits = packed_net(images)
logits = rearrange(logits, "(n b) c -> b n c", n=packed_net.num_estimators)
probs_per_est = F.softmax(logits, dim=-1)
outputs = probs_per_est.mean(dim=1)
_, predicted = torch.max(outputs, 1)
print(
"Predicted: ",
" ".join(f"{classes[predicted[j]]:5s}" for j in range(batch_size)),
)
Predicted: cat ship ship ship
The results seem pretty good.
Total running time of the script: (1 minutes 54.281 seconds)