Note
Go to the end to download the full example code.
Improve Top-label Calibration with Temperature Scaling¶
In this tutorial, we use TorchUncertainty to improve the calibration of the top-label predictions and the reliability of the underlying neural network.
This tutorial provides extensive details on how to use the TemperatureScaler class, however, this is done automatically in the classification routine when setting the calibration_set to val or test.
Through this tutorial, we also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty’s models.
1. Loading the Utilities¶
In this tutorial, we will need:
TorchUncertainty’s Calibration Error metric to compute to evaluate the top-label calibration with ECE and plot the reliability diagrams
the CIFAR-100 datamodule to handle the data
a ResNet 18 as starting model
the temperature scaler to improve the top-label calibration
a utility function to download HF models easily
If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the log_plots flag.
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.metrics import CalibrationError
from torch_uncertainty.models.resnet import resnet
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.utils import load_hf
2. Loading a model from TorchUncertainty’s HF¶
To avoid training a model on CIFAR-100 from scratch, we load a model from Hugging Face. This can be done in a one liner:
# Build the model
model = resnet(in_channels=3, num_classes=100, arch=18, style="cifar", conv_bias=False)
# Download the weights (the config is not used here)
weights, config = load_hf("resnet18_c100")
# Load the weights in the pre-built model
model.load_state_dict(weights)
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/torch_uncertainty/utils/hub.py:59: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
weight = torch.load(weight_path, map_location=torch.device("cpu"))
<All keys matched successfully>
3. Setting up the Datamodule and Dataloaders¶
To get the dataloader from the datamodule, just call prepare_data, setup, and extract the first element of the test dataloader list. There are more than one element if eval_ood is True: the dataloader of in-distribution data and the dataloader of out-of-distribution data. Otherwise, it is a list of 1 element.
dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32)
dm.prepare_data()
dm.setup("test")
# Get the full test dataloader (unused in this tutorial)
dataloader = dm.test_dataloader()[0]
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz
0%| | 0.00/169M [00:00<?, ?B/s]
1%| | 1.97M/169M [00:00<00:08, 19.6MB/s]
7%|▋ | 11.7M/169M [00:00<00:02, 65.0MB/s]
13%|█▎ | 22.5M/169M [00:00<00:01, 84.6MB/s]
18%|█▊ | 31.0M/169M [00:00<00:01, 84.8MB/s]
25%|██▍ | 41.7M/169M [00:00<00:01, 92.9MB/s]
30%|███ | 51.1M/169M [00:00<00:01, 88.8MB/s]
36%|███▌ | 61.2M/169M [00:00<00:01, 92.8MB/s]
42%|████▏ | 70.6M/169M [00:00<00:01, 90.6MB/s]
48%|████▊ | 80.9M/169M [00:00<00:00, 93.6MB/s]
53%|█████▎ | 90.3M/169M [00:01<00:00, 93.3MB/s]
60%|█████▉ | 101M/169M [00:01<00:00, 96.4MB/s]
65%|██████▌ | 110M/169M [00:01<00:00, 94.1MB/s]
71%|███████▏ | 121M/169M [00:01<00:00, 97.0MB/s]
77%|███████▋ | 131M/169M [00:01<00:00, 93.9MB/s]
83%|████████▎ | 141M/169M [00:01<00:00, 96.3MB/s]
89%|████████▉ | 150M/169M [00:01<00:00, 92.1MB/s]
94%|█████████▍| 160M/169M [00:01<00:00, 91.7MB/s]
100%|█████████▉| 169M/169M [00:01<00:00, 86.0MB/s]
100%|██████████| 169M/169M [00:01<00:00, 89.3MB/s]
Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified
4. Iterating on the Dataloader and Computing the ECE¶
We first split the original test set into a calibration set and a test set for proper evaluation.
When computing the ECE, you need to provide the likelihoods associated with the inputs. To do this, just call PyTorch’s softmax.
To avoid lengthy computations (without GPU), we restrict the calibration computation to a subset of the test set.
from torch.utils.data import DataLoader, random_split
# Split datasets
dataset = dm.test
cal_dataset, test_dataset, other = random_split(
dataset, [1000, 1000, len(dataset) - 2000]
)
test_dataloader = DataLoader(test_dataset, batch_size=32)
# Initialize the ECE
ece = CalibrationError(task="multiclass", num_classes=100)
# Iterate on the calibration dataloader
for sample, target in test_dataloader:
logits = model(sample)
probs = logits.softmax(-1)
ece.update(probs, target)
# Compute & print the calibration error
print(f"ECE before scaling - {ece.compute():.3%}.")
ECE before scaling - 11.154%.
We also compute and plot the top-label calibration figure. We see that the model is not well calibrated.
fig, ax = ece.plot()
fig.show()
5. Fitting the Scaler to Improve the Calibration¶
The TemperatureScaler has one parameter that can be used to temper the softmax. We minimize the tempered cross-entropy on a calibration set that we define here as a subset of the test set and containing 1000 data. Look at the code run by TemperatureScaler fit method for more details.
# Fit the scaler on the calibration dataset
scaled_model = TemperatureScaler(model=model)
scaled_model.fit(calibration_set=cal_dataset)
0%| | 0/32 [00:00<?, ?it/s]
3%|▎ | 1/32 [00:00<00:07, 4.03it/s]
6%|▋ | 2/32 [00:00<00:07, 4.06it/s]
9%|▉ | 3/32 [00:00<00:07, 4.07it/s]
12%|█▎ | 4/32 [00:00<00:06, 4.05it/s]
16%|█▌ | 5/32 [00:01<00:06, 4.05it/s]
19%|█▉ | 6/32 [00:01<00:06, 4.06it/s]
22%|██▏ | 7/32 [00:01<00:06, 4.06it/s]
25%|██▌ | 8/32 [00:01<00:05, 4.04it/s]
28%|██▊ | 9/32 [00:02<00:05, 3.93it/s]
31%|███▏ | 10/32 [00:02<00:05, 3.97it/s]
34%|███▍ | 11/32 [00:02<00:05, 3.99it/s]
38%|███▊ | 12/32 [00:02<00:05, 3.99it/s]
41%|████ | 13/32 [00:03<00:04, 4.01it/s]
44%|████▍ | 14/32 [00:03<00:04, 4.03it/s]
47%|████▋ | 15/32 [00:03<00:04, 3.99it/s]
50%|█████ | 16/32 [00:03<00:04, 3.94it/s]
53%|█████▎ | 17/32 [00:04<00:03, 3.98it/s]
56%|█████▋ | 18/32 [00:04<00:03, 4.01it/s]
59%|█████▉ | 19/32 [00:04<00:03, 4.03it/s]
62%|██████▎ | 20/32 [00:04<00:02, 4.02it/s]
66%|██████▌ | 21/32 [00:05<00:02, 4.03it/s]
69%|██████▉ | 22/32 [00:05<00:02, 4.04it/s]
72%|███████▏ | 23/32 [00:05<00:02, 4.06it/s]
75%|███████▌ | 24/32 [00:05<00:01, 4.04it/s]
78%|███████▊ | 25/32 [00:06<00:01, 4.04it/s]
81%|████████▏ | 26/32 [00:06<00:01, 4.05it/s]
84%|████████▍ | 27/32 [00:06<00:01, 4.06it/s]
88%|████████▊ | 28/32 [00:06<00:00, 4.06it/s]
91%|█████████ | 29/32 [00:07<00:00, 4.03it/s]
94%|█████████▍| 30/32 [00:07<00:00, 4.05it/s]
97%|█████████▋| 31/32 [00:07<00:00, 4.05it/s]
100%|██████████| 32/32 [00:07<00:00, 4.12it/s]
6. Iterating Again to Compute the Improved ECE¶
We can directly use the scaler as a calibrated model.
Note that you will need to first reset the ECE metric to avoid mixing the scores of the previous and current iterations.
# Reset the ECE
ece.reset()
# Iterate on the test dataloader
for sample, target in test_dataloader:
logits = scaled_model(sample)
probs = logits.softmax(-1)
ece.update(probs, target)
print(f"ECE after scaling - {ece.compute():.3%}.")
ECE after scaling - 4.419%.
We finally compute and plot the scaled top-label calibration figure. We see that the model is now better calibrated.
fig, ax = ece.plot()
fig.show()
The top-label calibration should be improved.
Notes¶
Temperature scaling is very efficient when the calibration set is representative of the test set. In this case, we say that the calibration and test set are drawn from the same distribution. However, this may not hold true in real-world cases where dataset shift could happen.
References¶
Expected Calibration Error: Naeini, M. P., Cooper, G. F., & Hauskrecht, M. (2015). Obtaining Well Calibrated Probabilities Using Bayesian Binning. In AAAI 2015.
Temperature Scaling: Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. In ICML 2017.
Total running time of the script: (0 minutes 30.516 seconds)