Note
Go to the end to download the full example code.
From a Standard Classifier to a Packed-Ensemble¶
This tutorial is heavily inspired by PyTorch’s Training a Classifier tutorial.
Let’s dive step by step into the process to modify a standard classifier into a packed-ensemble classifier.
Dataset¶
In this tutorial we will use the CIFAR10 dataset available in the torchvision package. The CIFAR10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
Here is an example of what the data looks like:
Training an image Packed-Ensemble classifier¶
Here is the outline of the process:
Load and normalizing the CIFAR10 training and test datasets using
torchvision
Define a Packed-Ensemble from a standard classifier
Define a loss function
Train the Packed-Ensemble on the training data
Test the Packed-Ensemble on the test data and evaluate its performance w.r.t. uncertainty quantification and OOD detection
1. Load and normalize CIFAR10¶
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
torch.set_num_threads(1)
The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1].
Note
If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
batch_size = 4
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
testloader = DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
0%| | 0.00/170M [00:00<?, ?B/s]
0%| | 32.8k/170M [00:00<10:25, 272kB/s]
0%| | 229k/170M [00:00<02:39, 1.07MB/s]
1%| | 885k/170M [00:00<00:55, 3.04MB/s]
2%|▏ | 3.54M/170M [00:00<00:15, 11.0MB/s]
5%|▌ | 8.65M/170M [00:00<00:06, 24.2MB/s]
8%|▊ | 12.9M/170M [00:00<00:05, 30.0MB/s]
11%|█ | 18.4M/170M [00:00<00:04, 37.9MB/s]
14%|█▎ | 23.1M/170M [00:00<00:03, 40.5MB/s]
16%|█▋ | 28.0M/170M [00:00<00:03, 42.1MB/s]
19%|█▉ | 32.9M/170M [00:01<00:03, 44.1MB/s]
22%|██▏ | 37.5M/170M [00:01<00:02, 44.5MB/s]
25%|██▌ | 42.6M/170M [00:01<00:02, 46.5MB/s]
28%|██▊ | 47.3M/170M [00:01<00:02, 46.5MB/s]
31%|███ | 52.3M/170M [00:01<00:02, 47.5MB/s]
33%|███▎ | 57.1M/170M [00:01<00:02, 47.1MB/s]
37%|███▋ | 62.7M/170M [00:01<00:02, 49.5MB/s]
40%|███▉ | 67.7M/170M [00:01<00:02, 48.3MB/s]
43%|████▎ | 72.5M/170M [00:01<00:02, 47.3MB/s]
46%|████▌ | 78.3M/170M [00:02<00:01, 49.5MB/s]
49%|████▉ | 83.3M/170M [00:02<00:01, 49.6MB/s]
52%|█████▏ | 88.3M/170M [00:02<00:01, 47.9MB/s]
55%|█████▍ | 93.2M/170M [00:02<00:01, 48.1MB/s]
58%|█████▊ | 98.4M/170M [00:02<00:01, 49.2MB/s]
61%|██████ | 104M/170M [00:02<00:01, 49.9MB/s]
64%|██████▎ | 109M/170M [00:02<00:01, 49.5MB/s]
67%|██████▋ | 114M/170M [00:02<00:01, 49.0MB/s]
69%|██████▉ | 118M/170M [00:02<00:01, 48.0MB/s]
73%|███████▎ | 124M/170M [00:02<00:00, 49.5MB/s]
76%|███████▌ | 129M/170M [00:03<00:00, 49.7MB/s]
78%|███████▊ | 134M/170M [00:03<00:00, 48.8MB/s]
81%|████████▏ | 139M/170M [00:03<00:00, 49.2MB/s]
84%|████████▍ | 144M/170M [00:03<00:00, 49.1MB/s]
87%|████████▋ | 149M/170M [00:03<00:00, 48.7MB/s]
90%|█████████ | 154M/170M [00:03<00:00, 49.2MB/s]
93%|█████████▎| 159M/170M [00:03<00:00, 48.8MB/s]
96%|█████████▋| 164M/170M [00:03<00:00, 50.3MB/s]
99%|█████████▉| 169M/170M [00:03<00:00, 49.4MB/s]
100%|██████████| 170M/170M [00:03<00:00, 43.9MB/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Let us show some of the training images, for fun.
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.figure(figsize=(10, 3))
plt.axis("off")
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images, pad_value=1))
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))
dog dog truck frog
2. Define a Packed-Ensemble from a standard classifier¶
First we define a standard classifier for CIFAR10 for reference. We will use a convolutional neural network.
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
Let’s modify the standard classifier into a Packed-Ensemble classifier of parameters \(M=4,\ \alpha=2\text{ and }\gamma=1\).
from einops import rearrange
from torch_uncertainty.layers import PackedConv2d, PackedLinear
class PackedNet(nn.Module):
def __init__(self) -> None:
super().__init__()
M = 4
alpha = 2
gamma = 1
self.conv1 = PackedConv2d(
3, 6, 5, alpha=alpha, num_estimators=M, gamma=gamma, first=True
)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = PackedConv2d(6, 16, 5, alpha=alpha, num_estimators=M, gamma=gamma)
self.fc1 = PackedLinear(
16 * 5 * 5, 120, alpha=alpha, num_estimators=M, gamma=gamma
)
self.fc2 = PackedLinear(120, 84, alpha=alpha, num_estimators=M, gamma=gamma)
self.fc3 = PackedLinear(
84, 10 * M, alpha=alpha, num_estimators=M, gamma=gamma, last=True
)
self.num_estimators = M
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = rearrange(x, "e (m c) h w -> (m e) c h w", m=self.num_estimators)
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
packed_net = PackedNet()
3. Define a Loss function and optimizer¶
Let’s use a Classification Cross-Entropy loss and SGD with momentum.
from torch import optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(packed_net.parameters(), lr=0.001, momentum=0.9)
4. Train the Packed-Ensemble on the training data¶
Let’s train the Packed-Ensemble on the training data.
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = packed_net(inputs)
loss = criterion(outputs, labels.repeat(packed_net.num_estimators))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
print("Finished Training")
[1, 2000] loss: 2.628
[1, 4000] loss: 2.197
[1, 6000] loss: 2.060
[1, 8000] loss: 1.987
[1, 10000] loss: 1.920
[1, 12000] loss: 1.847
[2, 2000] loss: 1.757
[2, 4000] loss: 1.707
[2, 6000] loss: 1.662
[2, 8000] loss: 1.638
[2, 10000] loss: 1.615
[2, 12000] loss: 1.591
Finished Training
Save our trained model:
PATH = "./cifar_packed_net.pth"
torch.save(packed_net.state_dict(), PATH)
5. Test the Packed-Ensemble on the test data¶
Let us display an image from the test set to get familiar.
dataiter = iter(testloader)
images, labels = next(dataiter)
# print images
imshow(torchvision.utils.make_grid(images, pad_value=1))
print(
"GroundTruth: ",
" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)),
)
GroundTruth: cat ship ship plane
Next, let us load back in our saved model (note: saving and re-loading the model wasn’t necessary here, we only did it to illustrate how to do so):
packed_net = PackedNet()
packed_net.load_state_dict(torch.load(PATH))
/home/runner/work/torch-uncertainty/torch-uncertainty/auto_tutorials_source/tutorial_pe_cifar10.py:264: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
packed_net.load_state_dict(torch.load(PATH))
<All keys matched successfully>
Let us see what the Packed-Ensemble thinks these examples above are:
logits = packed_net(images)
logits = rearrange(logits, "(n b) c -> b n c", n=packed_net.num_estimators)
probs_per_est = F.softmax(logits, dim=-1)
outputs = probs_per_est.mean(dim=1)
_, predicted = torch.max(outputs, 1)
print(
"Predicted: ",
" ".join(f"{classes[predicted[j]]:5s}" for j in range(batch_size)),
)
Predicted: cat ship ship ship
The results seem pretty good.
Total running time of the script: (1 minutes 55.816 seconds)