.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorials/Classification/tutorial_ood_detection.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorials_Classification_tutorial_ood_detection.py: Out-of-distribution detection with TorchUncertainty =================================================== This tutorial demonstrates how to perform OOD detection using TorchUncertainty's ClassificationRoutine with a ResNet18 model trained on CIFAR-10, evaluating its performance with SVHN as the OOD dataset. We will: - Set up the CIFAR-10 datamodule. - Initialize and shortly train a ResNet18 model using the ClassificationRoutine. - Evaluate the model's performance on both in-distribution and out-of-distribution data. - Analyze uncertainty metrics for OOD detection. .. GENERATED FROM PYTHON SOURCE LINES 19-25 Imports and Setup ------------------ First, we need to import the necessary libraries and set up our environment. This includes importing PyTorch, TorchUncertainty components, and TorchUncertainty's Trainer (built on top of Lightning's), as well as two criteria for OOD detection, the maximum softmax probability [1] and the Max Logit [2]. .. GENERATED FROM PYTHON SOURCE LINES 25-33 .. code-block:: Python from torch import nn, optim from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.models.classification.resnet import resnet from torch_uncertainty.ood_criteria import MaxLogitCriterion, MaxSoftmaxCriterion from torch_uncertainty.routines.classification import ClassificationRoutine .. GENERATED FROM PYTHON SOURCE LINES 34-42 DataModule Setup ---------------- TorchUncertainty provides convenient DataModules for standard datasets like CIFAR-10. DataModules handle data loading, preprocessing, and batching, simplifying the data pipeline. Each datamodule also include the corresponding out-of-distribution and distribution shift datasets, which are then used by the routine. For CIFAR-10, the corresponding OOD-detection dataset is SVHN as used in the community. To enable OOD evaluation, activate the `eval_ood` flag as done below. .. GENERATED FROM PYTHON SOURCE LINES 42-45 .. code-block:: Python datamodule = CIFAR10DataModule(root="./data", batch_size=512, num_workers=8, eval_ood=True) .. GENERATED FROM PYTHON SOURCE LINES 46-52 Model Initialization -------------------- We use the ResNet18 architecture, a widely adopted convolutional neural network known for its deep residual learning capabilities. The model is initialized with 10 output classes corresponding to the CIFAR-10 dataset categories. When training on CIFAR, do not forget to set the style of the resnet to CIFAR, otherwise it will lose more information in the first convolution. .. GENERATED FROM PYTHON SOURCE LINES 52-56 .. code-block:: Python # Initialize the ResNet18 model model = resnet(arch=18, in_channels=3, num_classes=10, style="cifar", conv_bias=False) .. GENERATED FROM PYTHON SOURCE LINES 57-67 Define the Classification Routine --------------------------------- The `ClassificationRoutine` is one of the most crucial building blocks in TorchUncertainty. It streamlines the training and evaluation processes. It integrates the model, loss function, and optimizer into a cohesive routine compatible with PyTorch Lightning's Trainer. This abstraction simplifies the implementation of standard training loops and evaluation protocols. To come back to what matters in this tutorial, the routine also handles OOD detection. To enable it, just activate the `eval_ood` flag. Note that you can also evaluate the distribution-shift performance of the model at the same time by also setting `eval_shift` to True. .. GENERATED FROM PYTHON SOURCE LINES 67-84 .. code-block:: Python # Loss function criterion = nn.CrossEntropyLoss() # Optimizer optimizer = optim.Adam(model.parameters(), lr=0.001) # Initialize the ClassificationRoutine, you could replace MaxSoftmaxCriterion by "msp" routine = ClassificationRoutine( model=model, num_classes=10, loss=criterion, optim_recipe=optimizer, eval_ood=True, ood_criterion=MaxSoftmaxCriterion, ) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/chocolatine/actions-runner/_work/_tool/Python/3.11.12/x64/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `FPR95` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint. warnings.warn(*args, **kwargs) .. GENERATED FROM PYTHON SOURCE LINES 85-92 Test the Training of the Model ------------------------------ With the routine defined, we can now set up the Trainer and commence training. The Trainer handles the training loop, including epoch management, logging, and checkpointing. We specify the maximum number of epochs, the precision and the device to be used. To reduce the tutorial building time, we will train for a single epoch and load a model from `TorchUncertainty's HuggingFace `_. .. GENERATED FROM PYTHON SOURCE LINES 92-101 .. code-block:: Python # Initialize the TUTrainer trainer = TUTrainer( max_epochs=1, precision="16-mixed", accelerator="cuda", devices=1, enable_progress_bar=False ) # Train the model for 1 epoch using the CIFAR-10 DataModule trainer.fit(routine, datamodule=datamodule) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/64.3M [00:00`_ and load it with the `load_from_checkpoint` method. .. GENERATED FROM PYTHON SOURCE LINES 107-118 .. code-block:: Python import torch from huggingface_hub import hf_hub_download path = hf_hub_download( repo_id="torch-uncertainty/resnet18_c10", filename="resnet18_c10.ckpt", ) state_dict = torch.load(path, map_location="cpu", weights_only=True) routine.model.load_state_dict(state_dict) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 119-124 Evaluating on In-Distribution and Out-of-distribution Data ---------------------------------------------------------- Now that the model is trained, we can evaluate its performance on the original in-distribution test set, as well as the OOD set. Typing the next line will automatically compute the in-distribution and OOD detection metrics. .. GENERATED FROM PYTHON SOURCE LINES 124-128 .. code-block:: Python # Evaluate the model on the CIFAR-10 (IID) and SVHN (OOD) test sets results = trainer.test(routine, datamodule=datamodule) .. rst-class:: sphx-glr-script-out .. code-block:: none ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Classification ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ Acc │ 93.380% │ │ Brier │ 0.10812 │ │ Entropy │ 0.08849 │ │ NLL │ 0.26408 │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Calibration ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ ECE │ 3.546% │ │ aECE │ 3.499% │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ OOD Detection ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ AUPR │ 90.244% │ │ AUROC │ 82.968% │ │ Entropy │ 0.08849 │ │ FPR95 │ 56.060% │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Selective Classification ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ AUGRC │ 0.779% │ │ AURC │ 0.960% │ │ Cov@5Risk │ 96.520% │ │ Risk@80Cov │ 1.200% │ └──────────────┴───────────────────────────┘ .. GENERATED FROM PYTHON SOURCE LINES 129-135 Changing the OOD Criterion -------------------------- The previous metrics for Out-of-distribution detection have been computed using the maximum softmax probability score [1], which corresponds to the likelihood of the prediction. We could use other scores such as the maximum logit [2]. To do this, just change the routine's `ood_criterion` and perform a second test. .. GENERATED FROM PYTHON SOURCE LINES 135-139 .. code-block:: Python routine.ood_criterion = MaxLogitCriterion() results = trainer.test(routine, datamodule=datamodule) .. rst-class:: sphx-glr-script-out .. code-block:: none ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Classification ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ Acc │ 93.380% │ │ Brier │ 0.10812 │ │ Entropy │ 0.08849 │ │ NLL │ 0.26408 │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Calibration ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ ECE │ 3.546% │ │ aECE │ 3.499% │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ OOD Detection ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ AUPR │ 85.143% │ │ AUROC │ 73.895% │ │ Entropy │ 0.08849 │ │ FPR95 │ 79.620% │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Selective Classification ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ AUGRC │ 0.779% │ │ AURC │ 0.960% │ │ Cov@5Risk │ 96.520% │ │ Risk@80Cov │ 1.200% │ └──────────────┴───────────────────────────┘ .. GENERATED FROM PYTHON SOURCE LINES 140-150 Note that you could create your own class if you want to implement a custom OOD detection score. When changing the Out-of-distribution criterion, all the In-distribution metrics remain the same. The only values that change are those of the regrouped in the OOD Detection category. Here we see that the AUPR, AUROC and FPR95 are worse using the maximum logit score compared to the maximum softmax probability but it could depend on the model you are using. References ---------- [1] Hendrycks, D., & Gimpel, K. (2016). A baseline for detecting misclassified and out-of-distribution examples in neural networks. In ICLR 2017. [2] Hendrycks, D., Basart, S., Mazeika, M., Zou, A., Kwon, J., Mostajabi, M., ... & Song, D. (2019). Scaling out-of-distribution detection for real-world settings. In ICML 2022. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 34.373 seconds) .. _sphx_glr_download_auto_tutorials_Classification_tutorial_ood_detection.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tutorial_ood_detection.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_ood_detection.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tutorial_ood_detection.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_