Out-of-distribution detection with TorchUncertainty#

This tutorial demonstrates how to perform OOD detection using TorchUncertainty’s ClassificationRoutine with a ResNet18 model trained on CIFAR-10, evaluating its performance with SVHN as the OOD dataset.

We will:

  • Set up the CIFAR-10 datamodule.

  • Initialize and shortly train a ResNet18 model using the ClassificationRoutine.

  • Evaluate the model’s performance on both in-distribution and out-of-distribution data.

  • Analyze uncertainty metrics for OOD detection.

Imports and Setup#

First, we need to import the necessary libraries and set up our environment. This includes importing PyTorch, TorchUncertainty components, and TorchUncertainty’s Trainer (built on top of Lightning’s), as well as two criteria for OOD detection, the maximum softmax probability [1] and the Max Logit [2].

from torch import nn, optim

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.models.classification.resnet import resnet
from torch_uncertainty.ood_criteria import MaxLogitCriterion, MaxSoftmaxCriterion
from torch_uncertainty.routines.classification import ClassificationRoutine

DataModule Setup#

TorchUncertainty provides convenient DataModules for standard datasets like CIFAR-10. DataModules handle data loading, preprocessing, and batching, simplifying the data pipeline. Each datamodule also include the corresponding out-of-distribution and distribution shift datasets, which are then used by the routine. For CIFAR-10, the corresponding OOD-detection dataset is SVHN as used in the community. To enable OOD evaluation, activate the eval_ood flag as done below.

datamodule = CIFAR10DataModule(root="./data", batch_size=512, num_workers=8, eval_ood=True)

Model Initialization#

We use the ResNet18 architecture, a widely adopted convolutional neural network known for its deep residual learning capabilities. The model is initialized with 10 output classes corresponding to the CIFAR-10 dataset categories. When training on CIFAR, do not forget to set the style of the resnet to CIFAR, otherwise it will lose more information in the first convolution.

# Initialize the ResNet18 model
model = resnet(arch=18, in_channels=3, num_classes=10, style="cifar", conv_bias=False)

Define the Classification Routine#

The ClassificationRoutine is one of the most crucial building blocks in TorchUncertainty. It streamlines the training and evaluation processes. It integrates the model, loss function, and optimizer into a cohesive routine compatible with PyTorch Lightning’s Trainer. This abstraction simplifies the implementation of standard training loops and evaluation protocols. To come back to what matters in this tutorial, the routine also handles OOD detection. To enable it, just activate the eval_ood flag. Note that you can also evaluate the distribution-shift performance of the model at the same time by also setting eval_shift to True.

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize the ClassificationRoutine, you could replace MaxSoftmaxCriterion by "msp"
routine = ClassificationRoutine(
    model=model,
    num_classes=10,
    loss=criterion,
    optim_recipe=optimizer,
    eval_ood=True,
    ood_criterion=MaxSoftmaxCriterion,
)
/home/chocolatine/actions-runner/_work/_tool/Python/3.11.12/x64/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `FPR95` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)

Test the Training of the Model#

With the routine defined, we can now set up the Trainer and commence training. The Trainer handles the training loop, including epoch management, logging, and checkpointing. We specify the maximum number of epochs, the precision and the device to be used. To reduce the tutorial building time, we will train for a single epoch and load a model from TorchUncertainty’s HuggingFace.

# Initialize the TUTrainer
trainer = TUTrainer(
    max_epochs=1, precision="16-mixed", accelerator="cuda", devices=1, enable_progress_bar=False
)

# Train the model for 1 epoch using the CIFAR-10 DataModule
trainer.fit(routine, datamodule=datamodule)
  0%|          | 0.00/64.3M [00:00<?, ?B/s]
  0%|          | 32.8k/64.3M [00:00<04:26, 241kB/s]
  0%|          | 65.5k/64.3M [00:00<04:28, 239kB/s]
  0%|          | 98.3k/64.3M [00:00<04:27, 240kB/s]
  0%|          | 131k/64.3M [00:00<04:27, 240kB/s]
  0%|          | 197k/64.3M [00:00<03:16, 326kB/s]
  0%|          | 295k/64.3M [00:00<02:19, 459kB/s]
  1%|          | 426k/64.3M [00:00<01:42, 621kB/s]
  1%|          | 557k/64.3M [00:01<01:27, 728kB/s]
  1%|          | 754k/64.3M [00:01<01:06, 949kB/s]
  2%|▏         | 1.05M/64.3M [00:01<00:47, 1.32MB/s]
  2%|▏         | 1.41M/64.3M [00:01<00:36, 1.72MB/s]
  3%|▎         | 1.90M/64.3M [00:01<00:27, 2.29MB/s]
  4%|▍         | 2.62M/64.3M [00:01<00:19, 3.19MB/s]
  6%|▌         | 3.57M/64.3M [00:01<00:14, 4.32MB/s]
  7%|▋         | 4.82M/64.3M [00:02<00:10, 5.74MB/s]
 10%|▉         | 6.39M/64.3M [00:02<00:07, 7.46MB/s]
 13%|█▎        | 8.32M/64.3M [00:02<00:05, 9.49MB/s]
 17%|█▋        | 10.7M/64.3M [00:02<00:04, 11.9MB/s]
 21%|██        | 13.5M/64.3M [00:02<00:03, 14.8MB/s]
 26%|██▌       | 16.7M/64.3M [00:02<00:02, 18.7MB/s]
 30%|██▉       | 19.1M/64.3M [00:02<00:02, 19.7MB/s]
 36%|███▌      | 22.8M/64.3M [00:02<00:01, 22.1MB/s]
 41%|████      | 26.1M/64.3M [00:03<00:01, 24.7MB/s]
 45%|████▍     | 28.6M/64.3M [00:03<00:01, 22.8MB/s]
 48%|████▊     | 31.0M/64.3M [00:03<00:02, 16.0MB/s]
 53%|█████▎    | 34.0M/64.3M [00:03<00:01, 17.6MB/s]
 56%|█████▋    | 36.3M/64.3M [00:03<00:01, 17.4MB/s]
 60%|██████    | 38.6M/64.3M [00:03<00:01, 17.2MB/s]
 64%|██████▎   | 41.0M/64.3M [00:03<00:01, 17.2MB/s]
 67%|██████▋   | 43.4M/64.3M [00:04<00:01, 17.3MB/s]
 71%|███████   | 45.8M/64.3M [00:04<00:01, 17.4MB/s]
 75%|███████▍  | 48.2M/64.3M [00:04<00:00, 17.5MB/s]
 79%|███████▉  | 50.7M/64.3M [00:04<00:00, 17.7MB/s]
 83%|████████▎ | 53.2M/64.3M [00:04<00:00, 17.8MB/s]
 87%|████████▋ | 55.7M/64.3M [00:04<00:00, 18.7MB/s]
 91%|█████████ | 58.2M/64.3M [00:04<00:00, 18.6MB/s]
 95%|█████████▍| 60.8M/64.3M [00:05<00:00, 20.3MB/s]
 98%|█████████▊| 63.0M/64.3M [00:05<00:00, 20.9MB/s]
100%|██████████| 64.3M/64.3M [00:05<00:00, 12.3MB/s]

Load the model from HuggingFace#

We simply download a ResNet-18 trained on CIFAR-10 from TorchUncertainty’s HuggingFace and load it with the load_from_checkpoint method.

import torch
from huggingface_hub import hf_hub_download

path = hf_hub_download(
    repo_id="torch-uncertainty/resnet18_c10",
    filename="resnet18_c10.ckpt",
)
state_dict = torch.load(path, map_location="cpu", weights_only=True)
routine.model.load_state_dict(state_dict)
<All keys matched successfully>

Evaluating on In-Distribution and Out-of-distribution Data#

Now that the model is trained, we can evaluate its performance on the original in-distribution test set, as well as the OOD set. Typing the next line will automatically compute the in-distribution and OOD detection metrics.

# Evaluate the model on the CIFAR-10 (IID) and SVHN (OOD) test sets
results = trainer.test(routine, datamodule=datamodule)
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26408          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.546%           │
│     aECE     │          3.499%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          90.244%          │
│    AUROC     │          82.968%          │
│   Entropy    │          0.08849          │
│    FPR95     │          56.060%          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.960%           │
│  Cov@5Risk   │          96.520%          │
│  Risk@80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘

Changing the OOD Criterion#

The previous metrics for Out-of-distribution detection have been computed using the maximum softmax probability score [1], which corresponds to the likelihood of the prediction. We could use other scores such as the maximum logit [2]. To do this, just change the routine’s ood_criterion and perform a second test.

routine.ood_criterion = MaxLogitCriterion()

results = trainer.test(routine, datamodule=datamodule)
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26408          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.546%           │
│     aECE     │          3.499%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          85.143%          │
│    AUROC     │          73.895%          │
│   Entropy    │          0.08849          │
│    FPR95     │          79.620%          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.960%           │
│  Cov@5Risk   │          96.520%          │
│  Risk@80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘

Note that you could create your own class if you want to implement a custom OOD detection score. When changing the Out-of-distribution criterion, all the In-distribution metrics remain the same. The only values that change are those of the regrouped in the OOD Detection category. Here we see that the AUPR, AUROC and FPR95 are worse using the maximum logit score compared to the maximum softmax probability but it could depend on the model you are using.

References#

[1] Hendrycks, D., & Gimpel, K. (2016). A baseline for detecting misclassified and out-of-distribution examples in neural networks. In ICLR 2017. [2] Hendrycks, D., Basart, S., Mazeika, M., Zou, A., Kwon, J., Mostajabi, M., … & Song, D. (2019). Scaling out-of-distribution detection for real-world settings. In ICML 2022.

Total running time of the script: (0 minutes 34.373 seconds)

Gallery generated by Sphinx-Gallery