Note
Go to the end to download the full example code.
Conformal Prediction on CIFAR-10 with TorchUncertainty#
We evaluate the model’s performance both before and after applying different conformal predictors (THR, APS, RAPS), and visualize how conformal prediction estimates the prediction sets.
We use the pretrained ResNet models we provide on Hugging Face.
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.models.classification.resnet import resnet
from torch_uncertainty.post_processing import ConformalClsAPS, ConformalClsRAPS, ConformalClsTHR
from torch_uncertainty.routines import ClassificationRoutine
1. Load pretrained model from Hugging Face repository#
We use a ResNet18 model trained on CIFAR-10, provided by the TorchUncertainty team
ckpt_path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c10", filename="resnet18_c10.ckpt")
model = resnet(in_channels=3, num_classes=10, arch=18, conv_bias=False, style="cifar")
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt)
model = model.cuda().eval()
2. Load CIFAR-10 Dataset & Define Dataloaders#
We set eval_ood to True to evaluate the performance of Conformal scores for detecting out-of-distribution samples. In this case, since we use a model trained on the full training set, we use the test set to as calibration set for the Conformal methods and for its evaluation. This is not a proper way to evaluate the coverage.
BATCH_SIZE = 128
datamodule = CIFAR10DataModule(
root="./data",
batch_size=BATCH_SIZE,
num_workers=8,
eval_ood=True,
postprocess_set="test",
)
datamodule.prepare_data()
datamodule.setup()
0%| | 0.00/170M [00:00<?, ?B/s]
0%| | 65.5k/170M [00:00<07:23, 384kB/s]
0%| | 229k/170M [00:00<03:51, 735kB/s]
0%| | 623k/170M [00:00<01:34, 1.80MB/s]
0%| | 852k/170M [00:00<01:50, 1.54MB/s]
1%| | 1.93M/170M [00:00<00:51, 3.28MB/s]
2%|▏ | 2.75M/170M [00:00<00:41, 4.08MB/s]
2%|▏ | 3.21M/170M [00:01<00:40, 4.17MB/s]
2%|▏ | 3.67M/170M [00:01<00:39, 4.25MB/s]
2%|▏ | 4.13M/170M [00:01<00:42, 3.91MB/s]
3%|▎ | 4.55M/170M [00:01<00:41, 3.99MB/s]
3%|▎ | 5.05M/170M [00:01<00:39, 4.23MB/s]
3%|▎ | 5.51M/170M [00:01<00:42, 3.92MB/s]
3%|▎ | 5.96M/170M [00:01<00:40, 4.05MB/s]
4%|▍ | 6.46M/170M [00:01<00:38, 4.27MB/s]
4%|▍ | 6.91M/170M [00:01<00:40, 4.06MB/s]
4%|▍ | 7.34M/170M [00:02<00:40, 4.07MB/s]
5%|▍ | 7.83M/170M [00:02<00:38, 4.25MB/s]
5%|▍ | 8.32M/170M [00:02<00:36, 4.40MB/s]
5%|▌ | 8.78M/170M [00:02<00:38, 4.20MB/s]
5%|▌ | 9.21M/170M [00:02<00:38, 4.19MB/s]
6%|▌ | 9.70M/170M [00:02<00:36, 4.35MB/s]
6%|▌ | 10.2M/170M [00:02<00:35, 4.50MB/s]
6%|▌ | 10.6M/170M [00:02<00:37, 4.29MB/s]
7%|▋ | 11.1M/170M [00:02<00:36, 4.31MB/s]
7%|▋ | 11.6M/170M [00:03<00:35, 4.44MB/s]
7%|▋ | 12.1M/170M [00:03<00:34, 4.57MB/s]
7%|▋ | 12.6M/170M [00:03<00:35, 4.44MB/s]
8%|▊ | 13.0M/170M [00:03<00:35, 4.38MB/s]
8%|▊ | 13.5M/170M [00:03<00:34, 4.53MB/s]
8%|▊ | 14.0M/170M [00:03<00:34, 4.54MB/s]
9%|▊ | 14.5M/170M [00:03<00:34, 4.55MB/s]
9%|▉ | 15.0M/170M [00:03<00:34, 4.49MB/s]
9%|▉ | 15.4M/170M [00:03<00:34, 4.45MB/s]
9%|▉ | 15.9M/170M [00:03<00:33, 4.57MB/s]
10%|▉ | 16.4M/170M [00:04<00:33, 4.54MB/s]
10%|▉ | 16.9M/170M [00:04<00:33, 4.58MB/s]
10%|█ | 17.3M/170M [00:04<00:34, 4.49MB/s]
10%|█ | 17.8M/170M [00:04<00:34, 4.44MB/s]
11%|█ | 18.3M/170M [00:04<00:33, 4.48MB/s]
11%|█ | 18.7M/170M [00:04<00:33, 4.48MB/s]
11%|█ | 19.2M/170M [00:04<00:33, 4.47MB/s]
12%|█▏ | 19.6M/170M [00:04<00:33, 4.49MB/s]
12%|█▏ | 20.1M/170M [00:04<00:33, 4.52MB/s]
12%|█▏ | 20.5M/170M [00:05<00:33, 4.48MB/s]
12%|█▏ | 21.0M/170M [00:05<00:33, 4.49MB/s]
13%|█▎ | 21.5M/170M [00:05<00:33, 4.51MB/s]
13%|█▎ | 21.9M/170M [00:05<00:33, 4.50MB/s]
13%|█▎ | 22.4M/170M [00:05<00:32, 4.52MB/s]
13%|█▎ | 22.8M/170M [00:05<00:32, 4.53MB/s]
14%|█▎ | 23.3M/170M [00:05<00:32, 4.57MB/s]
14%|█▍ | 23.8M/170M [00:05<00:32, 4.54MB/s]
14%|█▍ | 24.2M/170M [00:05<00:32, 4.51MB/s]
15%|█▍ | 24.7M/170M [00:05<00:31, 4.59MB/s]
15%|█▍ | 25.2M/170M [00:06<00:31, 4.66MB/s]
15%|█▌ | 25.7M/170M [00:06<00:31, 4.64MB/s]
15%|█▌ | 26.2M/170M [00:06<00:31, 4.59MB/s]
16%|█▌ | 26.7M/170M [00:06<00:31, 4.60MB/s]
16%|█▌ | 27.2M/170M [00:06<00:30, 4.65MB/s]
16%|█▌ | 27.7M/170M [00:06<00:30, 4.65MB/s]
17%|█▋ | 28.2M/170M [00:06<00:30, 4.64MB/s]
17%|█▋ | 28.7M/170M [00:06<00:30, 4.58MB/s]
17%|█▋ | 29.2M/170M [00:06<00:30, 4.63MB/s]
17%|█▋ | 29.7M/170M [00:06<00:30, 4.65MB/s]
18%|█▊ | 30.1M/170M [00:07<00:30, 4.65MB/s]
18%|█▊ | 30.6M/170M [00:07<00:30, 4.58MB/s]
18%|█▊ | 31.1M/170M [00:07<00:30, 4.63MB/s]
19%|█▊ | 31.6M/170M [00:07<00:29, 4.65MB/s]
19%|█▉ | 32.1M/170M [00:07<00:29, 4.69MB/s]
19%|█▉ | 32.6M/170M [00:07<00:30, 4.55MB/s]
19%|█▉ | 33.1M/170M [00:07<00:29, 4.69MB/s]
20%|█▉ | 33.6M/170M [00:07<00:29, 4.70MB/s]
20%|██ | 34.1M/170M [00:07<00:28, 4.72MB/s]
20%|██ | 34.6M/170M [00:08<00:29, 4.61MB/s]
21%|██ | 35.1M/170M [00:08<00:28, 4.67MB/s]
21%|██ | 35.6M/170M [00:08<00:28, 4.72MB/s]
21%|██ | 36.1M/170M [00:08<00:28, 4.74MB/s]
21%|██▏ | 36.6M/170M [00:08<00:28, 4.71MB/s]
22%|██▏ | 37.1M/170M [00:08<00:28, 4.69MB/s]
22%|██▏ | 37.6M/170M [00:08<00:28, 4.72MB/s]
22%|██▏ | 38.1M/170M [00:08<00:27, 4.75MB/s]
23%|██▎ | 38.6M/170M [00:08<00:27, 4.82MB/s]
23%|██▎ | 39.1M/170M [00:09<00:27, 4.78MB/s]
23%|██▎ | 39.6M/170M [00:09<00:27, 4.78MB/s]
24%|██▎ | 40.1M/170M [00:09<00:27, 4.82MB/s]
24%|██▍ | 40.7M/170M [00:09<00:26, 4.91MB/s]
24%|██▍ | 41.2M/170M [00:09<00:26, 4.93MB/s]
24%|██▍ | 41.7M/170M [00:09<00:26, 4.81MB/s]
25%|██▍ | 42.3M/170M [00:09<00:26, 4.90MB/s]
25%|██▌ | 42.8M/170M [00:09<00:25, 4.98MB/s]
25%|██▌ | 43.3M/170M [00:09<00:25, 5.04MB/s]
26%|██▌ | 43.8M/170M [00:09<00:25, 4.89MB/s]
26%|██▌ | 44.4M/170M [00:10<00:25, 5.03MB/s]
26%|██▋ | 45.0M/170M [00:10<00:24, 5.10MB/s]
27%|██▋ | 45.5M/170M [00:10<00:24, 5.13MB/s]
27%|██▋ | 46.0M/170M [00:10<00:24, 5.13MB/s]
27%|██▋ | 46.6M/170M [00:10<00:24, 5.13MB/s]
28%|██▊ | 47.1M/170M [00:10<00:23, 5.22MB/s]
28%|██▊ | 47.7M/170M [00:10<00:23, 5.32MB/s]
28%|██▊ | 48.3M/170M [00:10<00:23, 5.31MB/s]
29%|██▊ | 48.8M/170M [00:10<00:23, 5.27MB/s]
29%|██▉ | 49.4M/170M [00:11<00:22, 5.35MB/s]
29%|██▉ | 50.0M/170M [00:11<00:21, 5.49MB/s]
30%|██▉ | 50.6M/170M [00:11<00:21, 5.60MB/s]
30%|███ | 51.2M/170M [00:11<00:21, 5.59MB/s]
30%|███ | 51.7M/170M [00:11<00:21, 5.58MB/s]
31%|███ | 52.3M/170M [00:11<00:21, 5.60MB/s]
31%|███ | 53.0M/170M [00:11<00:20, 5.75MB/s]
31%|███▏ | 53.6M/170M [00:11<00:19, 5.89MB/s]
32%|███▏ | 54.2M/170M [00:11<00:19, 5.87MB/s]
32%|███▏ | 54.8M/170M [00:11<00:19, 5.86MB/s]
33%|███▎ | 55.4M/170M [00:12<00:19, 5.97MB/s]
33%|███▎ | 56.1M/170M [00:12<00:18, 6.10MB/s]
33%|███▎ | 56.8M/170M [00:12<00:18, 6.19MB/s]
34%|███▎ | 57.4M/170M [00:12<00:18, 6.25MB/s]
34%|███▍ | 58.1M/170M [00:12<00:17, 6.26MB/s]
34%|███▍ | 58.8M/170M [00:12<00:17, 6.38MB/s]
35%|███▍ | 59.5M/170M [00:12<00:17, 6.47MB/s]
35%|███▌ | 60.2M/170M [00:12<00:16, 6.62MB/s]
36%|███▌ | 60.9M/170M [00:12<00:16, 6.69MB/s]
36%|███▌ | 61.6M/170M [00:12<00:16, 6.71MB/s]
37%|███▋ | 62.4M/170M [00:13<00:15, 6.90MB/s]
37%|███▋ | 63.1M/170M [00:13<00:15, 6.94MB/s]
37%|███▋ | 63.9M/170M [00:13<00:15, 7.06MB/s]
38%|███▊ | 64.7M/170M [00:13<00:14, 7.18MB/s]
38%|███▊ | 65.4M/170M [00:13<00:14, 7.19MB/s]
39%|███▉ | 66.2M/170M [00:13<00:14, 7.33MB/s]
39%|███▉ | 67.0M/170M [00:13<00:13, 7.42MB/s]
40%|███▉ | 67.8M/170M [00:13<00:13, 7.55MB/s]
40%|████ | 68.6M/170M [00:13<00:13, 7.72MB/s]
41%|████ | 69.5M/170M [00:14<00:12, 7.88MB/s]
41%|████ | 70.3M/170M [00:14<00:12, 7.83MB/s]
42%|████▏ | 71.1M/170M [00:14<00:12, 7.89MB/s]
42%|████▏ | 72.0M/170M [00:14<00:12, 8.07MB/s]
43%|████▎ | 72.9M/170M [00:14<00:11, 8.25MB/s]
43%|████▎ | 73.8M/170M [00:14<00:11, 8.46MB/s]
44%|████▍ | 74.7M/170M [00:14<00:11, 8.64MB/s]
44%|████▍ | 75.7M/170M [00:14<00:10, 8.83MB/s]
45%|████▍ | 76.6M/170M [00:14<00:11, 8.40MB/s]
45%|████▌ | 77.5M/170M [00:14<00:11, 7.82MB/s]
46%|████▌ | 78.4M/170M [00:15<00:11, 8.31MB/s]
47%|████▋ | 79.3M/170M [00:15<00:12, 7.45MB/s]
47%|████▋ | 80.1M/170M [00:15<00:12, 7.31MB/s]
47%|████▋ | 80.8M/170M [00:15<00:12, 7.28MB/s]
48%|████▊ | 81.6M/170M [00:15<00:12, 7.12MB/s]
48%|████▊ | 82.3M/170M [00:15<00:12, 7.18MB/s]
49%|████▊ | 83.1M/170M [00:15<00:12, 7.24MB/s]
49%|████▉ | 83.9M/170M [00:15<00:11, 7.31MB/s]
50%|████▉ | 84.6M/170M [00:15<00:11, 7.43MB/s]
50%|█████ | 85.4M/170M [00:16<00:11, 7.43MB/s]
51%|█████ | 86.2M/170M [00:16<00:11, 7.44MB/s]
51%|█████ | 87.0M/170M [00:16<00:11, 7.52MB/s]
51%|█████▏ | 87.8M/170M [00:16<00:10, 7.62MB/s]
52%|█████▏ | 88.5M/170M [00:16<00:10, 7.69MB/s]
52%|█████▏ | 89.4M/170M [00:16<00:10, 7.89MB/s]
53%|█████▎ | 90.2M/170M [00:16<00:10, 7.82MB/s]
53%|█████▎ | 91.0M/170M [00:16<00:10, 7.90MB/s]
54%|█████▍ | 91.9M/170M [00:16<00:09, 8.02MB/s]
54%|█████▍ | 92.7M/170M [00:16<00:09, 8.15MB/s]
55%|█████▍ | 93.6M/170M [00:17<00:09, 8.15MB/s]
55%|█████▌ | 94.4M/170M [00:17<00:09, 8.26MB/s]
56%|█████▌ | 95.3M/170M [00:17<00:09, 8.28MB/s]
56%|█████▋ | 96.2M/170M [00:17<00:08, 8.48MB/s]
57%|█████▋ | 97.0M/170M [00:17<00:08, 8.46MB/s]
57%|█████▋ | 97.9M/170M [00:17<00:08, 8.47MB/s]
58%|█████▊ | 98.7M/170M [00:17<00:08, 8.44MB/s]
58%|█████▊ | 99.6M/170M [00:17<00:08, 8.53MB/s]
59%|█████▉ | 100M/170M [00:17<00:08, 8.53MB/s]
59%|█████▉ | 101M/170M [00:18<00:08, 8.54MB/s]
60%|██████ | 102M/170M [00:18<00:07, 8.69MB/s]
61%|██████ | 103M/170M [00:18<00:07, 8.73MB/s]
61%|██████ | 104M/170M [00:18<00:07, 8.70MB/s]
62%|██████▏ | 105M/170M [00:18<00:07, 8.82MB/s]
62%|██████▏ | 106M/170M [00:18<00:07, 8.95MB/s]
63%|██████▎ | 107M/170M [00:18<00:07, 9.05MB/s]
63%|██████▎ | 108M/170M [00:18<00:06, 8.97MB/s]
64%|██████▍ | 109M/170M [00:18<00:06, 9.03MB/s]
64%|██████▍ | 110M/170M [00:18<00:06, 9.12MB/s]
65%|██████▍ | 111M/170M [00:19<00:06, 9.20MB/s]
65%|██████▌ | 112M/170M [00:19<00:06, 9.28MB/s]
66%|██████▌ | 113M/170M [00:19<00:06, 9.27MB/s]
67%|██████▋ | 114M/170M [00:19<00:06, 9.20MB/s]
67%|██████▋ | 114M/170M [00:19<00:06, 9.26MB/s]
68%|██████▊ | 115M/170M [00:19<00:05, 9.29MB/s]
68%|██████▊ | 116M/170M [00:19<00:05, 9.33MB/s]
69%|██████▉ | 117M/170M [00:19<00:05, 9.34MB/s]
69%|██████▉ | 118M/170M [00:19<00:05, 9.38MB/s]
70%|██████▉ | 119M/170M [00:19<00:05, 9.41MB/s]
70%|███████ | 120M/170M [00:20<00:05, 9.44MB/s]
71%|███████ | 121M/170M [00:20<00:05, 9.46MB/s]
72%|███████▏ | 122M/170M [00:20<00:05, 9.44MB/s]
72%|███████▏ | 123M/170M [00:20<00:04, 9.49MB/s]
73%|███████▎ | 124M/170M [00:20<00:04, 9.46MB/s]
73%|███████▎ | 125M/170M [00:20<00:04, 9.48MB/s]
74%|███████▍ | 126M/170M [00:20<00:04, 9.47MB/s]
74%|███████▍ | 127M/170M [00:20<00:04, 9.51MB/s]
75%|███████▌ | 128M/170M [00:20<00:04, 9.52MB/s]
76%|███████▌ | 129M/170M [00:20<00:04, 9.53MB/s]
76%|███████▌ | 130M/170M [00:21<00:04, 9.58MB/s]
77%|███████▋ | 131M/170M [00:21<00:04, 9.52MB/s]
77%|███████▋ | 132M/170M [00:21<00:04, 9.48MB/s]
78%|███████▊ | 133M/170M [00:21<00:03, 9.53MB/s]
78%|███████▊ | 134M/170M [00:21<00:03, 9.43MB/s]
79%|███████▉ | 135M/170M [00:21<00:03, 9.52MB/s]
80%|███████▉ | 136M/170M [00:21<00:03, 9.58MB/s]
80%|████████ | 137M/170M [00:21<00:03, 9.64MB/s]
81%|████████ | 138M/170M [00:21<00:03, 9.56MB/s]
81%|████████▏ | 139M/170M [00:21<00:03, 9.51MB/s]
82%|████████▏ | 140M/170M [00:22<00:03, 9.52MB/s]
83%|████████▎ | 141M/170M [00:22<00:03, 9.49MB/s]
83%|████████▎ | 142M/170M [00:22<00:03, 9.56MB/s]
84%|████████▎ | 143M/170M [00:22<00:02, 9.58MB/s]
84%|████████▍ | 144M/170M [00:22<00:02, 9.63MB/s]
85%|████████▍ | 145M/170M [00:22<00:02, 9.57MB/s]
85%|████████▌ | 146M/170M [00:22<00:02, 9.56MB/s]
86%|████████▌ | 147M/170M [00:22<00:02, 9.60MB/s]
87%|████████▋ | 148M/170M [00:22<00:02, 9.53MB/s]
87%|████████▋ | 149M/170M [00:23<00:02, 9.61MB/s]
88%|████████▊ | 150M/170M [00:23<00:02, 9.54MB/s]
88%|████████▊ | 151M/170M [00:23<00:02, 9.41MB/s]
89%|████████▉ | 152M/170M [00:23<00:01, 9.51MB/s]
89%|████████▉ | 153M/170M [00:23<00:01, 9.56MB/s]
90%|█████████ | 153M/170M [00:23<00:01, 9.62MB/s]
91%|█████████ | 154M/170M [00:23<00:01, 9.63MB/s]
91%|█████████ | 155M/170M [00:23<00:01, 9.62MB/s]
92%|█████████▏| 156M/170M [00:23<00:01, 9.47MB/s]
92%|█████████▏| 157M/170M [00:23<00:01, 9.34MB/s]
93%|█████████▎| 158M/170M [00:24<00:01, 9.50MB/s]
94%|█████████▎| 159M/170M [00:24<00:01, 9.61MB/s]
94%|█████████▍| 160M/170M [00:24<00:01, 9.69MB/s]
95%|█████████▍| 161M/170M [00:24<00:00, 9.66MB/s]
95%|█████████▌| 162M/170M [00:24<00:00, 9.41MB/s]
96%|█████████▌| 163M/170M [00:24<00:00, 9.42MB/s]
96%|█████████▋| 164M/170M [00:24<00:00, 9.53MB/s]
97%|█████████▋| 165M/170M [00:24<00:00, 9.58MB/s]
98%|█████████▊| 166M/170M [00:24<00:00, 9.66MB/s]
98%|█████████▊| 167M/170M [00:24<00:00, 9.68MB/s]
99%|█████████▊| 168M/170M [00:25<00:00, 9.59MB/s]
99%|█████████▉| 169M/170M [00:25<00:00, 9.60MB/s]
100%|█████████▉| 170M/170M [00:25<00:00, 9.59MB/s]
100%|██████████| 170M/170M [00:25<00:00, 6.74MB/s]
0%| | 0.00/64.3M [00:00<?, ?B/s]
0%| | 32.8k/64.3M [00:00<04:34, 234kB/s]
0%| | 65.5k/64.3M [00:00<04:35, 233kB/s]
0%| | 98.3k/64.3M [00:00<04:33, 235kB/s]
0%| | 164k/64.3M [00:00<03:15, 327kB/s]
0%| | 295k/64.3M [00:00<01:57, 545kB/s]
1%| | 459k/64.3M [00:00<01:23, 761kB/s]
1%| | 655k/64.3M [00:01<01:09, 922kB/s]
1%|▏ | 950k/64.3M [00:01<00:48, 1.31MB/s]
2%|▏ | 1.38M/64.3M [00:01<00:33, 1.90MB/s]
3%|▎ | 1.70M/64.3M [00:01<00:30, 2.04MB/s]
3%|▎ | 2.03M/64.3M [00:01<00:29, 2.14MB/s]
4%|▎ | 2.39M/64.3M [00:01<00:27, 2.27MB/s]
4%|▍ | 2.75M/64.3M [00:01<00:26, 2.36MB/s]
5%|▍ | 3.11M/64.3M [00:01<00:25, 2.42MB/s]
5%|▌ | 3.47M/64.3M [00:02<00:24, 2.48MB/s]
6%|▌ | 3.87M/64.3M [00:02<00:23, 2.57MB/s]
7%|▋ | 4.23M/64.3M [00:02<00:23, 2.58MB/s]
7%|▋ | 4.62M/64.3M [00:02<00:22, 2.65MB/s]
8%|▊ | 5.01M/64.3M [00:02<00:21, 2.70MB/s]
8%|▊ | 5.41M/64.3M [00:02<00:21, 2.73MB/s]
9%|▉ | 5.83M/64.3M [00:02<00:20, 2.81MB/s]
10%|▉ | 6.23M/64.3M [00:03<00:20, 2.82MB/s]
10%|█ | 6.65M/64.3M [00:03<00:19, 2.88MB/s]
11%|█ | 7.08M/64.3M [00:03<00:19, 2.93MB/s]
12%|█▏ | 7.50M/64.3M [00:03<00:19, 2.96MB/s]
12%|█▏ | 7.93M/64.3M [00:03<00:18, 2.99MB/s]
13%|█▎ | 8.36M/64.3M [00:03<00:18, 3.01MB/s]
14%|█▎ | 8.78M/64.3M [00:03<00:18, 3.02MB/s]
14%|█▍ | 9.21M/64.3M [00:04<00:18, 3.03MB/s]
15%|█▌ | 9.67M/64.3M [00:04<00:17, 3.10MB/s]
16%|█▌ | 10.1M/64.3M [00:04<00:17, 3.09MB/s]
16%|█▋ | 10.6M/64.3M [00:04<00:17, 3.14MB/s]
17%|█▋ | 11.0M/64.3M [00:04<00:16, 3.33MB/s]
18%|█▊ | 11.4M/64.3M [00:04<00:17, 3.11MB/s]
19%|█▊ | 11.9M/64.3M [00:04<00:16, 3.15MB/s]
19%|█▉ | 12.4M/64.3M [00:05<00:16, 3.19MB/s]
20%|█▉ | 12.8M/64.3M [00:05<00:16, 3.21MB/s]
21%|██ | 13.3M/64.3M [00:05<00:15, 3.23MB/s]
21%|██▏ | 13.7M/64.3M [00:05<00:14, 3.43MB/s]
22%|██▏ | 14.2M/64.3M [00:05<00:14, 3.39MB/s]
23%|██▎ | 14.6M/64.3M [00:05<00:15, 3.12MB/s]
23%|██▎ | 15.1M/64.3M [00:05<00:15, 3.16MB/s]
24%|██▍ | 15.5M/64.3M [00:06<00:15, 3.20MB/s]
25%|██▍ | 16.0M/64.3M [00:06<00:14, 3.22MB/s]
26%|██▌ | 16.4M/64.3M [00:06<00:13, 3.52MB/s]
26%|██▋ | 16.9M/64.3M [00:06<00:13, 3.45MB/s]
27%|██▋ | 17.4M/64.3M [00:06<00:13, 3.40MB/s]
28%|██▊ | 17.8M/64.3M [00:06<00:13, 3.37MB/s]
28%|██▊ | 18.3M/64.3M [00:06<00:13, 3.34MB/s]
29%|██▉ | 18.8M/64.3M [00:06<00:14, 3.11MB/s]
30%|██▉ | 19.2M/64.3M [00:07<00:14, 3.16MB/s]
31%|███ | 19.7M/64.3M [00:07<00:13, 3.20MB/s]
31%|███▏ | 20.2M/64.3M [00:07<00:13, 3.22MB/s]
32%|███▏ | 20.6M/64.3M [00:07<00:13, 3.24MB/s]
33%|███▎ | 21.1M/64.3M [00:07<00:13, 3.25MB/s]
33%|███▎ | 21.5M/64.3M [00:07<00:13, 3.26MB/s]
34%|███▍ | 22.0M/64.3M [00:07<00:12, 3.27MB/s]
35%|███▍ | 22.4M/64.3M [00:08<00:12, 3.27MB/s]
36%|███▌ | 22.9M/64.3M [00:08<00:12, 3.27MB/s]
36%|███▋ | 23.4M/64.3M [00:08<00:12, 3.27MB/s]
37%|███▋ | 23.8M/64.3M [00:08<00:12, 3.35MB/s]
38%|███▊ | 24.3M/64.3M [00:08<00:12, 3.32MB/s]
38%|███▊ | 24.7M/64.3M [00:08<00:11, 3.31MB/s]
39%|███▉ | 25.2M/64.3M [00:08<00:11, 3.30MB/s]
40%|███▉ | 25.7M/64.3M [00:09<00:11, 3.29MB/s]
41%|████ | 26.1M/64.3M [00:09<00:10, 3.52MB/s]
41%|████▏ | 26.6M/64.3M [00:09<00:10, 3.52MB/s]
42%|████▏ | 27.1M/64.3M [00:09<00:11, 3.16MB/s]
43%|████▎ | 27.5M/64.3M [00:09<00:11, 3.26MB/s]
44%|████▎ | 28.0M/64.3M [00:09<00:11, 3.27MB/s]
44%|████▍ | 28.4M/64.3M [00:09<00:10, 3.57MB/s]
45%|████▌ | 28.9M/64.3M [00:10<00:11, 3.19MB/s]
46%|████▌ | 29.4M/64.3M [00:10<00:10, 3.28MB/s]
46%|████▋ | 29.9M/64.3M [00:10<00:10, 3.28MB/s]
47%|████▋ | 30.3M/64.3M [00:10<00:10, 3.35MB/s]
48%|████▊ | 30.8M/64.3M [00:10<00:10, 3.33MB/s]
49%|████▊ | 31.3M/64.3M [00:10<00:09, 3.61MB/s]
49%|████▉ | 31.8M/64.3M [00:10<00:09, 3.35MB/s]
50%|█████ | 32.3M/64.3M [00:11<00:09, 3.40MB/s]
51%|█████ | 32.8M/64.3M [00:11<00:09, 3.43MB/s]
52%|█████▏ | 33.3M/64.3M [00:11<00:08, 3.46MB/s]
53%|█████▎ | 33.8M/64.3M [00:11<00:08, 3.48MB/s]
53%|█████▎ | 34.2M/64.3M [00:11<00:08, 3.73MB/s]
54%|█████▍ | 34.6M/64.3M [00:11<00:09, 3.01MB/s]
55%|█████▍ | 35.2M/64.3M [00:11<00:08, 3.43MB/s]
55%|█████▌ | 35.6M/64.3M [00:12<00:08, 3.23MB/s]
56%|█████▌ | 35.9M/64.3M [00:12<00:09, 3.05MB/s]
56%|█████▋ | 36.3M/64.3M [00:12<00:09, 2.92MB/s]
57%|█████▋ | 36.7M/64.3M [00:12<00:09, 2.89MB/s]
58%|█████▊ | 37.1M/64.3M [00:12<00:09, 2.86MB/s]
58%|█████▊ | 37.5M/64.3M [00:12<00:09, 2.91MB/s]
59%|█████▉ | 37.9M/64.3M [00:12<00:08, 3.16MB/s]
60%|█████▉ | 38.3M/64.3M [00:12<00:08, 3.05MB/s]
60%|██████ | 38.7M/64.3M [00:13<00:08, 3.11MB/s]
61%|██████ | 39.2M/64.3M [00:13<00:08, 3.09MB/s]
62%|██████▏ | 39.6M/64.3M [00:13<00:07, 3.36MB/s]
62%|██████▏ | 39.9M/64.3M [00:13<00:07, 3.19MB/s]
63%|██████▎ | 40.3M/64.3M [00:13<00:07, 3.02MB/s]
63%|██████▎ | 40.6M/64.3M [00:13<00:08, 2.90MB/s]
64%|██████▍ | 41.0M/64.3M [00:13<00:07, 2.93MB/s]
64%|██████▍ | 41.5M/64.3M [00:13<00:07, 3.04MB/s]
65%|██████▌ | 41.9M/64.3M [00:14<00:07, 3.12MB/s]
66%|██████▌ | 42.4M/64.3M [00:14<00:06, 3.17MB/s]
67%|██████▋ | 42.8M/64.3M [00:14<00:06, 3.28MB/s]
67%|██████▋ | 43.3M/64.3M [00:14<00:06, 3.28MB/s]
68%|██████▊ | 43.8M/64.3M [00:14<00:05, 3.59MB/s]
69%|██████▉ | 44.3M/64.3M [00:14<00:06, 3.33MB/s]
70%|██████▉ | 44.8M/64.3M [00:14<00:05, 3.38MB/s]
70%|███████ | 45.3M/64.3M [00:15<00:05, 3.35MB/s]
71%|███████ | 45.7M/64.3M [00:15<00:05, 3.40MB/s]
72%|███████▏ | 46.2M/64.3M [00:15<00:05, 3.43MB/s]
73%|███████▎ | 46.7M/64.3M [00:15<00:05, 3.46MB/s]
73%|███████▎ | 47.2M/64.3M [00:15<00:04, 3.55MB/s]
74%|███████▍ | 47.7M/64.3M [00:15<00:04, 3.54MB/s]
75%|███████▍ | 48.2M/64.3M [00:15<00:04, 3.86MB/s]
76%|███████▌ | 48.7M/64.3M [00:16<00:03, 4.02MB/s]
76%|███████▋ | 49.1M/64.3M [00:16<00:04, 3.76MB/s]
77%|███████▋ | 49.5M/64.3M [00:16<00:04, 3.56MB/s]
78%|███████▊ | 49.9M/64.3M [00:16<00:04, 3.41MB/s]
78%|███████▊ | 50.2M/64.3M [00:16<00:04, 3.37MB/s]
79%|███████▉ | 50.7M/64.3M [00:16<00:03, 3.42MB/s]
80%|███████▉ | 51.2M/64.3M [00:16<00:03, 3.79MB/s]
80%|████████ | 51.7M/64.3M [00:16<00:03, 3.98MB/s]
81%|████████ | 52.1M/64.3M [00:16<00:03, 3.72MB/s]
82%|████████▏ | 52.5M/64.3M [00:17<00:03, 3.55MB/s]
82%|████████▏ | 52.9M/64.3M [00:17<00:03, 3.44MB/s]
83%|████████▎ | 53.3M/64.3M [00:17<00:03, 3.34MB/s]
84%|████████▎ | 53.8M/64.3M [00:17<00:03, 3.47MB/s]
84%|████████▍ | 54.3M/64.3M [00:17<00:02, 3.49MB/s]
85%|████████▌ | 54.8M/64.3M [00:17<00:02, 3.84MB/s]
86%|████████▌ | 55.1M/64.3M [00:17<00:02, 3.87MB/s]
86%|████████▋ | 55.5M/64.3M [00:17<00:02, 3.67MB/s]
87%|████████▋ | 55.9M/64.3M [00:18<00:02, 3.48MB/s]
88%|████████▊ | 56.3M/64.3M [00:18<00:02, 3.41MB/s]
88%|████████▊ | 56.8M/64.3M [00:18<00:02, 3.72MB/s]
89%|████████▉ | 57.3M/64.3M [00:18<00:01, 3.74MB/s]
90%|████████▉ | 57.7M/64.3M [00:18<00:01, 3.85MB/s]
90%|█████████ | 58.1M/64.3M [00:18<00:01, 3.66MB/s]
91%|█████████ | 58.5M/64.3M [00:18<00:01, 3.48MB/s]
92%|█████████▏| 58.9M/64.3M [00:18<00:01, 3.34MB/s]
92%|█████████▏| 59.3M/64.3M [00:19<00:01, 3.40MB/s]
93%|█████████▎| 59.8M/64.3M [00:19<00:01, 3.49MB/s]
94%|█████████▍| 60.4M/64.3M [00:19<00:01, 3.52MB/s]
95%|█████████▍| 60.9M/64.3M [00:19<00:00, 3.57MB/s]
95%|█████████▌| 61.3M/64.3M [00:19<00:00, 3.84MB/s]
96%|█████████▌| 61.8M/64.3M [00:19<00:00, 4.10MB/s]
97%|█████████▋| 62.3M/64.3M [00:19<00:00, 3.82MB/s]
97%|█████████▋| 62.7M/64.3M [00:19<00:00, 3.65MB/s]
98%|█████████▊| 63.0M/64.3M [00:20<00:00, 3.47MB/s]
99%|█████████▊| 63.4M/64.3M [00:20<00:00, 3.35MB/s]
99%|█████████▉| 63.9M/64.3M [00:20<00:00, 3.49MB/s]
100%|██████████| 64.3M/64.3M [00:20<00:00, 3.16MB/s]
3. Define the Lightning Trainer#
trainer = TUTrainer(accelerator="gpu", devices=1, max_epochs=5, enable_progress_bar=False)
4. Function to Visualize the Prediction Sets#
def visualize_prediction_sets(inputs, labels, confidence_scores, classes, num_examples=5) -> None:
_, axs = plt.subplots(2, num_examples, figsize=(15, 5))
for i in range(num_examples):
ax = axs[0, i]
img = np.clip(
inputs[i].permute(1, 2, 0).cpu().numpy() * datamodule.std + datamodule.mean, 0, 1
)
ax.imshow(img)
ax.set_title(f"True: {classes[labels[i]]}")
ax.axis("off")
ax = axs[1, i]
for j in range(len(classes)):
ax.barh(classes[j], confidence_scores[i, j], color="blue")
ax.set_xlim(0, 1)
ax.set_xlabel("Confidence Score")
plt.tight_layout()
plt.show()
5. Estimate Prediction Sets with ConformalClsTHR#
Using alpha=0.01, we aim for a 1% error rate.
print("[Phase 2]: ConformalClsTHR calibration")
conformal_model = ConformalClsTHR(alpha=0.01, device="cuda")
routine_thr = ClassificationRoutine(
num_classes=10,
model=model,
loss=None, # No loss needed for evaluation
eval_ood=True,
post_processing=conformal_model,
ood_criterion="post_processing",
)
perf_thr = trainer.test(routine_thr, datamodule=datamodule)
[Phase 2]: ConformalClsTHR calibration
0%| | 0/79 [00:00<?, ?it/s]
1%|▏ | 1/79 [00:00<00:19, 3.95it/s]
10%|█ | 8/79 [00:00<00:02, 27.11it/s]
18%|█▊ | 14/79 [00:00<00:01, 37.76it/s]
25%|██▌ | 20/79 [00:00<00:01, 44.85it/s]
33%|███▎ | 26/79 [00:00<00:01, 49.58it/s]
42%|████▏ | 33/79 [00:00<00:00, 53.16it/s]
49%|████▉ | 39/79 [00:00<00:00, 55.14it/s]
57%|█████▋ | 45/79 [00:00<00:00, 56.56it/s]
65%|██████▍ | 51/79 [00:01<00:00, 57.54it/s]
72%|███████▏ | 57/79 [00:01<00:00, 58.20it/s]
80%|███████▉ | 63/79 [00:01<00:00, 58.69it/s]
87%|████████▋ | 69/79 [00:01<00:00, 59.04it/s]
95%|█████████▍| 75/79 [00:01<00:00, 59.28it/s]
100%|██████████| 79/79 [00:01<00:00, 50.31it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26406 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.671% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 86.584% │
│ AUROC │ 79.252% │
│ Entropy │ 0.35549 │
│ FPR95 │ 100.000% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov@5Risk │ 96.510% │
│ Risk@80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Post-Processing ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │ 0.99000 │
│ SetSize │ 1.52330 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 142.19 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
6. Visualization of ConformalClsTHR prediction sets#
inputs, labels = next(iter(datamodule.test_dataloader()[0]))
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
classes = datamodule.test.classes
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)

7. Estimate Prediction Sets with ConformalClsAPS#
print("[Phase 3]: ConformalClsAPS calibration")
conformal_model = ConformalClsAPS(alpha=0.01, device="cuda", enable_ts=True)
routine_aps = ClassificationRoutine(
num_classes=10,
model=model,
loss=None, # No loss needed for evaluation
eval_ood=True,
post_processing=conformal_model,
ood_criterion="post_processing",
)
perf_aps = trainer.test(routine_aps, datamodule=datamodule)
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)

[Phase 3]: ConformalClsAPS calibration
0%| | 0/79 [00:00<?, ?it/s]
1%|▏ | 1/79 [00:00<00:17, 4.36it/s]
9%|▉ | 7/79 [00:00<00:02, 25.70it/s]
16%|█▋ | 13/79 [00:00<00:01, 37.35it/s]
24%|██▍ | 19/79 [00:00<00:01, 44.74it/s]
32%|███▏ | 25/79 [00:00<00:01, 49.55it/s]
39%|███▉ | 31/79 [00:00<00:00, 52.73it/s]
47%|████▋ | 37/79 [00:00<00:00, 54.88it/s]
54%|█████▍ | 43/79 [00:00<00:00, 56.34it/s]
62%|██████▏ | 49/79 [00:01<00:00, 57.34it/s]
70%|██████▉ | 55/79 [00:01<00:00, 58.02it/s]
77%|███████▋ | 61/79 [00:01<00:00, 58.50it/s]
85%|████████▍ | 67/79 [00:01<00:00, 58.85it/s]
92%|█████████▏| 73/79 [00:01<00:00, 59.07it/s]
100%|██████████| 79/79 [00:01<00:00, 59.14it/s]
100%|██████████| 79/79 [00:01<00:00, 51.26it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26406 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.671% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 82.394% │
│ AUROC │ 73.286% │
│ Entropy │ 0.35549 │
│ FPR95 │ 100.000% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov@5Risk │ 96.510% │
│ Risk@80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Post-Processing ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │ 0.99340 │
│ SetSize │ 2.25980 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 142.19 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
8. Estimate Prediction Sets with ConformalClsRAPS#
print("[Phase 4]: ConformalClsRAPS calibration")
conformal_model = ConformalClsRAPS(
alpha=0.01, regularization_rank=3, penalty=0.002, model=model, device="cuda", enable_ts=True
)
routine_raps = ClassificationRoutine(
num_classes=10,
model=model,
loss=None, # No loss needed for evaluation
eval_ood=True,
post_processing=conformal_model,
ood_criterion="post_processing",
)
perf_raps = trainer.test(routine_raps, datamodule=datamodule)
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)

[Phase 4]: ConformalClsRAPS calibration
0%| | 0/79 [00:00<?, ?it/s]
1%|▏ | 1/79 [00:00<00:15, 4.92it/s]
9%|▉ | 7/79 [00:00<00:02, 27.36it/s]
16%|█▋ | 13/79 [00:00<00:01, 39.09it/s]
24%|██▍ | 19/79 [00:00<00:01, 46.02it/s]
32%|███▏ | 25/79 [00:00<00:01, 50.34it/s]
39%|███▉ | 31/79 [00:00<00:00, 53.18it/s]
47%|████▋ | 37/79 [00:00<00:00, 55.13it/s]
54%|█████▍ | 43/79 [00:00<00:00, 56.45it/s]
62%|██████▏ | 49/79 [00:01<00:00, 57.35it/s]
70%|██████▉ | 55/79 [00:01<00:00, 57.94it/s]
77%|███████▋ | 61/79 [00:01<00:00, 58.39it/s]
85%|████████▍ | 67/79 [00:01<00:00, 58.72it/s]
92%|█████████▏| 73/79 [00:01<00:00, 58.92it/s]
100%|██████████| 79/79 [00:01<00:00, 58.94it/s]
100%|██████████| 79/79 [00:01<00:00, 51.93it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26406 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.671% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 82.614% │
│ AUROC │ 73.309% │
│ Entropy │ 0.35549 │
│ FPR95 │ 100.000% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov@5Risk │ 96.510% │
│ Risk@80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Post-Processing ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │ 0.99340 │
│ SetSize │ 2.14780 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 142.19 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
Summary#
In this tutorial, we explored how to apply conformal prediction to a pretrained ResNet on CIFAR-10. We evaluated three methods: Thresholding (THR), Adaptive Prediction Sets (APS), and Regularized APS (RAPS). For each, we calibrated on a validation set, evaluated OOD performance, and visualized prediction sets. You can explore further by adjusting alpha, changing the model, or testing on other datasets.
Total running time of the script: (1 minutes 35.625 seconds)