.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorials/Post_Hoc_Methods/tutorial_temperature.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorials_Post_Hoc_Methods_tutorial_temperature.py: Improve Top-label Calibration with Temperature Scaling ====================================================== In this tutorial, we use *TorchUncertainty* to improve the calibration of the top-label predictions and the reliability of the underlying neural network. This tutorial provides extensive details on how to use the TemperatureScaler class and its derivatives, VectorScaler, MatrixScaler, and Dirichlet calibration, as well as IsotonicRegressionScaler. However, please note that this is usually done automatically in the datamodule when setting the `postprocess_set` to val or test. Through this tutorial, we also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty's models. Note: ~~~~ The Expected Calibration Error (ECE) is not sufficient to properly assess the calibration properties of a model. 1. Loading the Utilities ~~~~~~~~~~~~~~~~~~~~~~~~ In this tutorial, we will need: - TorchUncertainty's Calibration Error metric to compute to evaluate the top-label calibration with ECE and plot the reliability diagrams - the CIFAR-100 datamodule to handle the data - a ResNet 18 as starting model - the temperature scaler to improve the top-label calibration - a utility function to download HF models easily If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag. .. GENERATED FROM PYTHON SOURCE LINES 37-43 .. code-block:: Python from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.metrics import CalibrationError from torch_uncertainty.models.classification import resnet from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.utils import load_hf .. GENERATED FROM PYTHON SOURCE LINES 44-49 2. Loading a model from TorchUncertainty's HF ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To avoid training a model on CIFAR-100 from scratch, we load a model from Hugging Face. This can be done in a one liner: .. GENERATED FROM PYTHON SOURCE LINES 49-58 .. code-block:: Python # Build the model model = resnet(in_channels=3, num_classes=100, arch=18, style="cifar", conv_bias=False) # Download the weights (the config is not used here) weights, config = load_hf("resnet18_c100") # Load the weights in the pre-built model model.load_state_dict(weights) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 59-66 3. Setting up the Datamodule and Dataloaders ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To get the dataloader from the datamodule, just call prepare_data, setup, and extract the first element of the test dataloader list. There are more than one element if eval_ood is True: the dataloader of in-distribution data and the dataloader of out-of-distribution data. Otherwise, it is a list of 1 element. .. GENERATED FROM PYTHON SOURCE LINES 66-71 .. code-block:: Python dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32) dm.prepare_data() dm.setup("test") .. GENERATED FROM PYTHON SOURCE LINES 72-82 4. Iterating on the Dataloader and Computing the ECE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We first split the original test set into a calibration set and a test set for proper evaluation. When computing the ECE, you need to provide the likelihoods associated with the inputs. To do this, just call PyTorch's softmax. To avoid lengthy computations, we restrict the calibration computation to a subset of the test set. .. GENERATED FROM PYTHON SOURCE LINES 82-103 .. code-block:: Python from torch.utils.data import DataLoader, random_split # Split datasets dataset = dm.test cal_dataset, test_dataset = random_split(dataset, [4000, len(dataset) - 4000]) test_dataloader = DataLoader(test_dataset, batch_size=128) calibration_dataloader = DataLoader(cal_dataset, batch_size=128) # Initialize the ECE ece = CalibrationError(task="multiclass", num_classes=100) # Iterate on the calibration dataloader for sample, target in test_dataloader: logits = model(sample) probs = logits.softmax(-1) ece.update(probs, target) # Compute & print the calibration error print(f"ECE before scaling - {ece.compute():.3%}.") .. rst-class:: sphx-glr-script-out .. code-block:: none ECE before scaling - 10.773%. .. GENERATED FROM PYTHON SOURCE LINES 104-106 We also compute and plot the top-label calibration figure. We see that the model is not well calibrated. .. GENERATED FROM PYTHON SOURCE LINES 106-110 .. code-block:: Python fig, ax = ece.plot() fig.tight_layout() fig.show() .. image-sg:: /auto_tutorials/Post_Hoc_Methods/images/sphx_glr_tutorial_temperature_001.png :alt: Reliability Diagram :srcset: /auto_tutorials/Post_Hoc_Methods/images/sphx_glr_tutorial_temperature_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 111-118 5. Fitting the Scaler to Improve the Calibration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The TemperatureScaler has one parameter that can be used to temper the softmax. We minimize the tempered cross-entropy on a calibration set that we define here as a subset of the test set and containing 1000 data. Look at the code run by TemperatureScaler `fit` method for more details. .. GENERATED FROM PYTHON SOURCE LINES 118-123 .. code-block:: Python # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler(model=model) scaled_model.fit(dataloader=calibration_dataloader) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/32 [00:00`_. - **Temperature Scaling:** Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. In `ICML 2017 `_. .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 24.360 seconds) .. _sphx_glr_download_auto_tutorials_Post_Hoc_Methods_tutorial_temperature.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tutorial_temperature.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_temperature.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tutorial_temperature.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_