Note
Go to the end to download the full example code.
DEUP: Direct Epistemic Uncertainty Prediction with TorchUncertainty#
DEUP estimates the epistemic component of uncertainty by training a lightweight
error-predictor g on out-of-fold generalization errors collected from a held-out
calibration set (Algorithm 2 in Lahlou et al. 2023). Once fitted, g(x) returns
a non-negative score: higher means “the base model is more likely to be wrong here.”
This tutorial has two parts:
Synthetic walkthrough - illustrates the DEUP API on random tabular data.
CIFAR-10 + ClassificationRoutine - integrates DEUP with a pretrained ResNet-18 for OOD detection against SVHN, the standard CIFAR-10 OOD benchmark.
How DEUP works:#
Run the (already-trained) base model on a held-out calibration set and compute per-sample errors (cross-entropy for classification, squared error for regression).
Perform K-fold cross-validation on those calibration errors to train K lightweight error-predictor MLPs. The out-of-fold predictions become the targets for the final predictor, acting as generalization-error proxies.
Train the final error predictor
gon all calibration features to predict those OOF targets. At inference,g(x) ≥ 0is the epistemic uncertainty estimate.
- Reference:
Lahlou et al. (2023). DEUP: Direct Epistemic Uncertainty Prediction. TMLR. https://openreview.net/forum?id=eGLdVRvvfQ
1. Imports#
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch_uncertainty.post_processing import DEUP
2. Synthetic classification task#
We create a small random dataset to illustrate the DEUP API without any expensive training. In practice the base model should be pre-trained; here we use random weights just to show the interface.
torch.manual_seed(0)
n_cal, in_dim, n_classes = 100, 8, 5
x_cal = torch.randn(n_cal, in_dim)
y_cal = torch.randint(0, n_classes, (n_cal,))
cal_loader = DataLoader(TensorDataset(x_cal, y_cal), batch_size=32)
model = nn.Sequential(nn.Linear(in_dim, 32), nn.ReLU(), nn.Linear(32, n_classes))
# In a real use-case, load a pre-trained model here.
3. Fit DEUP on the calibration split#
DEUP.fit collects per-sample cross-entropy errors from the base model on the
calibration loader, runs K-fold cross-validation to build OOF error estimates, and
trains the final error predictor g on those estimates.
deup = DEUP(
task="classification",
model=model,
num_folds=5,
hidden_dim=32,
max_epochs=30,
device="cpu",
)
deup.fit(cal_loader)
4. Epistemic uncertainty at inference#
deup(x) takes raw inputs, runs them through the base model to extract features,
and returns a non-negative epistemic score per sample.
Higher scores signal that the base model is more likely to be unreliable on that input.
x_test = torch.randn(10, in_dim)
uncertainty = deup(x_test)
print("Epistemic scores g(x):", uncertainty)
probs = deup.predict_proba(x_test)
print("Base-model likelihood shape:", probs.shape)
Epistemic scores g(x): tensor([0.8570, 0.8676, 0.8819, 0.8829, 0.8861, 0.8889, 0.8522, 0.8715, 0.8617,
0.8622])
Base-model likelihood shape: torch.Size([10, 5])
5. Apply DEUP on CIFAR-10 with the ClassificationRoutine#
In this section we integrate DEUP with TorchUncertainty’s
ClassificationRoutine to evaluate OOD
detection performance on CIFAR-10 versus SVHN.
The routine fits the error predictor automatically on the validation
split at the start of trainer.test(), so no manual fit() call is
needed here.
import torch
from huggingface_hub import hf_hub_download
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.models.classification.resnet import resnet
from torch_uncertainty.post_processing import DEUP
from torch_uncertainty.routines import ClassificationRoutine
6. Load a pretrained ResNet-18 from Hugging Face#
We use a CIFAR-style ResNet-18 (3x3 first convolution, no max-pooling) from TorchUncertainty’s HuggingFace hub. The CIFAR-style variant preserves more spatial information on small 32x32 images than the standard ImageNet variant.
cifar_model = resnet(in_channels=3, num_classes=10, arch=18, style="cifar", conv_bias=False)
ckpt_path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c10", filename="resnet18_c10.ckpt")
weights = torch.load(ckpt_path, map_location="cpu", weights_only=True)
cifar_model.load_state_dict(weights)
cifar_model = cifar_model.cuda().eval()
7. DataModule and Trainer#
CIFAR10DataModule with eval_ood=True
automatically provides the SVHN out-of-distribution test loader alongside the
standard CIFAR-10 test loader.
The key argument here is postprocess_set="val": it tells the routine to fit
the DEUP error predictor on the validation split rather than the test set,
avoiding any data leakage. We reserve 10% of the training set as validation via
val_split=0.1. The routine automatically builds this split before fitting
DEUP, so no manual setup call is needed.
datamodule = CIFAR10DataModule(
root=os.environ.get("TU_DATA_DIR", "data"),
batch_size=256,
num_workers=4,
eval_ood=True,
val_split=0.1,
postprocess_set="val",
)
trainer = TUTrainer(
accelerator="gpu",
devices=1,
max_epochs=1,
enable_progress_bar=True,
)
8. ClassificationRoutine with DEUP post-processing#
Passing post_processing=deup_c10 and ood_criterion="deup" wires DEUP
into the routine:
At the start of
trainer.test(), the routine callsdeup_c10.fit()on the validation dataloader (postprocess_set="val").During the test loop, the DEUP epistemic score is used as the OOD detection criterion: higher score ⟹ more likely OOD.
No loss or optimizer is needed since we are only running evaluation.
deup_c10 = DEUP(
task="classification",
hidden_dim=64,
max_epochs=50,
device="cuda",
)
routine = ClassificationRoutine(
model=cifar_model,
num_classes=10,
loss=None,
eval_ood=True,
post_processing=deup_c10,
ood_criterion="deup",
)
9. Evaluate OOD detection with DEUP#
trainer.test() runs the following sequence automatically:
Fit the DEUP error predictor on the CIFAR-10 validation split.
Compute in-distribution classification metrics on the CIFAR-10 test set.
Compute OOD detection metrics (AUROC, AUPR, FPR95) using SVHN as OOD data.
OOD detection results are reported under the ood/ prefix.
results_deup = trainer.test(routine, datamodule=datamodule)
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26405 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.670% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 81.944% │
│ AUROC │ 67.990% │
│ Entropy │ 0.35549 │
│ FPR95 │ 79.300% │
│ SCOD_AUGRC │ 0.32517 │
│ SCOD_AURC │ 0.58219 │
│ SCOD_Cov_5R… │ nan │
│ SCOD_Risk_8… │ 0.69073 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov_5Risk │ 96.510% │
│ Risk_80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 284.38 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102/102 0:00:06 • 0:00:00 14.63it/s
10. Compare with the Maximum Softmax Probability baseline#
We can swap the OOD criterion to compare DEUP against the Maximum Softmax Probability (MSP) baseline — no re-training or re-fitting required. Only the OOD detection scores change; in-distribution metrics remain identical.
from torch_uncertainty.ood_criteria import MaxSoftmaxCriterion
routine.ood_criterion = MaxSoftmaxCriterion()
results_msp = trainer.test(routine, datamodule=datamodule)
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Acc │ 93.380% │
│ Brier │ 0.10812 │
│ Entropy │ 0.08849 │
│ NLL │ 0.26405 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Calibration ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ECE │ 3.537% │
│ MCE │ 23.670% │
│ SmECE │ 10.143% │
│ aECE │ 3.500% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ OOD Detection ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUPR │ 90.246% │
│ AUROC │ 82.969% │
│ Entropy │ 0.35549 │
│ FPR95 │ 56.050% │
│ SCOD_AUGRC │ 0.29513 │
│ SCOD_AURC │ 0.47860 │
│ SCOD_Cov_5R… │ nan │
│ SCOD_Risk_8… │ 0.67224 │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Selective Classification ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ AUGRC │ 0.779% │
│ AURC │ 0.959% │
│ Cov_5Risk │ 96.510% │
│ Risk_80Cov │ 1.200% │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ Complexity ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ flops │ 284.38 G │
│ params │ 11.17 M │
└──────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102/102 0:00:06 • 0:00:00 14.58it/s
The ood/ rows in both tables allow a direct comparison between the DEUP
epistemic score and the MSP confidence score as OOD detectors on a well-trained
ResNet-18. DEUP is expected to complement confidence-based criteria by focusing
on the epistemic component of the model’s uncertainty.
References#
DEUP: Lahlou, S., Jain, M., Nekoei, H., Butoi, V. I., Bertin, P., Rector-Brooks, J., … & Bengio, Y. (2023). DEUP: Direct Epistemic Uncertainty Prediction. TMLR. openreview.
ResNet: He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. CVPR 2016.
MSP baseline: Hendrycks, D., & Gimpel, K. (2017). A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks. ICLR 2017.
Total running time of the script: (0 minutes 26.760 seconds)