Note
Go to the end to download the full example code.
Conformal Prediction on CIFAR-10 with TorchUncertainty#
We evaluate the model’s performance both before and after applying different conformal predictors (THR, APS, RAPS), and visualize how conformal prediction estimates the prediction sets.
We use the pretrained ResNet models we provide on Hugging Face.
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.models.classification.resnet import resnet
from torch_uncertainty.post_processing import ConformalClsAPS, ConformalClsRAPS, ConformalClsTHR
from torch_uncertainty.routines import ClassificationRoutine
1. Load pretrained model from Hugging Face repository#
We use a ResNet18 model trained on CIFAR-10, provided by the TorchUncertainty team
ckpt_path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c10", filename="resnet18_c10.ckpt")
model = resnet(in_channels=3, num_classes=10, arch=18, conv_bias=False, style="cifar")
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt)
model = model.cuda().eval()
2. Load CIFAR-10 Dataset & Define Dataloaders#
We set eval_ood to True to evaluate the performance of Conformal scores for detecting out-of-distribution samples. In this case, since we use a model trained on the full training set, we use the test set to as calibration set for the Conformal methods and for its evaluation. This is not a proper way to evaluate the coverage.
BATCH_SIZE = 128
datamodule = CIFAR10DataModule(
root="./data",
batch_size=BATCH_SIZE,
num_workers=8,
eval_ood=True,
postprocess_set="test",
)
datamodule.prepare_data()
datamodule.setup()
0%| | 0.00/170M [00:00<?, ?B/s]
0%| | 65.5k/170M [00:00<06:55, 411kB/s]
0%| | 229k/170M [00:00<03:41, 770kB/s]
1%| | 918k/170M [00:00<01:11, 2.37MB/s]
2%|▏ | 3.60M/170M [00:00<00:18, 8.82MB/s]
4%|▍ | 6.46M/170M [00:00<00:12, 13.6MB/s]
7%|▋ | 11.2M/170M [00:00<00:07, 22.5MB/s]
9%|▉ | 15.8M/170M [00:00<00:06, 25.8MB/s]
12%|█▏ | 20.5M/170M [00:01<00:04, 31.1MB/s]
15%|█▍ | 25.3M/170M [00:01<00:04, 31.3MB/s]
18%|█▊ | 30.0M/170M [00:01<00:04, 35.1MB/s]
20%|██ | 34.8M/170M [00:01<00:04, 33.9MB/s]
23%|██▎ | 39.6M/170M [00:01<00:03, 37.2MB/s]
26%|██▌ | 44.4M/170M [00:01<00:03, 35.3MB/s]
29%|██▉ | 49.2M/170M [00:01<00:03, 38.2MB/s]
32%|███▏ | 53.9M/170M [00:01<00:03, 35.8MB/s]
34%|███▍ | 58.7M/170M [00:02<00:02, 38.6MB/s]
37%|███▋ | 63.5M/170M [00:02<00:02, 36.2MB/s]
40%|████ | 68.3M/170M [00:02<00:02, 38.9MB/s]
43%|████▎ | 73.0M/170M [00:02<00:02, 36.4MB/s]
46%|████▌ | 77.8M/170M [00:02<00:02, 39.0MB/s]
48%|████▊ | 82.5M/170M [00:02<00:02, 36.3MB/s]
51%|█████ | 87.3M/170M [00:02<00:02, 39.1MB/s]
54%|█████▍ | 92.2M/170M [00:03<00:02, 36.5MB/s]
57%|█████▋ | 97.0M/170M [00:03<00:01, 39.2MB/s]
60%|█████▉ | 102M/170M [00:03<00:01, 36.4MB/s]
62%|██████▏ | 106M/170M [00:03<00:01, 39.2MB/s]
65%|██████▌ | 111M/170M [00:03<00:01, 36.2MB/s]
68%|██████▊ | 116M/170M [00:03<00:01, 39.2MB/s]
71%|███████ | 121M/170M [00:03<00:01, 36.2MB/s]
74%|███████▎ | 126M/170M [00:03<00:01, 39.2MB/s]
76%|███████▋ | 130M/170M [00:04<00:01, 36.1MB/s]
79%|███████▉ | 135M/170M [00:04<00:00, 39.0MB/s]
82%|████████▏ | 140M/170M [00:04<00:00, 36.3MB/s]
85%|████████▍ | 144M/170M [00:04<00:00, 39.0MB/s]
88%|████████▊ | 149M/170M [00:04<00:00, 36.3MB/s]
90%|█████████ | 154M/170M [00:04<00:00, 39.2MB/s]
93%|█████████▎| 159M/170M [00:04<00:00, 36.2MB/s]
96%|█████████▌| 163M/170M [00:04<00:00, 39.1MB/s]
99%|█████████▊| 168M/170M [00:05<00:00, 36.1MB/s]
100%|██████████| 170M/170M [00:05<00:00, 33.8MB/s]
0%| | 0.00/64.3M [00:00<?, ?B/s]
0%| | 32.8k/64.3M [00:00<08:57, 120kB/s]
0%| | 65.5k/64.3M [00:00<06:19, 169kB/s]
0%| | 98.3k/64.3M [00:00<07:32, 142kB/s]
0%| | 131k/64.3M [00:00<06:20, 169kB/s]
0%| | 164k/64.3M [00:00<05:40, 188kB/s]
0%| | 197k/64.3M [00:01<05:16, 203kB/s]
0%| | 229k/64.3M [00:01<05:00, 213kB/s]
0%| | 262k/64.3M [00:01<04:50, 220kB/s]
0%| | 295k/64.3M [00:01<04:43, 226kB/s]
1%| | 328k/64.3M [00:01<04:38, 229kB/s]
1%| | 393k/64.3M [00:01<03:30, 303kB/s]
1%| | 426k/64.3M [00:01<03:45, 284kB/s]
1%| | 492k/64.3M [00:02<03:07, 340kB/s]
1%| | 557k/64.3M [00:02<02:47, 381kB/s]
1%| | 623k/64.3M [00:02<02:35, 409kB/s]
1%| | 688k/64.3M [00:02<02:28, 429kB/s]
1%| | 754k/64.3M [00:02<02:22, 444kB/s]
1%|▏ | 819k/64.3M [00:02<02:19, 454kB/s]
1%|▏ | 918k/64.3M [00:02<01:59, 532kB/s]
2%|▏ | 983k/64.3M [00:03<02:02, 515kB/s]
2%|▏ | 1.08M/64.3M [00:03<01:49, 575kB/s]
2%|▏ | 1.18M/64.3M [00:03<01:42, 616kB/s]
2%|▏ | 1.28M/64.3M [00:03<01:37, 646kB/s]
2%|▏ | 1.41M/64.3M [00:03<01:25, 736kB/s]
2%|▏ | 1.51M/64.3M [00:03<01:25, 731kB/s]
3%|▎ | 1.64M/64.3M [00:03<01:18, 797kB/s]
3%|▎ | 1.80M/64.3M [00:03<01:08, 913kB/s]
3%|▎ | 1.93M/64.3M [00:04<01:07, 919kB/s]
3%|▎ | 2.10M/64.3M [00:04<01:02, 999kB/s]
3%|▎ | 2.23M/64.3M [00:04<01:02, 986kB/s]
4%|▎ | 2.39M/64.3M [00:04<00:59, 1.05MB/s]
4%|▍ | 2.56M/64.3M [00:04<00:56, 1.09MB/s]
4%|▍ | 2.72M/64.3M [00:04<00:55, 1.12MB/s]
5%|▍ | 2.92M/64.3M [00:04<00:50, 1.21MB/s]
5%|▍ | 3.08M/64.3M [00:05<00:50, 1.21MB/s]
5%|▌ | 3.24M/64.3M [00:05<00:50, 1.20MB/s]
5%|▌ | 3.44M/64.3M [00:05<00:47, 1.27MB/s]
6%|▌ | 3.64M/64.3M [00:05<00:46, 1.32MB/s]
6%|▌ | 3.83M/64.3M [00:05<00:44, 1.35MB/s]
6%|▋ | 4.03M/64.3M [00:05<00:43, 1.37MB/s]
7%|▋ | 4.23M/64.3M [00:05<00:43, 1.39MB/s]
7%|▋ | 4.42M/64.3M [00:06<00:42, 1.40MB/s]
7%|▋ | 4.62M/64.3M [00:06<00:42, 1.41MB/s]
8%|▊ | 4.85M/64.3M [00:06<00:40, 1.48MB/s]
8%|▊ | 5.08M/64.3M [00:06<00:38, 1.54MB/s]
8%|▊ | 5.31M/64.3M [00:06<00:37, 1.58MB/s]
9%|▊ | 5.54M/64.3M [00:06<00:36, 1.60MB/s]
9%|▉ | 5.77M/64.3M [00:06<00:36, 1.62MB/s]
9%|▉ | 6.03M/64.3M [00:07<00:34, 1.70MB/s]
10%|▉ | 6.29M/64.3M [00:07<00:32, 1.76MB/s]
10%|█ | 6.55M/64.3M [00:07<00:31, 1.80MB/s]
11%|█ | 6.82M/64.3M [00:07<00:31, 1.83MB/s]
11%|█ | 7.08M/64.3M [00:07<00:30, 1.85MB/s]
11%|█▏ | 7.37M/64.3M [00:07<00:29, 1.94MB/s]
12%|█▏ | 7.67M/64.3M [00:07<00:28, 2.00MB/s]
12%|█▏ | 7.96M/64.3M [00:07<00:27, 2.04MB/s]
13%|█▎ | 8.26M/64.3M [00:08<00:27, 2.07MB/s]
13%|█▎ | 8.55M/64.3M [00:08<00:26, 2.09MB/s]
14%|█▍ | 8.88M/64.3M [00:08<00:25, 2.18MB/s]
14%|█▍ | 9.21M/64.3M [00:08<00:24, 2.24MB/s]
15%|█▍ | 9.57M/64.3M [00:08<00:23, 2.35MB/s]
15%|█▌ | 9.93M/64.3M [00:08<00:22, 2.43MB/s]
16%|█▌ | 10.3M/64.3M [00:08<00:21, 2.48MB/s]
17%|█▋ | 10.7M/64.3M [00:09<00:20, 2.59MB/s]
17%|█▋ | 11.1M/64.3M [00:09<00:19, 2.67MB/s]
18%|█▊ | 11.4M/64.3M [00:09<00:21, 2.50MB/s]
18%|█▊ | 11.8M/64.3M [00:09<00:19, 2.75MB/s]
19%|█▉ | 12.1M/64.3M [00:09<00:20, 2.57MB/s]
19%|█▉ | 12.4M/64.3M [00:09<00:21, 2.44MB/s]
20%|█▉ | 12.7M/64.3M [00:09<00:22, 2.27MB/s]
20%|██ | 12.9M/64.3M [00:10<00:23, 2.15MB/s]
20%|██ | 13.2M/64.3M [00:10<00:25, 2.00MB/s]
21%|██ | 13.4M/64.3M [00:10<00:34, 1.46MB/s]
21%|██ | 13.6M/64.3M [00:10<00:33, 1.51MB/s]
22%|██▏ | 13.8M/64.3M [00:10<00:33, 1.49MB/s]
22%|██▏ | 14.0M/64.3M [00:11<00:45, 1.11MB/s]
22%|██▏ | 14.1M/64.3M [00:11<00:46, 1.07MB/s]
22%|██▏ | 14.3M/64.3M [00:11<00:48, 1.04MB/s]
22%|██▏ | 14.4M/64.3M [00:11<00:49, 1.02MB/s]
23%|██▎ | 14.5M/64.3M [00:11<00:49, 998kB/s]
23%|██▎ | 14.6M/64.3M [00:11<00:50, 984kB/s]
23%|██▎ | 14.8M/64.3M [00:11<00:55, 896kB/s]
23%|██▎ | 14.9M/64.3M [00:12<00:54, 911kB/s]
23%|██▎ | 15.0M/64.3M [00:12<00:53, 922kB/s]
24%|██▎ | 15.2M/64.3M [00:12<01:03, 775kB/s]
24%|██▍ | 15.3M/64.3M [00:12<00:59, 819kB/s]
24%|██▍ | 15.4M/64.3M [00:12<00:57, 854kB/s]
24%|██▍ | 15.6M/64.3M [00:12<00:55, 880kB/s]
24%|██▍ | 15.7M/64.3M [00:12<00:53, 900kB/s]
25%|██▍ | 15.8M/64.3M [00:13<00:52, 914kB/s]
25%|██▍ | 16.0M/64.3M [00:13<00:52, 924kB/s]
25%|██▌ | 16.1M/64.3M [00:13<00:51, 931kB/s]
25%|██▌ | 16.2M/64.3M [00:13<00:51, 937kB/s]
25%|██▌ | 16.3M/64.3M [00:13<00:55, 870kB/s]
26%|██▌ | 16.4M/64.3M [00:13<00:53, 894kB/s]
26%|██▌ | 16.6M/64.3M [00:13<00:52, 910kB/s]
26%|██▌ | 16.7M/64.3M [00:14<00:51, 922kB/s]
26%|██▌ | 16.8M/64.3M [00:14<00:51, 930kB/s]
26%|██▋ | 17.0M/64.3M [00:14<00:50, 935kB/s]
27%|██▋ | 17.1M/64.3M [00:14<00:50, 939kB/s]
27%|██▋ | 17.2M/64.3M [00:14<00:49, 942kB/s]
27%|██▋ | 17.4M/64.3M [00:14<00:49, 944kB/s]
27%|██▋ | 17.5M/64.3M [00:14<00:53, 875kB/s]
27%|██▋ | 17.6M/64.3M [00:15<00:52, 898kB/s]
28%|██▊ | 17.7M/64.3M [00:15<00:50, 913kB/s]
28%|██▊ | 17.9M/64.3M [00:15<00:50, 924kB/s]
28%|██▊ | 18.0M/64.3M [00:15<00:49, 931kB/s]
28%|██▊ | 18.1M/64.3M [00:15<00:49, 937kB/s]
28%|██▊ | 18.3M/64.3M [00:15<00:48, 941kB/s]
29%|██▊ | 18.4M/64.3M [00:15<00:48, 943kB/s]
29%|██▉ | 18.5M/64.3M [00:16<00:48, 946kB/s]
29%|██▉ | 18.6M/64.3M [00:16<00:48, 946kB/s]
29%|██▉ | 18.8M/64.3M [00:16<00:48, 947kB/s]
29%|██▉ | 18.9M/64.3M [00:16<00:44, 1.02MB/s]
30%|██▉ | 19.1M/64.3M [00:16<00:45, 997kB/s]
30%|██▉ | 19.2M/64.3M [00:16<00:45, 983kB/s]
30%|███ | 19.3M/64.3M [00:16<00:46, 973kB/s]
30%|███ | 19.5M/64.3M [00:16<00:46, 965kB/s]
31%|███ | 19.6M/64.3M [00:17<00:43, 1.03MB/s]
31%|███ | 19.8M/64.3M [00:17<00:44, 1.01MB/s]
31%|███ | 19.9M/64.3M [00:17<00:41, 1.06MB/s]
31%|███ | 20.1M/64.3M [00:17<00:43, 1.03MB/s]
31%|███▏ | 20.2M/64.3M [00:17<00:40, 1.08MB/s]
32%|███▏ | 20.4M/64.3M [00:17<00:39, 1.11MB/s]
32%|███▏ | 20.5M/64.3M [00:17<00:38, 1.13MB/s]
32%|███▏ | 20.7M/64.3M [00:18<00:37, 1.15MB/s]
32%|███▏ | 20.9M/64.3M [00:18<00:37, 1.16MB/s]
33%|███▎ | 21.1M/64.3M [00:18<00:34, 1.24MB/s]
33%|███▎ | 21.2M/64.3M [00:18<00:35, 1.22MB/s]
33%|███▎ | 21.4M/64.3M [00:18<00:33, 1.28MB/s]
34%|███▎ | 21.6M/64.3M [00:18<00:34, 1.25MB/s]
34%|███▍ | 21.8M/64.3M [00:18<00:32, 1.31MB/s]
34%|███▍ | 22.0M/64.3M [00:19<00:31, 1.34MB/s]
35%|███▍ | 22.2M/64.3M [00:19<00:30, 1.37MB/s]
35%|███▍ | 22.4M/64.3M [00:19<00:30, 1.38MB/s]
35%|███▌ | 22.6M/64.3M [00:19<00:29, 1.40MB/s]
35%|███▌ | 22.8M/64.3M [00:19<00:28, 1.48MB/s]
36%|███▌ | 23.0M/64.3M [00:19<00:28, 1.46MB/s]
36%|███▌ | 23.2M/64.3M [00:19<00:26, 1.52MB/s]
36%|███▋ | 23.4M/64.3M [00:20<00:27, 1.49MB/s]
37%|███▋ | 23.7M/64.3M [00:20<00:26, 1.54MB/s]
37%|███▋ | 23.8M/64.3M [00:20<00:28, 1.43MB/s]
37%|███▋ | 24.0M/64.3M [00:20<00:29, 1.36MB/s]
38%|███▊ | 24.2M/64.3M [00:20<00:30, 1.31MB/s]
38%|███▊ | 24.3M/64.3M [00:20<00:33, 1.20MB/s]
38%|███▊ | 24.4M/64.3M [00:20<00:35, 1.12MB/s]
38%|███▊ | 24.5M/64.3M [00:20<00:37, 1.07MB/s]
38%|███▊ | 24.7M/64.3M [00:21<00:35, 1.11MB/s]
39%|███▊ | 24.9M/64.3M [00:21<00:34, 1.13MB/s]
39%|███▉ | 25.0M/64.3M [00:21<00:34, 1.15MB/s]
39%|███▉ | 25.2M/64.3M [00:21<00:33, 1.16MB/s]
39%|███▉ | 25.4M/64.3M [00:21<00:33, 1.17MB/s]
40%|███▉ | 25.5M/64.3M [00:21<00:33, 1.17MB/s]
40%|████ | 25.7M/64.3M [00:21<00:30, 1.25MB/s]
40%|████ | 25.9M/64.3M [00:22<00:31, 1.23MB/s]
41%|████ | 26.1M/64.3M [00:22<00:29, 1.29MB/s]
41%|████ | 26.2M/64.3M [00:22<00:30, 1.26MB/s]
41%|████ | 26.4M/64.3M [00:22<00:28, 1.31MB/s]
41%|████▏ | 26.6M/64.3M [00:22<00:28, 1.34MB/s]
42%|████▏ | 26.8M/64.3M [00:22<00:27, 1.37MB/s]
42%|████▏ | 27.0M/64.3M [00:22<00:26, 1.38MB/s]
42%|████▏ | 27.2M/64.3M [00:23<00:27, 1.33MB/s]
43%|████▎ | 27.4M/64.3M [00:23<00:27, 1.36MB/s]
43%|████▎ | 27.6M/64.3M [00:23<00:26, 1.38MB/s]
43%|████▎ | 27.8M/64.3M [00:23<00:26, 1.39MB/s]
44%|████▎ | 28.0M/64.3M [00:23<00:25, 1.40MB/s]
44%|████▍ | 28.2M/64.3M [00:23<00:25, 1.41MB/s]
44%|████▍ | 28.4M/64.3M [00:23<00:25, 1.41MB/s]
44%|████▍ | 28.6M/64.3M [00:24<00:25, 1.42MB/s]
45%|████▍ | 28.8M/64.3M [00:24<00:24, 1.42MB/s]
45%|████▌ | 29.0M/64.3M [00:24<00:24, 1.42MB/s]
45%|████▌ | 29.2M/64.3M [00:24<00:24, 1.42MB/s]
46%|████▌ | 29.4M/64.3M [00:24<00:24, 1.42MB/s]
46%|████▌ | 29.6M/64.3M [00:24<00:24, 1.42MB/s]
46%|████▋ | 29.8M/64.3M [00:24<00:24, 1.42MB/s]
47%|████▋ | 29.9M/64.3M [00:24<00:24, 1.43MB/s]
47%|████▋ | 30.1M/64.3M [00:25<00:23, 1.43MB/s]
47%|████▋ | 30.3M/64.3M [00:25<00:23, 1.43MB/s]
48%|████▊ | 30.5M/64.3M [00:25<00:23, 1.43MB/s]
48%|████▊ | 30.7M/64.3M [00:25<00:23, 1.43MB/s]
48%|████▊ | 30.9M/64.3M [00:25<00:23, 1.43MB/s]
48%|████▊ | 31.1M/64.3M [00:25<00:23, 1.42MB/s]
49%|████▊ | 31.3M/64.3M [00:25<00:23, 1.42MB/s]
49%|████▉ | 31.5M/64.3M [00:26<00:22, 1.42MB/s]
49%|████▉ | 31.7M/64.3M [00:26<00:22, 1.43MB/s]
50%|████▉ | 31.9M/64.3M [00:26<00:22, 1.43MB/s]
50%|████▉ | 32.1M/64.3M [00:26<00:22, 1.43MB/s]
50%|█████ | 32.3M/64.3M [00:26<00:21, 1.49MB/s]
51%|█████ | 32.5M/64.3M [00:26<00:21, 1.47MB/s]
51%|█████ | 32.7M/64.3M [00:26<00:21, 1.46MB/s]
51%|█████ | 32.9M/64.3M [00:27<00:21, 1.45MB/s]
52%|█████▏ | 33.1M/64.3M [00:27<00:21, 1.44MB/s]
52%|█████▏ | 33.3M/64.3M [00:27<00:21, 1.44MB/s]
52%|█████▏ | 33.6M/64.3M [00:27<00:20, 1.50MB/s]
53%|█████▎ | 33.8M/64.3M [00:27<00:20, 1.48MB/s]
53%|█████▎ | 33.9M/64.3M [00:27<00:20, 1.46MB/s]
53%|█████▎ | 34.2M/64.3M [00:27<00:19, 1.52MB/s]
53%|█████▎ | 34.4M/64.3M [00:28<00:20, 1.49MB/s]
54%|█████▍ | 34.6M/64.3M [00:28<00:19, 1.54MB/s]
54%|█████▍ | 34.8M/64.3M [00:28<00:18, 1.58MB/s]
54%|█████▍ | 35.0M/64.3M [00:28<00:19, 1.53MB/s]
55%|█████▍ | 35.3M/64.3M [00:28<00:18, 1.57MB/s]
55%|█████▌ | 35.5M/64.3M [00:28<00:18, 1.60MB/s]
56%|█████▌ | 35.7M/64.3M [00:28<00:17, 1.61MB/s]
56%|█████▌ | 35.9M/64.3M [00:28<00:17, 1.63MB/s]
56%|█████▋ | 36.2M/64.3M [00:29<00:16, 1.71MB/s]
57%|█████▋ | 36.4M/64.3M [00:29<00:16, 1.69MB/s]
57%|█████▋ | 36.7M/64.3M [00:29<00:15, 1.76MB/s]
57%|█████▋ | 36.9M/64.3M [00:29<00:15, 1.73MB/s]
58%|█████▊ | 37.2M/64.3M [00:29<00:15, 1.78MB/s]
58%|█████▊ | 37.5M/64.3M [00:29<00:14, 1.81MB/s]
59%|█████▊ | 37.7M/64.3M [00:29<00:14, 1.84MB/s]
59%|█████▉ | 38.0M/64.3M [00:30<00:13, 1.93MB/s]
60%|█████▉ | 38.3M/64.3M [00:30<00:13, 1.92MB/s]
60%|██████ | 38.6M/64.3M [00:30<00:12, 1.98MB/s]
60%|██████ | 38.8M/64.3M [00:30<00:13, 1.88MB/s]
61%|██████ | 39.1M/64.3M [00:30<00:12, 2.03MB/s]
61%|██████ | 39.4M/64.3M [00:30<00:13, 1.79MB/s]
62%|██████▏ | 39.6M/64.3M [00:30<00:13, 1.81MB/s]
62%|██████▏ | 39.7M/64.3M [00:31<00:15, 1.58MB/s]
62%|██████▏ | 39.9M/64.3M [00:31<00:16, 1.47MB/s]
62%|██████▏ | 40.1M/64.3M [00:31<00:17, 1.38MB/s]
63%|██████▎ | 40.2M/64.3M [00:31<00:22, 1.08MB/s]
63%|██████▎ | 40.4M/64.3M [00:31<00:21, 1.11MB/s]
63%|██████▎ | 40.6M/64.3M [00:31<00:20, 1.13MB/s]
63%|██████▎ | 40.7M/64.3M [00:32<00:20, 1.15MB/s]
64%|██████▎ | 40.9M/64.3M [00:32<00:20, 1.16MB/s]
64%|██████▍ | 41.1M/64.3M [00:32<00:18, 1.23MB/s]
64%|██████▍ | 41.3M/64.3M [00:32<00:18, 1.22MB/s]
64%|██████▍ | 41.4M/64.3M [00:32<00:18, 1.21MB/s]
65%|██████▍ | 41.6M/64.3M [00:32<00:18, 1.20MB/s]
65%|██████▌ | 41.8M/64.3M [00:32<00:17, 1.27MB/s]
65%|██████▌ | 41.9M/64.3M [00:32<00:17, 1.24MB/s]
66%|██████▌ | 42.1M/64.3M [00:33<00:18, 1.23MB/s]
66%|██████▌ | 42.3M/64.3M [00:33<00:17, 1.29MB/s]
66%|██████▌ | 42.5M/64.3M [00:33<00:17, 1.26MB/s]
66%|██████▋ | 42.6M/64.3M [00:33<00:17, 1.24MB/s]
67%|██████▋ | 42.8M/64.3M [00:33<00:16, 1.29MB/s]
67%|██████▋ | 43.0M/64.3M [00:33<00:16, 1.26MB/s]
67%|██████▋ | 43.2M/64.3M [00:33<00:16, 1.31MB/s]
67%|██████▋ | 43.4M/64.3M [00:34<00:16, 1.27MB/s]
68%|██████▊ | 43.5M/64.3M [00:34<00:16, 1.25MB/s]
68%|██████▊ | 43.7M/64.3M [00:34<00:15, 1.30MB/s]
68%|██████▊ | 43.9M/64.3M [00:34<00:16, 1.27MB/s]
69%|██████▊ | 44.1M/64.3M [00:34<00:15, 1.31MB/s]
69%|██████▉ | 44.2M/64.3M [00:34<00:15, 1.28MB/s]
69%|██████▉ | 44.4M/64.3M [00:34<00:15, 1.25MB/s]
69%|██████▉ | 44.6M/64.3M [00:35<00:15, 1.30MB/s]
70%|██████▉ | 44.8M/64.3M [00:35<00:15, 1.27MB/s]
70%|██████▉ | 45.0M/64.3M [00:35<00:14, 1.31MB/s]
70%|███████ | 45.1M/64.3M [00:35<00:15, 1.28MB/s]
70%|███████ | 45.3M/64.3M [00:35<00:15, 1.25MB/s]
71%|███████ | 45.5M/64.3M [00:35<00:14, 1.30MB/s]
71%|███████ | 45.6M/64.3M [00:35<00:14, 1.27MB/s]
71%|███████▏ | 45.8M/64.3M [00:36<00:14, 1.24MB/s]
72%|███████▏ | 46.0M/64.3M [00:36<00:14, 1.30MB/s]
72%|███████▏ | 46.2M/64.3M [00:36<00:14, 1.27MB/s]
72%|███████▏ | 46.4M/64.3M [00:36<00:13, 1.31MB/s]
72%|███████▏ | 46.5M/64.3M [00:36<00:13, 1.28MB/s]
73%|███████▎ | 46.7M/64.3M [00:36<00:13, 1.32MB/s]
73%|███████▎ | 46.9M/64.3M [00:36<00:13, 1.28MB/s]
73%|███████▎ | 47.1M/64.3M [00:36<00:12, 1.32MB/s]
74%|███████▎ | 47.3M/64.3M [00:37<00:12, 1.35MB/s]
74%|███████▍ | 47.4M/64.3M [00:37<00:12, 1.30MB/s]
74%|███████▍ | 47.6M/64.3M [00:37<00:12, 1.34MB/s]
74%|███████▍ | 47.8M/64.3M [00:37<00:12, 1.36MB/s]
75%|███████▍ | 48.0M/64.3M [00:37<00:11, 1.38MB/s]
75%|███████▍ | 48.2M/64.3M [00:37<00:12, 1.32MB/s]
75%|███████▌ | 48.4M/64.3M [00:37<00:11, 1.35MB/s]
76%|███████▌ | 48.6M/64.3M [00:38<00:11, 1.38MB/s]
76%|███████▌ | 48.8M/64.3M [00:38<00:11, 1.39MB/s]
76%|███████▋ | 49.0M/64.3M [00:38<00:10, 1.47MB/s]
77%|███████▋ | 49.2M/64.3M [00:38<00:10, 1.46MB/s]
77%|███████▋ | 49.4M/64.3M [00:38<00:10, 1.45MB/s]
77%|███████▋ | 49.6M/64.3M [00:38<00:09, 1.51MB/s]
78%|███████▊ | 49.8M/64.3M [00:38<00:09, 1.48MB/s]
78%|███████▊ | 50.1M/64.3M [00:39<00:09, 1.54MB/s]
78%|███████▊ | 50.3M/64.3M [00:39<00:08, 1.57MB/s]
79%|███████▊ | 50.5M/64.3M [00:39<00:08, 1.60MB/s]
79%|███████▉ | 50.8M/64.3M [00:39<00:08, 1.62MB/s]
79%|███████▉ | 51.0M/64.3M [00:39<00:08, 1.63MB/s]
80%|███████▉ | 51.2M/64.3M [00:39<00:07, 1.71MB/s]
80%|████████ | 51.5M/64.3M [00:39<00:07, 1.69MB/s]
80%|████████ | 51.7M/64.3M [00:40<00:07, 1.76MB/s]
81%|████████ | 52.0M/64.3M [00:40<00:06, 1.80MB/s]
81%|████████▏ | 52.2M/64.3M [00:40<00:06, 1.76MB/s]
82%|████████▏ | 52.5M/64.3M [00:40<00:06, 1.87MB/s]
82%|████████▏ | 52.8M/64.3M [00:40<00:06, 1.88MB/s]
83%|████████▎ | 53.1M/64.3M [00:40<00:05, 1.88MB/s]
83%|████████▎ | 53.3M/64.3M [00:40<00:05, 1.96MB/s]
83%|████████▎ | 53.6M/64.3M [00:40<00:05, 1.94MB/s]
84%|████████▍ | 53.9M/64.3M [00:41<00:05, 2.00MB/s]
84%|████████▍ | 54.2M/64.3M [00:41<00:04, 2.04MB/s]
85%|████████▍ | 54.5M/64.3M [00:41<00:04, 2.14MB/s]
85%|████████▌ | 54.8M/64.3M [00:41<00:04, 2.14MB/s]
86%|████████▌ | 55.1M/64.3M [00:41<00:04, 1.86MB/s]
86%|████████▌ | 55.4M/64.3M [00:41<00:03, 2.22MB/s]
87%|████████▋ | 55.7M/64.3M [00:41<00:04, 2.12MB/s]
87%|████████▋ | 55.9M/64.3M [00:42<00:04, 1.98MB/s]
87%|████████▋ | 56.1M/64.3M [00:42<00:04, 1.89MB/s]
88%|████████▊ | 56.3M/64.3M [00:42<00:04, 1.75MB/s]
88%|████████▊ | 56.6M/64.3M [00:42<00:04, 1.80MB/s]
88%|████████▊ | 56.9M/64.3M [00:42<00:04, 1.83MB/s]
89%|████████▉ | 57.1M/64.3M [00:42<00:03, 1.85MB/s]
89%|████████▉ | 57.4M/64.3M [00:42<00:03, 1.87MB/s]
90%|████████▉ | 57.6M/64.3M [00:43<00:03, 1.88MB/s]
90%|█████████ | 57.9M/64.3M [00:43<00:03, 1.95MB/s]
91%|█████████ | 58.2M/64.3M [00:43<00:03, 1.94MB/s]
91%|█████████ | 58.5M/64.3M [00:43<00:02, 2.00MB/s]
91%|█████████▏| 58.8M/64.3M [00:43<00:02, 2.04MB/s]
92%|█████████▏| 59.1M/64.3M [00:43<00:02, 2.07MB/s]
92%|█████████▏| 59.4M/64.3M [00:43<00:02, 2.09MB/s]
93%|█████████▎| 59.6M/64.3M [00:44<00:02, 1.95MB/s]
93%|█████████▎| 59.9M/64.3M [00:44<00:02, 2.08MB/s]
94%|█████████▎| 60.2M/64.3M [00:44<00:02, 1.96MB/s]
94%|█████████▍| 60.4M/64.3M [00:44<00:02, 1.80MB/s]
94%|█████████▍| 60.6M/64.3M [00:44<00:02, 1.76MB/s]
95%|█████████▍| 60.8M/64.3M [00:44<00:02, 1.73MB/s]
95%|█████████▍| 61.0M/64.3M [00:44<00:01, 1.71MB/s]
95%|█████████▌| 61.3M/64.3M [00:45<00:01, 1.69MB/s]
96%|█████████▌| 61.5M/64.3M [00:45<00:01, 1.68MB/s]
96%|█████████▌| 61.7M/64.3M [00:45<00:01, 1.68MB/s]
96%|█████████▋| 62.0M/64.3M [00:45<00:01, 1.67MB/s]
97%|█████████▋| 62.2M/64.3M [00:45<00:01, 1.74MB/s]
97%|█████████▋| 62.5M/64.3M [00:45<00:01, 1.72MB/s]
98%|█████████▊| 62.7M/64.3M [00:45<00:00, 1.77MB/s]
98%|█████████▊| 62.9M/64.3M [00:45<00:00, 1.74MB/s]
98%|█████████▊| 63.2M/64.3M [00:46<00:00, 1.79MB/s]
99%|█████████▊| 63.4M/64.3M [00:46<00:00, 1.75MB/s]
99%|█████████▉| 63.7M/64.3M [00:46<00:00, 1.79MB/s]
100%|█████████▉| 64.0M/64.3M [00:46<00:00, 1.82MB/s]
100%|█████████▉| 64.2M/64.3M [00:46<00:00, 1.78MB/s]
100%|██████████| 64.3M/64.3M [00:46<00:00, 1.38MB/s]
3. Define the Lightning Trainer#
trainer = TUTrainer(accelerator="gpu", devices=1, max_epochs=5, enable_progress_bar=False)
4. Function to Visualize the Prediction Sets#
def visualize_prediction_sets(inputs, labels, confidence_scores, classes, num_examples=5) -> None:
_, axs = plt.subplots(2, num_examples, figsize=(15, 5))
for i in range(num_examples):
ax = axs[0, i]
img = np.clip(
inputs[i].permute(1, 2, 0).cpu().numpy() * datamodule.std + datamodule.mean, 0, 1
)
ax.imshow(img)
ax.set_title(f"True: {classes[labels[i]]}")
ax.axis("off")
ax = axs[1, i]
for j in range(len(classes)):
ax.barh(classes[j], confidence_scores[i, j], color="blue")
ax.set_xlim(0, 1)
ax.set_xlabel("Confidence Score")
plt.tight_layout()
plt.show()
5. Estimate Prediction Sets with ConformalClsTHR#
Using alpha=0.01, we aim for a 1% error rate.
print("[Phase 2]: ConformalClsTHR calibration")
conformal_model = ConformalClsTHR(alpha=0.01, device="cuda")
routine_thr = ClassificationRoutine(
num_classes=10,
model=model,
loss=None, # No loss needed for evaluation
eval_ood=True,
post_processing=conformal_model,
ood_criterion="post_processing",
)
perf_thr = trainer.test(routine_thr, datamodule=datamodule)
[Phase 2]: ConformalClsTHR calibration
0%| | 0/79 [00:00<?, ?it/s]
1%|▏ | 1/79 [00:00<00:20, 3.74it/s]
10%|█ | 8/79 [00:00<00:02, 26.18it/s]
18%|█▊ | 14/79 [00:00<00:01, 36.64it/s]
25%|██▌ | 20/79 [00:00<00:01, 43.59it/s]
33%|███▎ | 26/79 [00:00<00:01, 48.56it/s]
41%|████ | 32/79 [00:00<00:00, 52.02it/s]
48%|████▊ | 38/79 [00:00<00:00, 54.42it/s]
56%|█████▌ | 44/79 [00:00<00:00, 56.08it/s]
63%|██████▎ | 50/79 [00:01<00:00, 57.21it/s]
71%|███████ | 56/79 [00:01<00:00, 57.97it/s]
78%|███████▊ | 62/79 [00:01<00:00, 58.55it/s]
87%|████████▋ | 69/79 [00:01<00:00, 59.05it/s]
95%|█████████▍| 75/79 [00:01<00:00, 59.27it/s]
100%|██████████| 79/79 [00:01<00:00, 50.00it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26406 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.671% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 86.584% │
│ AUROC │ 79.252% │
│ Entropy │ 0.08849 │
│ FPR95 │ 100.000% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov@5Risk │ 96.510% │
│ Risk@80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Post-Processing ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │ 0.99000 │
│ SetSize │ 1.52330 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 142.19 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
6. Visualization of ConformalClsTHR prediction sets#
inputs, labels = next(iter(datamodule.test_dataloader()[0]))
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
classes = datamodule.test.classes
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)

7. Estimate Prediction Sets with ConformalClsAPS#
print("[Phase 3]: ConformalClsAPS calibration")
conformal_model = ConformalClsAPS(alpha=0.01, device="cuda", enable_ts=True)
routine_aps = ClassificationRoutine(
num_classes=10,
model=model,
loss=None, # No loss needed for evaluation
eval_ood=True,
post_processing=conformal_model,
ood_criterion="post_processing",
)
perf_aps = trainer.test(routine_aps, datamodule=datamodule)
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)

[Phase 3]: ConformalClsAPS calibration
0%| | 0/79 [00:00<?, ?it/s]
1%|▏ | 1/79 [00:00<00:19, 4.00it/s]
10%|█ | 8/79 [00:00<00:02, 26.49it/s]
18%|█▊ | 14/79 [00:00<00:01, 36.94it/s]
25%|██▌ | 20/79 [00:00<00:01, 43.80it/s]
33%|███▎ | 26/79 [00:00<00:01, 48.65it/s]
41%|████ | 32/79 [00:00<00:00, 52.01it/s]
48%|████▊ | 38/79 [00:00<00:00, 54.29it/s]
56%|█████▌ | 44/79 [00:00<00:00, 55.90it/s]
63%|██████▎ | 50/79 [00:01<00:00, 56.96it/s]
71%|███████ | 56/79 [00:01<00:00, 57.83it/s]
78%|███████▊ | 62/79 [00:01<00:00, 58.34it/s]
86%|████████▌ | 68/79 [00:01<00:00, 58.72it/s]
94%|█████████▎| 74/79 [00:01<00:00, 58.99it/s]
100%|██████████| 79/79 [00:01<00:00, 50.49it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26406 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.671% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 82.432% │
│ AUROC │ 73.377% │
│ Entropy │ 0.08849 │
│ FPR95 │ 100.000% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov@5Risk │ 96.510% │
│ Risk@80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Post-Processing ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │ 0.99300 │
│ SetSize │ 2.20440 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 142.19 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
8. Estimate Prediction Sets with ConformalClsRAPS#
print("[Phase 4]: ConformalClsRAPS calibration")
conformal_model = ConformalClsRAPS(
alpha=0.01, regularization_rank=3, penalty=0.002, model=model, device="cuda", enable_ts=True
)
routine_raps = ClassificationRoutine(
num_classes=10,
model=model,
loss=None, # No loss needed for evaluation
eval_ood=True,
post_processing=conformal_model,
ood_criterion="post_processing",
)
perf_raps = trainer.test(routine_raps, datamodule=datamodule)
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)

[Phase 4]: ConformalClsRAPS calibration
0%| | 0/79 [00:00<?, ?it/s]
1%|▏ | 1/79 [00:00<00:17, 4.47it/s]
9%|▉ | 7/79 [00:00<00:02, 25.46it/s]
16%|█▋ | 13/79 [00:00<00:01, 37.33it/s]
24%|██▍ | 19/79 [00:00<00:01, 44.74it/s]
32%|███▏ | 25/79 [00:00<00:01, 49.47it/s]
39%|███▉ | 31/79 [00:00<00:00, 52.64it/s]
47%|████▋ | 37/79 [00:00<00:00, 54.69it/s]
54%|█████▍ | 43/79 [00:00<00:00, 56.15it/s]
62%|██████▏ | 49/79 [00:01<00:00, 57.13it/s]
70%|██████▉ | 55/79 [00:01<00:00, 57.78it/s]
77%|███████▋ | 61/79 [00:01<00:00, 58.23it/s]
85%|████████▍ | 67/79 [00:01<00:00, 58.61it/s]
92%|█████████▏| 73/79 [00:01<00:00, 58.85it/s]
100%|██████████| 79/79 [00:01<00:00, 58.95it/s]
100%|██████████| 79/79 [00:01<00:00, 51.15it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26406 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.671% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 82.678% │
│ AUROC │ 73.360% │
│ Entropy │ 0.08849 │
│ FPR95 │ 100.000% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov@5Risk │ 96.510% │
│ Risk@80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Post-Processing ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │ 0.99420 │
│ SetSize │ 2.17680 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 142.19 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
Summary#
In this tutorial, we explored how to apply conformal prediction to a pretrained ResNet on CIFAR-10. We evaluated three methods: Thresholding (THR), Adaptive Prediction Sets (APS), and Regularized APS (RAPS). For each, we calibrated on a validation set, evaluated OOD performance, and visualized prediction sets. You can explore further by adjusting alpha, changing the model, or testing on other datasets.
Total running time of the script: (1 minutes 42.680 seconds)