Shortcuts

Deep Evidential Classification on a Toy Example

This tutorial aims to provide an introductory overview of Deep Evidential Classification (DEC) using a practical example. We demonstrate an application of DEC by tackling the toy-problem of fitting the MNIST dataset using a Multi-Layer Perceptron (MLP) neural network model. The output of the MLP is modeled as a Dirichlet distribution. The MLP is trained by minimizing the DEC loss function, composed of a Bayesian risk square error loss and a regularization term based on KL Divergence.

DEC represents an evidential approach to quantifying uncertainty in neural network classification models. This method involves introducing prior distributions over the parameters of the Categorical likelihood function. Then, the MLP model estimates the parameters of the evidential distribution.

Training a LeNet with DEC using TorchUncertainty models

In this part, we train a neural network, based on the model and routines already implemented in TU.

1. Loading the utilities

To train a LeNet with the DEC loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty:

  • our wrapper of the Lightning Trainer

  • the model: LeNet, which lies in torch_uncertainty.models

  • the classification training routine in the torch_uncertainty.routines

  • the evidential objective: the DECLoss from torch_uncertainty.losses

  • the datamodule that handles dataloaders & transforms: MNISTDataModule from torch_uncertainty.datamodules

We also need to define an optimizer using torch.optim, the neural network utils within torch.nn.

from pathlib import Path

import torch
from torch import nn, optim

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import DECLoss
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.routines import ClassificationRoutine

2. Creating the Optimizer Wrapper

We follow the official implementation in DEC, use the Adam optimizer with the default learning rate of 0.001 and a step scheduler.

def optim_lenet(model: nn.Module) -> dict:
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)
    exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    return {"optimizer": optimizer, "lr_scheduler": exp_lr_scheduler}

3. Creating the necessary variables

In the following, we need to define the root of the logs, and to We use the same MNIST classification example as that used in the original DEC paper. We only train for 3 epochs for the sake of time.

trainer = TUTrainer(accelerator="cpu", max_epochs=3, enable_progress_bar=False)

# datamodule
root = Path() / "data"
datamodule = MNISTDataModule(root=root, batch_size=128)

model = lenet(
    in_channels=datamodule.num_channels,
    num_classes=datamodule.num_classes,
)

4. The Loss and the Training Routine

Next, we need to define the loss to be used during training. After that, we define the training routine using the single classification model training routine from torch_uncertainty.routines.ClassificationRoutine. In this routine, we provide the model, the DEC loss, the optimizer, and all the default arguments.

loss = DECLoss(reg_weight=1e-2)

routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss,
    optim_recipe=optim_lenet(model),
)

5. Gathering Everything and Training the Model

trainer.fit(model=routine, datamodule=datamodule)
trainer.test(model=routine, datamodule=datamodule)
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃      Classification       ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     Acc      │          75.00%           │
│    Brier     │          0.28558          │
│   Entropy    │          0.32940          │
│     NLL      │          1.07252          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃        Calibration        ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     ECE      │          0.06117          │
│     aECE     │          0.10416          │
└──────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric  ┃ Selective Classification  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    AUGRC     │           4.13%           │
│     AURC     │           5.13%           │
│  Cov@5Risk   │          69.57%           │
│  Risk@80Cov  │           9.40%           │
└──────────────┴───────────────────────────┘

[{'test/cal/ECE': 0.06117115169763565, 'test/cal/aECE': 0.10416258126497269, 'test/cls/Acc': 0.75, 'test/cls/Brier': 0.28558000922203064, 'test/cls/NLL': 1.0725198984146118, 'test/sc/AUGRC': 0.04133406654000282, 'test/sc/AURC': 0.05131366103887558, 'test/sc/Cov@5Risk': 0.6956999897956848, 'test/sc/Risk@80Cov': 0.09399999678134918, 'test/cls/Entropy': 0.32939988374710083}]

6. Testing the Model

Now that the model is trained, let’s test it on MNIST.

import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms.functional as F


def imshow(img) -> None:
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def rotated_mnist(angle: int) -> None:
    """Rotate MNIST images and show images and confidence.

    Args:
        angle: Rotation angle in degrees.
    """
    rotated_images = F.rotate(images, angle)
    # print rotated images
    plt.axis("off")
    imshow(torchvision.utils.make_grid(rotated_images[:4, ...]))
    print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

    evidence = routine(rotated_images)
    alpha = torch.relu(evidence) + 1
    strength = torch.sum(alpha, dim=1, keepdim=True)
    probs = alpha / strength
    entropy = -1 * torch.sum(probs * torch.log(probs), dim=1, keepdim=True)
    for j in range(4):
        predicted = torch.argmax(probs[j, :])
        print(
            f"Predicted digits for the image {j}: {predicted} with strength "
            f"{strength[j,0]:.3} and entropy {entropy[j,0]:.3}."
        )


dataiter = iter(datamodule.val_dataloader())
images, labels = next(dataiter)

with torch.no_grad():
    routine.eval()
    rotated_mnist(0)
    rotated_mnist(45)
    rotated_mnist(90)
tutorial evidential classification
Ground truth:  7 2 1 0
Predicted digits for the image 0: 7 with strength 97.9 and entropy 0.509.
Predicted digits for the image 1: 2 with strength 69.2 and entropy 0.673.
Predicted digits for the image 2: 0 with strength 10.0 and entropy 2.3.
Predicted digits for the image 3: 0 with strength 32.4 and entropy 1.2.
Ground truth:  7 2 1 0
Predicted digits for the image 0: 9 with strength 15.9 and entropy 1.93.
Predicted digits for the image 1: 0 with strength 10.0 and entropy 2.3.
Predicted digits for the image 2: 8 with strength 16.2 and entropy 2.02.
Predicted digits for the image 3: 0 with strength 22.2 and entropy 1.57.
Ground truth:  7 2 1 0
Predicted digits for the image 0: 0 with strength 10.0 and entropy 2.3.
Predicted digits for the image 1: 4 with strength 53.2 and entropy 1.3.
Predicted digits for the image 2: 7 with strength 48.3 and entropy 1.34.
Predicted digits for the image 3: 9 with strength 10.5 and entropy 2.29.

References

  • Deep Evidential Classification: Murat Sensoy, Lance Kaplan, & Melih Kandemir (2018). Evidential Deep Learning to Quantify Classification Uncertainty NeurIPS 2018.

Total running time of the script: (0 minutes 56.925 seconds)

Gallery generated by Sphinx-Gallery