Conformal Prediction on CIFAR-10 with TorchUncertainty#

We evaluate the model’s performance both before and after applying different conformal predictors (THR, APS, RAPS), and visualize how conformal prediction estimates the prediction sets.

We use the pretrained ResNet models we provide on Hugging Face.

import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.models.classification.resnet import resnet
from torch_uncertainty.post_processing import ConformalClsAPS, ConformalClsRAPS, ConformalClsTHR
from torch_uncertainty.routines import ClassificationRoutine

1. Load pretrained model from Hugging Face repository#

We use a ResNet18 model trained on CIFAR-10, provided by the TorchUncertainty team

ckpt_path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c10", filename="resnet18_c10.ckpt")
model = resnet(in_channels=3, num_classes=10, arch=18, conv_bias=False, style="cifar")
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt)
model = model.cuda().eval()

2. Load CIFAR-10 Dataset & Define Dataloaders#

We set eval_ood to True to evaluate the performance of Conformal scores for detecting out-of-distribution samples. In this case, since we use a model trained on the full training set, we use the test set to as calibration set for the Conformal methods and for its evaluation. This is not a proper way to evaluate the coverage.

BATCH_SIZE = 128

datamodule = CIFAR10DataModule(
    root="./data",
    batch_size=BATCH_SIZE,
    num_workers=8,
    eval_ood=True,
    postprocess_set="test",
)
datamodule.prepare_data()
datamodule.setup()
  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 65.5k/170M [00:00<07:15, 391kB/s]
  0%|          | 229k/170M [00:00<03:51, 737kB/s]
  1%|          | 885k/170M [00:00<01:08, 2.46MB/s]
  1%|          | 1.80M/170M [00:00<00:40, 4.19MB/s]
  3%|▎         | 5.31M/170M [00:00<00:12, 13.1MB/s]
  5%|▌         | 9.31M/170M [00:00<00:08, 20.1MB/s]
  8%|▊         | 14.0M/170M [00:00<00:06, 24.8MB/s]
 11%|█         | 18.6M/170M [00:01<00:04, 30.4MB/s]
 13%|█▎        | 22.0M/170M [00:01<00:04, 30.7MB/s]
 16%|█▌        | 26.4M/170M [00:01<00:04, 34.6MB/s]
 18%|█▊        | 30.0M/170M [00:01<00:04, 32.0MB/s]
 20%|██        | 34.6M/170M [00:01<00:04, 33.8MB/s]
 23%|██▎       | 39.2M/170M [00:01<00:03, 37.1MB/s]
 25%|██▌       | 43.1M/170M [00:01<00:03, 34.2MB/s]
 28%|██▊       | 47.3M/170M [00:01<00:03, 34.8MB/s]
 30%|███       | 51.9M/170M [00:01<00:03, 37.7MB/s]
 33%|███▎      | 55.8M/170M [00:02<00:03, 34.7MB/s]
 35%|███▌      | 60.1M/170M [00:02<00:03, 35.0MB/s]
 38%|███▊      | 64.6M/170M [00:02<00:02, 37.8MB/s]
 40%|████      | 68.5M/170M [00:02<00:02, 35.0MB/s]
 43%|████▎     | 72.7M/170M [00:02<00:02, 35.2MB/s]
 45%|████▌     | 77.2M/170M [00:02<00:02, 37.5MB/s]
 48%|████▊     | 81.0M/170M [00:02<00:02, 34.9MB/s]
 50%|█████     | 85.5M/170M [00:02<00:02, 35.4MB/s]
 53%|█████▎    | 90.1M/170M [00:03<00:02, 38.3MB/s]
 55%|█████▌    | 94.0M/170M [00:03<00:02, 35.5MB/s]
 58%|█████▊    | 98.1M/170M [00:03<00:02, 35.1MB/s]
 60%|██████    | 103M/170M [00:03<00:01, 37.9MB/s]
 63%|██████▎   | 107M/170M [00:03<00:01, 35.2MB/s]
 65%|██████▍   | 111M/170M [00:03<00:01, 36.5MB/s]
 67%|██████▋   | 114M/170M [00:03<00:01, 35.6MB/s]
 70%|██████▉   | 119M/170M [00:03<00:01, 34.7MB/s]
 72%|███████▏  | 123M/170M [00:03<00:01, 37.6MB/s]
 75%|███████▍  | 127M/170M [00:04<00:01, 36.2MB/s]
 77%|███████▋  | 131M/170M [00:04<00:01, 34.3MB/s]
 80%|███████▉  | 136M/170M [00:04<00:00, 36.4MB/s]
 82%|████████▏ | 139M/170M [00:04<00:00, 35.4MB/s]
 84%|████████▍ | 144M/170M [00:04<00:00, 38.3MB/s]
 87%|████████▋ | 148M/170M [00:04<00:00, 35.5MB/s]
 89%|████████▉ | 152M/170M [00:04<00:00, 35.2MB/s]
 92%|█████████▏| 156M/170M [00:04<00:00, 37.3MB/s]
 94%|█████████▍| 160M/170M [00:04<00:00, 34.9MB/s]
 97%|█████████▋| 165M/170M [00:05<00:00, 37.9MB/s]
 99%|█████████▉| 169M/170M [00:05<00:00, 36.3MB/s]
100%|██████████| 170M/170M [00:05<00:00, 32.4MB/s]

  0%|          | 0.00/64.3M [00:00<?, ?B/s]
  0%|          | 32.8k/64.3M [00:00<08:57, 119kB/s]
  0%|          | 65.5k/64.3M [00:00<06:20, 169kB/s]
  0%|          | 98.3k/64.3M [00:00<07:32, 142kB/s]
  0%|          | 131k/64.3M [00:00<06:20, 169kB/s]
  0%|          | 164k/64.3M [00:00<05:40, 188kB/s]
  0%|          | 197k/64.3M [00:01<05:16, 203kB/s]
  0%|          | 229k/64.3M [00:01<05:00, 213kB/s]
  0%|          | 262k/64.3M [00:01<04:50, 220kB/s]
  0%|          | 295k/64.3M [00:01<04:43, 226kB/s]
  1%|          | 328k/64.3M [00:01<04:38, 229kB/s]
  1%|          | 393k/64.3M [00:01<03:31, 303kB/s]
  1%|          | 426k/64.3M [00:01<03:45, 283kB/s]
  1%|          | 492k/64.3M [00:02<03:06, 341kB/s]
  1%|          | 557k/64.3M [00:02<02:47, 381kB/s]
  1%|          | 623k/64.3M [00:02<02:35, 410kB/s]
  1%|          | 688k/64.3M [00:02<02:28, 430kB/s]
  1%|          | 754k/64.3M [00:02<02:23, 443kB/s]
  1%|▏         | 852k/64.3M [00:02<02:01, 524kB/s]
  1%|▏         | 918k/64.3M [00:02<02:04, 510kB/s]
  2%|▏         | 1.02M/64.3M [00:03<01:50, 570kB/s]
  2%|▏         | 1.11M/64.3M [00:03<01:43, 613kB/s]
  2%|▏         | 1.21M/64.3M [00:03<01:38, 643kB/s]
  2%|▏         | 1.31M/64.3M [00:03<01:34, 664kB/s]
  2%|▏         | 1.44M/64.3M [00:03<01:23, 750kB/s]
  2%|▏         | 1.54M/64.3M [00:03<01:24, 739kB/s]
  3%|▎         | 1.67M/64.3M [00:03<01:18, 802kB/s]
  3%|▎         | 1.80M/64.3M [00:03<01:13, 847kB/s]
  3%|▎         | 1.93M/64.3M [00:04<01:10, 878kB/s]
  3%|▎         | 2.06M/64.3M [00:04<01:09, 900kB/s]
  3%|▎         | 2.20M/64.3M [00:04<01:07, 916kB/s]
  4%|▎         | 2.36M/64.3M [00:04<01:02, 998kB/s]
  4%|▍         | 2.52M/64.3M [00:04<00:58, 1.06MB/s]
  4%|▍         | 2.69M/64.3M [00:04<00:56, 1.09MB/s]
  4%|▍         | 2.85M/64.3M [00:04<00:54, 1.12MB/s]
  5%|▍         | 3.05M/64.3M [00:05<00:50, 1.21MB/s]
  5%|▌         | 3.24M/64.3M [00:05<00:47, 1.28MB/s]
  5%|▌         | 3.44M/64.3M [00:05<00:46, 1.32MB/s]
  6%|▌         | 3.64M/64.3M [00:05<00:44, 1.35MB/s]
  6%|▌         | 3.87M/64.3M [00:05<00:41, 1.45MB/s]
  6%|▋         | 4.10M/64.3M [00:05<00:39, 1.51MB/s]
  7%|▋         | 4.36M/64.3M [00:05<00:36, 1.63MB/s]
  7%|▋         | 4.59M/64.3M [00:06<00:36, 1.64MB/s]
  8%|▊         | 4.88M/64.3M [00:06<00:33, 1.79MB/s]
  8%|▊         | 5.14M/64.3M [00:06<00:32, 1.83MB/s]
  9%|▊         | 5.47M/64.3M [00:06<00:29, 1.99MB/s]
  9%|▉         | 5.80M/64.3M [00:06<00:27, 2.11MB/s]
 10%|▉         | 6.16M/64.3M [00:06<00:25, 2.26MB/s]
 10%|█         | 6.59M/64.3M [00:06<00:22, 2.51MB/s]
 11%|█         | 7.05M/64.3M [00:07<00:20, 2.76MB/s]
 12%|█▏        | 7.54M/64.3M [00:07<00:18, 3.00MB/s]
 13%|█▎        | 8.06M/64.3M [00:07<00:17, 3.24MB/s]
 13%|█▎        | 8.59M/64.3M [00:07<00:16, 3.41MB/s]
 14%|█▍        | 9.18M/64.3M [00:07<00:15, 3.67MB/s]
 15%|█▌        | 9.80M/64.3M [00:07<00:13, 3.92MB/s]
 16%|█▋        | 10.5M/64.3M [00:07<00:12, 4.17MB/s]
 17%|█▋        | 11.0M/64.3M [00:07<00:13, 4.02MB/s]
 18%|█▊        | 11.4M/64.3M [00:08<00:14, 3.76MB/s]
 19%|█▊        | 12.0M/64.3M [00:08<00:13, 3.83MB/s]
 20%|█▉        | 12.6M/64.3M [00:08<00:12, 4.05MB/s]
 20%|██        | 13.0M/64.3M [00:08<00:13, 3.75MB/s]
 21%|██        | 13.4M/64.3M [00:08<00:14, 3.48MB/s]
 21%|██▏       | 13.8M/64.3M [00:08<00:15, 3.22MB/s]
 22%|██▏       | 14.1M/64.3M [00:08<00:16, 2.97MB/s]
 22%|██▏       | 14.4M/64.3M [00:09<00:17, 2.79MB/s]
 23%|██▎       | 14.8M/64.3M [00:09<00:17, 2.81MB/s]
 24%|██▎       | 15.1M/64.3M [00:09<00:18, 2.59MB/s]
 24%|██▍       | 15.5M/64.3M [00:09<00:18, 2.61MB/s]
 24%|██▍       | 15.7M/64.3M [00:09<00:22, 2.18MB/s]
 25%|██▍       | 16.0M/64.3M [00:09<00:26, 1.82MB/s]
 25%|██▌       | 16.3M/64.3M [00:10<00:27, 1.78MB/s]
 26%|██▌       | 16.4M/64.3M [00:10<00:28, 1.68MB/s]
 26%|██▌       | 16.6M/64.3M [00:10<00:29, 1.61MB/s]
 26%|██▌       | 16.8M/64.3M [00:10<00:38, 1.22MB/s]
 26%|██▋       | 17.0M/64.3M [00:10<00:39, 1.21MB/s]
 27%|██▋       | 17.1M/64.3M [00:10<00:41, 1.14MB/s]
 27%|██▋       | 17.3M/64.3M [00:11<00:43, 1.09MB/s]
 27%|██▋       | 17.4M/64.3M [00:11<00:41, 1.12MB/s]
 27%|██▋       | 17.6M/64.3M [00:11<00:41, 1.14MB/s]
 28%|██▊       | 17.8M/64.3M [00:11<00:40, 1.15MB/s]
 28%|██▊       | 17.9M/64.3M [00:11<00:39, 1.16MB/s]
 28%|██▊       | 18.1M/64.3M [00:11<00:39, 1.17MB/s]
 28%|██▊       | 18.3M/64.3M [00:11<00:39, 1.17MB/s]
 29%|██▊       | 18.4M/64.3M [00:12<00:45, 1.02MB/s]
 29%|██▉       | 18.6M/64.3M [00:12<00:45, 999kB/s]
 29%|██▉       | 18.7M/64.3M [00:12<00:46, 986kB/s]
 29%|██▉       | 18.9M/64.3M [00:12<00:46, 976kB/s]
 30%|██▉       | 19.0M/64.3M [00:12<00:46, 969kB/s]
 30%|██▉       | 19.1M/64.3M [00:12<00:50, 897kB/s]
 30%|██▉       | 19.2M/64.3M [00:12<00:49, 912kB/s]
 30%|███       | 19.4M/64.3M [00:13<00:48, 922kB/s]
 30%|███       | 19.5M/64.3M [00:13<00:48, 930kB/s]
 31%|███       | 19.6M/64.3M [00:13<00:47, 935kB/s]
 31%|███       | 19.8M/64.3M [00:13<00:47, 939kB/s]
 31%|███       | 19.9M/64.3M [00:13<00:47, 942kB/s]
 31%|███       | 20.0M/64.3M [00:13<00:46, 944kB/s]
 31%|███▏      | 20.2M/64.3M [00:13<00:46, 946kB/s]
 32%|███▏      | 20.3M/64.3M [00:14<00:43, 1.02MB/s]
 32%|███▏      | 20.4M/64.3M [00:14<00:43, 997kB/s]
 32%|███▏      | 20.6M/64.3M [00:14<00:44, 982kB/s]
 32%|███▏      | 20.7M/64.3M [00:14<00:44, 972kB/s]
 32%|███▏      | 20.9M/64.3M [00:14<00:41, 1.04MB/s]
 33%|███▎      | 21.0M/64.3M [00:14<00:42, 1.01MB/s]
 33%|███▎      | 21.1M/64.3M [00:14<00:43, 991kB/s]
 33%|███▎      | 21.3M/64.3M [00:15<00:43, 979kB/s]
 33%|███▎      | 21.4M/64.3M [00:15<00:44, 970kB/s]
 34%|███▎      | 21.6M/64.3M [00:15<00:41, 1.03MB/s]
 34%|███▎      | 21.7M/64.3M [00:15<00:42, 1.01MB/s]
 34%|███▍      | 21.8M/64.3M [00:15<00:42, 991kB/s]
 34%|███▍      | 22.0M/64.3M [00:15<00:43, 978kB/s]
 34%|███▍      | 22.1M/64.3M [00:15<00:40, 1.04MB/s]
 35%|███▍      | 22.2M/64.3M [00:16<00:41, 1.01MB/s]
 35%|███▍      | 22.4M/64.3M [00:16<00:42, 995kB/s]
 35%|███▌      | 22.5M/64.3M [00:16<00:42, 981kB/s]
 35%|███▌      | 22.7M/64.3M [00:16<00:39, 1.04MB/s]
 35%|███▌      | 22.8M/64.3M [00:16<00:40, 1.01MB/s]
 36%|███▌      | 22.9M/64.3M [00:16<00:41, 995kB/s]
 36%|███▌      | 23.1M/64.3M [00:16<00:42, 981kB/s]
 36%|███▌      | 23.2M/64.3M [00:16<00:39, 1.04MB/s]
 36%|███▋      | 23.4M/64.3M [00:17<00:40, 1.01MB/s]
 37%|███▋      | 23.5M/64.3M [00:17<00:41, 994kB/s]
 37%|███▋      | 23.6M/64.3M [00:17<00:41, 981kB/s]
 37%|███▋      | 23.8M/64.3M [00:17<00:38, 1.04MB/s]
 37%|███▋      | 23.9M/64.3M [00:17<00:39, 1.01MB/s]
 37%|███▋      | 24.1M/64.3M [00:17<00:37, 1.07MB/s]
 38%|███▊      | 24.2M/64.3M [00:17<00:38, 1.03MB/s]
 38%|███▊      | 24.3M/64.3M [00:18<00:39, 1.01MB/s]
 38%|███▊      | 24.5M/64.3M [00:18<00:37, 1.06MB/s]
 38%|███▊      | 24.6M/64.3M [00:18<00:38, 1.03MB/s]
 39%|███▊      | 24.8M/64.3M [00:18<00:36, 1.07MB/s]
 39%|███▉      | 25.0M/64.3M [00:18<00:35, 1.11MB/s]
 39%|███▉      | 25.1M/64.3M [00:18<00:36, 1.06MB/s]
 39%|███▉      | 25.3M/64.3M [00:18<00:35, 1.10MB/s]
 40%|███▉      | 25.4M/64.3M [00:19<00:34, 1.13MB/s]
 40%|███▉      | 25.6M/64.3M [00:19<00:33, 1.14MB/s]
 40%|████      | 25.8M/64.3M [00:19<00:33, 1.16MB/s]
 40%|████      | 25.9M/64.3M [00:19<00:32, 1.16MB/s]
 41%|████      | 26.1M/64.3M [00:19<00:32, 1.17MB/s]
 41%|████      | 26.3M/64.3M [00:19<00:30, 1.25MB/s]
 41%|████      | 26.4M/64.3M [00:19<00:30, 1.23MB/s]
 41%|████▏     | 26.6M/64.3M [00:20<00:30, 1.22MB/s]
 42%|████▏     | 26.8M/64.3M [00:20<00:29, 1.28MB/s]
 42%|████▏     | 27.0M/64.3M [00:20<00:28, 1.32MB/s]
 42%|████▏     | 27.2M/64.3M [00:20<00:27, 1.35MB/s]
 43%|████▎     | 27.4M/64.3M [00:20<00:26, 1.37MB/s]
 43%|████▎     | 27.6M/64.3M [00:20<00:26, 1.39MB/s]
 43%|████▎     | 27.8M/64.3M [00:20<00:26, 1.40MB/s]
 44%|████▎     | 28.0M/64.3M [00:20<00:25, 1.41MB/s]
 44%|████▍     | 28.1M/64.3M [00:21<00:28, 1.27MB/s]
 44%|████▍     | 28.4M/64.3M [00:21<00:24, 1.46MB/s]
 44%|████▍     | 28.5M/64.3M [00:21<00:25, 1.38MB/s]
 45%|████▍     | 28.7M/64.3M [00:21<00:26, 1.32MB/s]
 45%|████▍     | 28.9M/64.3M [00:21<00:27, 1.28MB/s]
 45%|████▌     | 29.0M/64.3M [00:21<00:28, 1.25MB/s]
 45%|████▌     | 29.2M/64.3M [00:21<00:28, 1.23MB/s]
 46%|████▌     | 29.4M/64.3M [00:22<00:28, 1.22MB/s]
 46%|████▌     | 29.5M/64.3M [00:22<00:28, 1.21MB/s]
 46%|████▌     | 29.7M/64.3M [00:22<00:27, 1.27MB/s]
 46%|████▋     | 29.9M/64.3M [00:22<00:27, 1.25MB/s]
 47%|████▋     | 30.1M/64.3M [00:22<00:26, 1.30MB/s]
 47%|████▋     | 30.2M/64.3M [00:22<00:26, 1.27MB/s]
 47%|████▋     | 30.4M/64.3M [00:22<00:25, 1.31MB/s]
 48%|████▊     | 30.6M/64.3M [00:23<00:24, 1.35MB/s]
 48%|████▊     | 30.8M/64.3M [00:23<00:24, 1.37MB/s]
 48%|████▊     | 31.0M/64.3M [00:23<00:23, 1.39MB/s]
 49%|████▊     | 31.3M/64.3M [00:23<00:22, 1.47MB/s]
 49%|████▉     | 31.5M/64.3M [00:23<00:22, 1.45MB/s]
 49%|████▉     | 31.7M/64.3M [00:23<00:22, 1.44MB/s]
 50%|████▉     | 31.9M/64.3M [00:23<00:21, 1.51MB/s]
 50%|████▉     | 32.1M/64.3M [00:24<00:21, 1.48MB/s]
 50%|█████     | 32.3M/64.3M [00:24<00:20, 1.54MB/s]
 51%|█████     | 32.5M/64.3M [00:24<00:21, 1.50MB/s]
 51%|█████     | 32.7M/64.3M [00:24<00:20, 1.55MB/s]
 51%|█████     | 32.9M/64.3M [00:24<00:20, 1.51MB/s]
 52%|█████▏    | 33.2M/64.3M [00:24<00:19, 1.56MB/s]
 52%|█████▏    | 33.4M/64.3M [00:24<00:20, 1.52MB/s]
 52%|█████▏    | 33.6M/64.3M [00:24<00:19, 1.56MB/s]
 53%|█████▎    | 33.8M/64.3M [00:25<00:19, 1.59MB/s]
 53%|█████▎    | 34.0M/64.3M [00:25<00:19, 1.54MB/s]
 53%|█████▎    | 34.2M/64.3M [00:25<00:19, 1.58MB/s]
 54%|█████▎    | 34.5M/64.3M [00:25<00:18, 1.60MB/s]
 54%|█████▍    | 34.7M/64.3M [00:25<00:19, 1.55MB/s]
 54%|█████▍    | 34.9M/64.3M [00:25<00:18, 1.58MB/s]
 55%|█████▍    | 35.1M/64.3M [00:25<00:18, 1.61MB/s]
 55%|█████▍    | 35.3M/64.3M [00:26<00:18, 1.55MB/s]
 55%|█████▌    | 35.6M/64.3M [00:26<00:18, 1.58MB/s]
 56%|█████▌    | 35.7M/64.3M [00:26<00:18, 1.54MB/s]
 56%|█████▌    | 36.0M/64.3M [00:26<00:17, 1.57MB/s]
 56%|█████▋    | 36.2M/64.3M [00:26<00:17, 1.60MB/s]
 57%|█████▋    | 36.4M/64.3M [00:26<00:18, 1.55MB/s]
 57%|█████▋    | 36.6M/64.3M [00:26<00:17, 1.58MB/s]
 57%|█████▋    | 36.9M/64.3M [00:27<00:17, 1.60MB/s]
 58%|█████▊    | 37.1M/64.3M [00:27<00:17, 1.55MB/s]
 58%|█████▊    | 37.3M/64.3M [00:27<00:17, 1.58MB/s]
 58%|█████▊    | 37.5M/64.3M [00:27<00:17, 1.54MB/s]
 59%|█████▊    | 37.7M/64.3M [00:27<00:16, 1.57MB/s]
 59%|█████▉    | 37.9M/64.3M [00:27<00:16, 1.60MB/s]
 59%|█████▉    | 38.1M/64.3M [00:27<00:16, 1.55MB/s]
 60%|█████▉    | 38.4M/64.3M [00:28<00:16, 1.58MB/s]
 60%|██████    | 38.6M/64.3M [00:28<00:16, 1.60MB/s]
 60%|██████    | 38.8M/64.3M [00:28<00:15, 1.62MB/s]
 61%|██████    | 39.0M/64.3M [00:28<00:16, 1.56MB/s]
 61%|██████    | 39.3M/64.3M [00:28<00:15, 1.59MB/s]
 61%|██████▏   | 39.5M/64.3M [00:28<00:15, 1.61MB/s]
 62%|██████▏   | 39.7M/64.3M [00:28<00:15, 1.63MB/s]
 62%|██████▏   | 39.9M/64.3M [00:28<00:14, 1.64MB/s]
 63%|██████▎   | 40.2M/64.3M [00:29<00:14, 1.65MB/s]
 63%|██████▎   | 40.4M/64.3M [00:29<00:14, 1.65MB/s]
 63%|██████▎   | 40.7M/64.3M [00:29<00:13, 1.72MB/s]
 64%|██████▎   | 40.9M/64.3M [00:29<00:13, 1.71MB/s]
 64%|██████▍   | 41.1M/64.3M [00:29<00:13, 1.69MB/s]
 64%|██████▍   | 41.4M/64.3M [00:29<00:13, 1.76MB/s]
 65%|██████▍   | 41.6M/64.3M [00:29<00:13, 1.73MB/s]
 65%|██████▌   | 41.9M/64.3M [00:30<00:12, 1.78MB/s]
 66%|██████▌   | 42.1M/64.3M [00:30<00:12, 1.82MB/s]
 66%|██████▌   | 42.4M/64.3M [00:30<00:11, 1.84MB/s]
 66%|██████▋   | 42.7M/64.3M [00:30<00:11, 1.86MB/s]
 67%|██████▋   | 42.9M/64.3M [00:30<00:11, 1.87MB/s]
 67%|██████▋   | 43.2M/64.3M [00:30<00:11, 1.88MB/s]
 68%|██████▊   | 43.5M/64.3M [00:30<00:11, 1.88MB/s]
 68%|██████▊   | 43.7M/64.3M [00:31<00:10, 1.96MB/s]
 69%|██████▊   | 44.0M/64.3M [00:31<00:10, 2.01MB/s]
 69%|██████▉   | 44.3M/64.3M [00:31<00:09, 2.05MB/s]
 69%|██████▉   | 44.6M/64.3M [00:31<00:09, 2.08MB/s]
 70%|██████▉   | 44.9M/64.3M [00:31<00:09, 2.09MB/s]
 70%|███████   | 45.3M/64.3M [00:31<00:08, 2.18MB/s]
 71%|███████   | 45.5M/64.3M [00:31<00:08, 2.17MB/s]
 71%|███████▏  | 45.9M/64.3M [00:32<00:08, 2.23MB/s]
 72%|███████▏  | 46.2M/64.3M [00:32<00:07, 2.27MB/s]
 72%|███████▏  | 46.5M/64.3M [00:32<00:07, 2.30MB/s]
 73%|███████▎  | 46.9M/64.3M [00:32<00:07, 2.39MB/s]
 73%|███████▎  | 47.2M/64.3M [00:32<00:07, 2.39MB/s]
 74%|███████▍  | 47.6M/64.3M [00:32<00:06, 2.45MB/s]
 75%|███████▍  | 47.9M/64.3M [00:32<00:06, 2.50MB/s]
 75%|███████▌  | 48.3M/64.3M [00:32<00:06, 2.53MB/s]
 76%|███████▌  | 48.7M/64.3M [00:33<00:05, 2.63MB/s]
 76%|███████▌  | 49.0M/64.3M [00:33<00:06, 2.47MB/s]
 77%|███████▋  | 49.3M/64.3M [00:33<00:06, 2.44MB/s]
 77%|███████▋  | 49.6M/64.3M [00:33<00:07, 1.98MB/s]
 77%|███████▋  | 49.8M/64.3M [00:33<00:07, 1.89MB/s]
 78%|███████▊  | 50.0M/64.3M [00:33<00:09, 1.55MB/s]
 78%|███████▊  | 50.2M/64.3M [00:34<00:09, 1.52MB/s]
 78%|███████▊  | 50.4M/64.3M [00:34<00:09, 1.43MB/s]
 79%|███████▊  | 50.5M/64.3M [00:34<00:10, 1.36MB/s]
 79%|███████▉  | 50.7M/64.3M [00:34<00:10, 1.31MB/s]
 79%|███████▉  | 50.9M/64.3M [00:34<00:10, 1.27MB/s]
 79%|███████▉  | 51.0M/64.3M [00:34<00:11, 1.18MB/s]
 80%|███████▉  | 51.2M/64.3M [00:34<00:11, 1.18MB/s]
 80%|███████▉  | 51.3M/64.3M [00:35<00:10, 1.18MB/s]
 80%|████████  | 51.5M/64.3M [00:35<00:10, 1.18MB/s]
 80%|████████  | 51.6M/64.3M [00:35<00:10, 1.18MB/s]
 81%|████████  | 51.8M/64.3M [00:35<00:09, 1.26MB/s]
 81%|████████  | 52.0M/64.3M [00:35<00:09, 1.23MB/s]
 81%|████████  | 52.2M/64.3M [00:35<00:09, 1.22MB/s]
 81%|████████▏ | 52.4M/64.3M [00:35<00:09, 1.28MB/s]
 82%|████████▏ | 52.5M/64.3M [00:36<00:09, 1.25MB/s]
 82%|████████▏ | 52.7M/64.3M [00:36<00:08, 1.30MB/s]
 82%|████████▏ | 52.9M/64.3M [00:36<00:08, 1.27MB/s]
 83%|████████▎ | 53.1M/64.3M [00:36<00:08, 1.31MB/s]
 83%|████████▎ | 53.2M/64.3M [00:36<00:08, 1.28MB/s]
 83%|████████▎ | 53.4M/64.3M [00:36<00:08, 1.32MB/s]
 83%|████████▎ | 53.6M/64.3M [00:36<00:08, 1.28MB/s]
 84%|████████▎ | 53.8M/64.3M [00:37<00:07, 1.32MB/s]
 84%|████████▍ | 54.0M/64.3M [00:37<00:08, 1.28MB/s]
 84%|████████▍ | 54.2M/64.3M [00:37<00:07, 1.32MB/s]
 85%|████████▍ | 54.3M/64.3M [00:37<00:07, 1.28MB/s]
 85%|████████▍ | 54.5M/64.3M [00:37<00:07, 1.25MB/s]
 85%|████████▌ | 54.7M/64.3M [00:37<00:07, 1.30MB/s]
 85%|████████▌ | 54.9M/64.3M [00:37<00:07, 1.27MB/s]
 86%|████████▌ | 55.1M/64.3M [00:37<00:07, 1.31MB/s]
 86%|████████▌ | 55.2M/64.3M [00:38<00:07, 1.28MB/s]
 86%|████████▌ | 55.4M/64.3M [00:38<00:06, 1.32MB/s]
 86%|████████▋ | 55.6M/64.3M [00:38<00:06, 1.28MB/s]
 87%|████████▋ | 55.8M/64.3M [00:38<00:06, 1.32MB/s]
 87%|████████▋ | 55.9M/64.3M [00:38<00:06, 1.28MB/s]
 87%|████████▋ | 56.1M/64.3M [00:38<00:06, 1.32MB/s]
 88%|████████▊ | 56.3M/64.3M [00:38<00:06, 1.28MB/s]
 88%|████████▊ | 56.5M/64.3M [00:39<00:05, 1.32MB/s]
 88%|████████▊ | 56.7M/64.3M [00:39<00:05, 1.28MB/s]
 88%|████████▊ | 56.9M/64.3M [00:39<00:05, 1.33MB/s]
 89%|████████▊ | 57.0M/64.3M [00:39<00:05, 1.28MB/s]
 89%|████████▉ | 57.2M/64.3M [00:39<00:05, 1.33MB/s]
 89%|████████▉ | 57.4M/64.3M [00:39<00:05, 1.36MB/s]
 90%|████████▉ | 57.6M/64.3M [00:39<00:05, 1.30MB/s]
 90%|████████▉ | 57.8M/64.3M [00:40<00:04, 1.34MB/s]
 90%|█████████ | 58.0M/64.3M [00:40<00:04, 1.37MB/s]
 90%|█████████ | 58.2M/64.3M [00:40<00:04, 1.38MB/s]
 91%|█████████ | 58.3M/64.3M [00:40<00:04, 1.32MB/s]
 91%|█████████ | 58.5M/64.3M [00:40<00:04, 1.35MB/s]
 91%|█████████▏| 58.7M/64.3M [00:40<00:04, 1.38MB/s]
 92%|█████████▏| 58.9M/64.3M [00:40<00:03, 1.46MB/s]
 92%|█████████▏| 59.1M/64.3M [00:41<00:03, 1.45MB/s]
 92%|█████████▏| 59.3M/64.3M [00:41<00:03, 1.44MB/s]
 93%|█████████▎| 59.5M/64.3M [00:41<00:03, 1.44MB/s]
 93%|█████████▎| 59.7M/64.3M [00:41<00:03, 1.44MB/s]
 93%|█████████▎| 60.0M/64.3M [00:41<00:02, 1.50MB/s]
 94%|█████████▎| 60.2M/64.3M [00:41<00:02, 1.55MB/s]
 94%|█████████▍| 60.4M/64.3M [00:41<00:02, 1.51MB/s]
 94%|█████████▍| 60.6M/64.3M [00:41<00:02, 1.56MB/s]
 95%|█████████▍| 60.9M/64.3M [00:42<00:02, 1.59MB/s]
 95%|█████████▌| 61.1M/64.3M [00:42<00:01, 1.68MB/s]
 95%|█████████▌| 61.3M/64.3M [00:42<00:01, 1.67MB/s]
 96%|█████████▌| 61.6M/64.3M [00:42<00:01, 1.67MB/s]
 96%|█████████▌| 61.8M/64.3M [00:42<00:01, 1.74MB/s]
 97%|█████████▋| 62.1M/64.3M [00:42<00:01, 1.79MB/s]
 97%|█████████▋| 62.4M/64.3M [00:42<00:01, 1.82MB/s]
 97%|█████████▋| 62.6M/64.3M [00:43<00:00, 1.85MB/s]
 98%|█████████▊| 62.9M/64.3M [00:43<00:00, 1.93MB/s]
 98%|█████████▊| 63.2M/64.3M [00:43<00:00, 1.92MB/s]
 99%|█████████▊| 63.5M/64.3M [00:43<00:00, 1.99MB/s]
 99%|█████████▉| 63.8M/64.3M [00:43<00:00, 2.03MB/s]
100%|█████████▉| 64.1M/64.3M [00:43<00:00, 2.13MB/s]
100%|██████████| 64.3M/64.3M [00:43<00:00, 1.47MB/s]

3. Define the Lightning Trainer#

trainer = TUTrainer(accelerator="gpu", devices=1, max_epochs=5, enable_progress_bar=False)

4. Function to Visualize the Prediction Sets#

def visualize_prediction_sets(inputs, labels, confidence_scores, classes, num_examples=5) -> None:
    _, axs = plt.subplots(2, num_examples, figsize=(15, 5))
    for i in range(num_examples):
        ax = axs[0, i]
        img = np.clip(
            inputs[i].permute(1, 2, 0).cpu().numpy() * datamodule.std + datamodule.mean, 0, 1
        )
        ax.imshow(img)
        ax.set_title(f"True: {classes[labels[i]]}")
        ax.axis("off")
        ax = axs[1, i]
        for j in range(len(classes)):
            ax.barh(classes[j], confidence_scores[i, j], color="blue")
        ax.set_xlim(0, 1)
        ax.set_xlabel("Confidence Score")
    plt.tight_layout()
    plt.show()

5. Estimate Prediction Sets with ConformalClsTHR#

Using alpha=0.01, we aim for a 1% error rate.

print("[Phase 2]: ConformalClsTHR calibration")
conformal_model = ConformalClsTHR(alpha=0.01, device="cuda")

routine_thr = ClassificationRoutine(
    num_classes=10,
    model=model,
    loss=None,  # No loss needed for evaluation
    eval_ood=True,
    post_processing=conformal_model,
    ood_criterion="post_processing",
)
perf_thr = trainer.test(routine_thr, datamodule=datamodule)
[Phase 2]: ConformalClsTHR calibration

  0%|          | 0/79 [00:00<?, ?it/s]
  1%|▏         | 1/79 [00:00<00:17,  4.49it/s]
 10%|█         | 8/79 [00:00<00:02, 29.27it/s]
 18%|█▊        | 14/79 [00:00<00:01, 39.80it/s]
 25%|██▌       | 20/79 [00:00<00:01, 46.45it/s]
 33%|███▎      | 26/79 [00:00<00:01, 50.77it/s]
 41%|████      | 32/79 [00:00<00:00, 53.67it/s]
 48%|████▊     | 38/79 [00:00<00:00, 55.60it/s]
 56%|█████▌    | 44/79 [00:00<00:00, 56.87it/s]
 63%|██████▎   | 50/79 [00:01<00:00, 57.77it/s]
 71%|███████   | 56/79 [00:01<00:00, 58.41it/s]
 78%|███████▊  | 62/79 [00:01<00:00, 58.73it/s]
 86%|████████▌ | 68/79 [00:01<00:00, 59.07it/s]
 94%|█████████▎| 74/79 [00:01<00:00, 59.28it/s]
100%|██████████| 79/79 [00:01<00:00, 51.29it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26406          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.537%           │
│     MCE      │          23.671%          │
│    SmECE     │          10.143%          │
│     aECE     │          3.500%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          86.584%          │
│    AUROC     │          79.252%          │
│   Entropy    │          0.08849          │
│    FPR95     │         100.000%          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.959%           │
│  Cov@5Risk   │          96.510%          │
│  Risk@80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Post-Processing      ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │          0.99000          │
│   SetSize    │          1.52330          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Complexity         ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    flops     │         142.19 G          │
│    params    │          11.17 M          │
└──────────────┴───────────────────────────┘

6. Visualization of ConformalClsTHR prediction sets#

inputs, labels = next(iter(datamodule.test_dataloader()[0]))

conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())

classes = datamodule.test.classes

visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)
True: cat, True: ship, True: ship, True: airplane, True: frog

7. Estimate Prediction Sets with ConformalClsAPS#

print("[Phase 3]: ConformalClsAPS calibration")
conformal_model = ConformalClsAPS(alpha=0.01, device="cuda", enable_ts=True)

routine_aps = ClassificationRoutine(
    num_classes=10,
    model=model,
    loss=None,  # No loss needed for evaluation
    eval_ood=True,
    post_processing=conformal_model,
    ood_criterion="post_processing",
)
perf_aps = trainer.test(routine_aps, datamodule=datamodule)
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)
True: cat, True: ship, True: ship, True: airplane, True: frog
[Phase 3]: ConformalClsAPS calibration

  0%|          | 0/79 [00:00<?, ?it/s]
  1%|▏         | 1/79 [00:00<00:15,  4.88it/s]
 10%|█         | 8/79 [00:00<00:02, 29.50it/s]
 18%|█▊        | 14/79 [00:00<00:01, 39.57it/s]
 25%|██▌       | 20/79 [00:00<00:01, 45.81it/s]
 33%|███▎      | 26/79 [00:00<00:01, 50.22it/s]
 41%|████      | 32/79 [00:00<00:00, 53.13it/s]
 48%|████▊     | 38/79 [00:00<00:00, 55.12it/s]
 56%|█████▌    | 44/79 [00:00<00:00, 56.46it/s]
 63%|██████▎   | 50/79 [00:01<00:00, 57.40it/s]
 71%|███████   | 56/79 [00:01<00:00, 58.04it/s]
 78%|███████▊  | 62/79 [00:01<00:00, 58.47it/s]
 86%|████████▌ | 68/79 [00:01<00:00, 58.78it/s]
 94%|█████████▎| 74/79 [00:01<00:00, 58.98it/s]
100%|██████████| 79/79 [00:01<00:00, 51.87it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26406          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.537%           │
│     MCE      │          23.671%          │
│    SmECE     │          10.143%          │
│     aECE     │          3.500%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          82.248%          │
│    AUROC     │          73.062%          │
│   Entropy    │          0.08849          │
│    FPR95     │         100.000%          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.959%           │
│  Cov@5Risk   │          96.510%          │
│  Risk@80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Post-Processing      ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │          0.99380          │
│   SetSize    │          2.28630          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Complexity         ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    flops     │         142.19 G          │
│    params    │          11.17 M          │
└──────────────┴───────────────────────────┘

8. Estimate Prediction Sets with ConformalClsRAPS#

print("[Phase 4]: ConformalClsRAPS calibration")
conformal_model = ConformalClsRAPS(
    alpha=0.01, regularization_rank=3, penalty=0.002, model=model, device="cuda", enable_ts=True
)

routine_raps = ClassificationRoutine(
    num_classes=10,
    model=model,
    loss=None,  # No loss needed for evaluation
    eval_ood=True,
    post_processing=conformal_model,
    ood_criterion="post_processing",
)
perf_raps = trainer.test(routine_raps, datamodule=datamodule)
conformal_model.cuda()
confidence_scores = conformal_model.conformal(inputs.cuda())
visualize_prediction_sets(inputs, labels, confidence_scores[:5].cpu(), classes)
True: cat, True: ship, True: ship, True: airplane, True: frog
[Phase 4]: ConformalClsRAPS calibration

  0%|          | 0/79 [00:00<?, ?it/s]
  1%|▏         | 1/79 [00:00<00:15,  4.94it/s]
  9%|▉         | 7/79 [00:00<00:02, 27.45it/s]
 16%|█▋        | 13/79 [00:00<00:01, 39.02it/s]
 24%|██▍       | 19/79 [00:00<00:01, 45.96it/s]
 32%|███▏      | 25/79 [00:00<00:01, 50.29it/s]
 39%|███▉      | 31/79 [00:00<00:00, 53.12it/s]
 47%|████▋     | 37/79 [00:00<00:00, 55.01it/s]
 54%|█████▍    | 43/79 [00:00<00:00, 56.28it/s]
 62%|██████▏   | 49/79 [00:01<00:00, 57.01it/s]
 70%|██████▉   | 55/79 [00:01<00:00, 57.79it/s]
 77%|███████▋  | 61/79 [00:01<00:00, 58.24it/s]
 85%|████████▍ | 67/79 [00:01<00:00, 58.55it/s]
 92%|█████████▏| 73/79 [00:01<00:00, 58.72it/s]
100%|██████████| 79/79 [00:01<00:00, 58.78it/s]
100%|██████████| 79/79 [00:01<00:00, 51.82it/s]
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26406          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.537%           │
│     MCE      │          23.671%          │
│    SmECE     │          10.143%          │
│     aECE     │          3.500%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          82.782%          │
│    AUROC     │          73.599%          │
│   Entropy    │          0.08849          │
│    FPR95     │         100.000%          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.959%           │
│  Cov@5Risk   │          96.510%          │
│  Risk@80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Post-Processing      ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CoverageRate │          0.99320          │
│   SetSize    │          2.11490          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Complexity         ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    flops     │         142.19 G          │
│    params    │          11.17 M          │
└──────────────┴───────────────────────────┘

Summary#

In this tutorial, we explored how to apply conformal prediction to a pretrained ResNet on CIFAR-10. We evaluated three methods: Thresholding (THR), Adaptive Prediction Sets (APS), and Regularized APS (RAPS). For each, we calibrated on a validation set, evaluated OOD performance, and visualized prediction sets. You can explore further by adjusting alpha, changing the model, or testing on other datasets.

Total running time of the script: (1 minutes 39.758 seconds)

Gallery generated by Sphinx-Gallery