DEUP: Direct Epistemic Uncertainty Prediction with TorchUncertainty#

DEUP estimates the epistemic component of uncertainty by training a lightweight error-predictor g on out-of-fold generalization errors collected from a held-out calibration set (Algorithm 2 in Lahlou et al. 2023). Once fitted, g(x) returns a non-negative score: higher means “the base model is more likely to be wrong here.”

This tutorial has two parts:

  1. Synthetic walkthrough - illustrates the DEUP API on random tabular data.

  2. CIFAR-10 + ClassificationRoutine - integrates DEUP with a pretrained ResNet-18 for OOD detection against SVHN, the standard CIFAR-10 OOD benchmark.

How DEUP works:#

  1. Run the (already-trained) base model on a held-out calibration set and compute per-sample errors (cross-entropy for classification, squared error for regression).

  2. Perform K-fold cross-validation on those calibration errors to train K lightweight error-predictor MLPs. The out-of-fold predictions become the targets for the final predictor, acting as generalization-error proxies.

  3. Train the final error predictor g on all calibration features to predict those OOF targets. At inference, g(x) 0 is the epistemic uncertainty estimate.

Reference:

Lahlou et al. (2023). DEUP: Direct Epistemic Uncertainty Prediction. TMLR. https://openreview.net/forum?id=eGLdVRvvfQ

1. Imports#

import os

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from torch_uncertainty.post_processing import DEUP

2. Synthetic classification task#

We create a small random dataset to illustrate the DEUP API without any expensive training. In practice the base model should be pre-trained; here we use random weights just to show the interface.

torch.manual_seed(0)
n_cal, in_dim, n_classes = 100, 8, 5

x_cal = torch.randn(n_cal, in_dim)
y_cal = torch.randint(0, n_classes, (n_cal,))
cal_loader = DataLoader(TensorDataset(x_cal, y_cal), batch_size=32)

model = nn.Sequential(nn.Linear(in_dim, 32), nn.ReLU(), nn.Linear(32, n_classes))
# In a real use-case, load a pre-trained model here.

3. Fit DEUP on the calibration split#

DEUP.fit collects per-sample cross-entropy errors from the base model on the calibration loader, runs K-fold cross-validation to build OOF error estimates, and trains the final error predictor g on those estimates.

deup = DEUP(
    task="classification",
    model=model,
    num_folds=5,
    hidden_dim=32,
    max_epochs=30,
    device="cpu",
)
deup.fit(cal_loader)

4. Epistemic uncertainty at inference#

deup(x) takes raw inputs, runs them through the base model to extract features, and returns a non-negative epistemic score per sample. Higher scores signal that the base model is more likely to be unreliable on that input.

x_test = torch.randn(10, in_dim)
uncertainty = deup(x_test)
print("Epistemic scores g(x):", uncertainty)

probs = deup.predict_proba(x_test)
print("Base-model likelihood shape:", probs.shape)
Epistemic scores g(x): tensor([0.8570, 0.8676, 0.8819, 0.8829, 0.8861, 0.8889, 0.8522, 0.8715, 0.8617,
        0.8622])
Base-model likelihood shape: torch.Size([10, 5])

5. Apply DEUP on CIFAR-10 with the ClassificationRoutine#

In this section we integrate DEUP with TorchUncertainty’s ClassificationRoutine to evaluate OOD detection performance on CIFAR-10 versus SVHN.

The routine fits the error predictor automatically on the validation split at the start of trainer.test(), so no manual fit() call is needed here.

import torch
from huggingface_hub import hf_hub_download

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.models.classification.resnet import resnet
from torch_uncertainty.post_processing import DEUP
from torch_uncertainty.routines import ClassificationRoutine

6. Load a pretrained ResNet-18 from Hugging Face#

We use a CIFAR-style ResNet-18 (3x3 first convolution, no max-pooling) from TorchUncertainty’s HuggingFace hub. The CIFAR-style variant preserves more spatial information on small 32x32 images than the standard ImageNet variant.

cifar_model = resnet(in_channels=3, num_classes=10, arch=18, style="cifar", conv_bias=False)

ckpt_path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c10", filename="resnet18_c10.ckpt")
weights = torch.load(ckpt_path, map_location="cpu", weights_only=True)
cifar_model.load_state_dict(weights)
cifar_model = cifar_model.cuda().eval()

7. DataModule and Trainer#

CIFAR10DataModule with eval_ood=True automatically provides the SVHN out-of-distribution test loader alongside the standard CIFAR-10 test loader.

The key argument here is postprocess_set="val": it tells the routine to fit the DEUP error predictor on the validation split rather than the test set, avoiding any data leakage. We reserve 10% of the training set as validation via val_split=0.1. The routine automatically builds this split before fitting DEUP, so no manual setup call is needed.

datamodule = CIFAR10DataModule(
    root=os.environ.get("TU_DATA_DIR", "data"),
    batch_size=256,
    num_workers=4,
    eval_ood=True,
    val_split=0.1,
    postprocess_set="val",
)

trainer = TUTrainer(
    accelerator="gpu",
    devices=1,
    max_epochs=1,
    enable_progress_bar=True,
)

8. ClassificationRoutine with DEUP post-processing#

Passing post_processing=deup_c10 and ood_criterion="deup" wires DEUP into the routine:

  • At the start of trainer.test(), the routine calls deup_c10.fit() on the validation dataloader (postprocess_set="val").

  • During the test loop, the DEUP epistemic score is used as the OOD detection criterion: higher score ⟹ more likely OOD.

No loss or optimizer is needed since we are only running evaluation.

deup_c10 = DEUP(
    task="classification",
    hidden_dim=64,
    max_epochs=50,
    device="cuda",
)

routine = ClassificationRoutine(
    model=cifar_model,
    num_classes=10,
    loss=None,
    eval_ood=True,
    post_processing=deup_c10,
    ood_criterion="deup",
)

9. Evaluate OOD detection with DEUP#

trainer.test() runs the following sequence automatically:

  1. Fit the DEUP error predictor on the CIFAR-10 validation split.

  2. Compute in-distribution classification metrics on the CIFAR-10 test set.

  3. Compute OOD detection metrics (AUROC, AUPR, FPR95) using SVHN as OOD data.

OOD detection results are reported under the ood/ prefix.

results_deup = trainer.test(routine, datamodule=datamodule)
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26405          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.537%           │
│     MCE      │          23.670%          │
│    SmECE     │          10.143%          │
│     aECE     │          3.500%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          81.944%          │
│    AUROC     │          67.990%          │
│   Entropy    │          0.35549          │
│    FPR95     │          79.300%          │
│  SCOD_AUGRC  │          0.32517          │
│  SCOD_AURC   │          0.58219          │
│ SCOD_Cov_5R… │            nan            │
│ SCOD_Risk_8… │          0.69073          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.959%           │
│  Cov_5Risk   │          96.510%          │
│  Risk_80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Complexity         ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    flops     │         284.38 G          │
│    params    │          11.17 M          │
└──────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102/102 0:00:06 • 0:00:00 14.63it/s

10. Compare with the Maximum Softmax Probability baseline#

We can swap the OOD criterion to compare DEUP against the Maximum Softmax Probability (MSP) baseline — no re-training or re-fitting required. Only the OOD detection scores change; in-distribution metrics remain identical.

from torch_uncertainty.ood_criteria import MaxSoftmaxCriterion

routine.ood_criterion = MaxSoftmaxCriterion()

results_msp = trainer.test(routine, datamodule=datamodule)
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          93.380%          │
│    Brier     │          0.10812          │
│   Entropy    │          0.08849          │
│     NLL      │          0.26405          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          3.537%           │
│     MCE      │          23.670%          │
│    SmECE     │          10.143%          │
│     aECE     │          3.500%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃       OOD Detection       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     AUPR     │          90.246%          │
│    AUROC     │          82.969%          │
│   Entropy    │          0.35549          │
│    FPR95     │          56.050%          │
│  SCOD_AUGRC  │          0.29513          │
│  SCOD_AURC   │          0.47860          │
│ SCOD_Cov_5R… │            nan            │
│ SCOD_Risk_8… │          0.67224          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │          0.779%           │
│     AURC     │          0.959%           │
│  Cov_5Risk   │          96.510%          │
│  Risk_80Cov  │          1.200%           │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Complexity         ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    flops     │         284.38 G          │
│    params    │          11.17 M          │
└──────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102/102 0:00:06 • 0:00:00 14.58it/s

The ood/ rows in both tables allow a direct comparison between the DEUP epistemic score and the MSP confidence score as OOD detectors on a well-trained ResNet-18. DEUP is expected to complement confidence-based criteria by focusing on the epistemic component of the model’s uncertainty.

References#

  • DEUP: Lahlou, S., Jain, M., Nekoei, H., Butoi, V. I., Bertin, P., Rector-Brooks, J., … & Bengio, Y. (2023). DEUP: Direct Epistemic Uncertainty Prediction. TMLR. openreview.

  • ResNet: He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. CVPR 2016.

  • MSP baseline: Hendrycks, D., & Gimpel, K. (2017). A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks. ICLR 2017.

Total running time of the script: (0 minutes 26.760 seconds)

Gallery generated by Sphinx-Gallery