• Docs >
  • Module code >
  • torch_uncertainty.metrics.classification.adaptive_calibration_error
Shortcuts

Source code for torch_uncertainty.metrics.classification.adaptive_calibration_error

from typing import Any, Literal

import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.classification.calibration_error import (
    _binary_calibration_error_arg_validation,
    _multiclass_calibration_error_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel


def _equal_binning_bucketize(
    confidences: Tensor, accuracies: Tensor, num_bins: int
) -> tuple[Tensor, Tensor, Tensor]:
    """Compute bins for the adaptive calibration error.

    Args:
        confidences: The confidence (i.e. predicted prob) of the top1
            prediction.
        accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise.
        num_bins: Number of bins to use when computing adaptive calibration
            error.

    Returns:
        tuple with binned accuracy, binned confidence and binned probabilities
    """
    confidences, indices = torch.sort(confidences)
    accuracies = accuracies[indices]
    acc_bin, conf_bin = (
        accuracies.tensor_split(num_bins),
        confidences.tensor_split(num_bins),
    )
    count_bin = torch.as_tensor(
        [len(cb) for cb in conf_bin],
        dtype=confidences.dtype,
        device=confidences.device,
    )
    return (
        pad_sequence(acc_bin, batch_first=True).sum(1) / count_bin,
        pad_sequence(conf_bin, batch_first=True).sum(1) / count_bin,
        torch.as_tensor(count_bin) / len(confidences),
    )


def _ace_compute(
    confidences: Tensor,
    accuracies: Tensor,
    num_bins: int,
    norm: Literal["l1", "l2", "max"] = "l1",
    debias: bool = False,
) -> Tensor:
    """Compute the adaptive calibration error given the provided number of bins
        and norm.

    Args:
        confidences: The confidence (i.e. predicted prob) of the top1
            prediction.
        accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise.
        num_bins: Number of bins to use when computing adaptive calibration
            error.
        norm: Norm function to use when computing calibration error. Defaults
            to "l1".
        debias: Apply debiasing to L2 norm computation as in
            `Verified Uncertainty Calibration`. Defaults to False.

    Returns:
        Tensor: Adaptive Calibration error scalar.
    """
    with torch.no_grad():
        acc_bin, conf_bin, prop_bin = _equal_binning_bucketize(
            confidences, accuracies, num_bins
        )

    if norm == "l1":
        return torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin)
    if norm == "max":
        ace = torch.max(torch.abs(acc_bin - conf_bin))
    if norm == "l2":
        ace = torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin)
        if debias:  # coverage: ignore
            debias_bins = (acc_bin * (acc_bin - 1) * prop_bin) / (
                prop_bin * accuracies.size()[0] - 1
            )
            ace += torch.sum(
                torch.nan_to_num(debias_bins)
            )  # replace nans with zeros if nothing appeared in a bin
        return torch.sqrt(ace) if ace > 0 else torch.tensor(0)
    return ace


class BinaryAdaptiveCalibrationError(Metric):
    r"""Adaptive Top-label Calibration Error for binary tasks."""

    is_differentiable: bool = False
    higher_is_better: bool = False
    full_state_update: bool = False

    confidences: list[Tensor]
    accuracies: list[Tensor]

    def __init__(
        self,
        n_bins: int = 10,
        norm: Literal["l1", "l2", "max"] = "l1",
        ignore_index: int | None = None,
        validate_args: bool = True,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if ignore_index is not None:  # coverage: ignore
            raise ValueError(
                "ignore_index is not supported for multiclass tasks."
            )

        if validate_args:
            _binary_calibration_error_arg_validation(n_bins, norm, ignore_index)
        self.n_bins = n_bins
        self.norm = norm

        self.add_state("confidences", [], dist_reduce_fx="cat")
        self.add_state("accuracies", [], dist_reduce_fx="cat")

    def update(self, probs: Tensor, targets: Tensor) -> None:
        """Update metric states with predictions and targets."""
        confidences, preds = torch.max(probs, 1 - probs), torch.round(probs)
        accuracies = preds == targets
        self.confidences.append(confidences.float())
        self.accuracies.append(accuracies.float())

    def compute(self) -> Tensor:
        """Compute metric."""
        confidences = dim_zero_cat(self.confidences)
        accuracies = dim_zero_cat(self.accuracies)
        return _ace_compute(
            confidences, accuracies, self.n_bins, norm=self.norm
        )


class MulticlassAdaptiveCalibrationError(Metric):
    r"""Adaptive Top-label Calibration Error for multiclass tasks."""

    is_differentiable: bool = False
    higher_is_better: bool = False
    full_state_update: bool = False

    confidences: list[Tensor]
    accuracies: list[Tensor]

    def __init__(
        self,
        num_classes: int,
        n_bins: int = 10,
        norm: Literal["l1", "l2", "max"] = "l1",
        ignore_index: int | None = None,
        validate_args: bool = True,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if ignore_index is not None:  # coverage: ignore
            raise ValueError(
                "ignore_index is not supported for multiclass tasks."
            )

        if validate_args:
            _multiclass_calibration_error_arg_validation(
                num_classes, n_bins, norm, ignore_index
            )
        self.n_bins = n_bins
        self.norm = norm

        self.add_state("confidences", [], dist_reduce_fx="cat")
        self.add_state("accuracies", [], dist_reduce_fx="cat")

    def update(self, probs: Tensor, targets: Tensor) -> None:
        """Update metric states with predictions and targets."""
        confidences, preds = torch.max(probs, 1)
        accuracies = preds == targets
        self.confidences.append(confidences.float())
        self.accuracies.append(accuracies.float())

    def compute(self) -> Tensor:
        """Compute metric."""
        confidences = dim_zero_cat(self.confidences)
        accuracies = dim_zero_cat(self.accuracies)
        return _ace_compute(
            confidences, accuracies, self.n_bins, norm=self.norm
        )


[docs]class AdaptiveCalibrationError: """`Adaptive Top-label Calibration Error`. Reference: Nixon et al. Measuring calibration in deep learning. In CVPRW, 2019. """ def __new__( cls, task: Literal["binary", "multiclass"], num_bins: int = 10, norm: Literal["l1", "l2", "max"] = "l1", num_classes: int | None = None, ignore_index: int | None = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update( { "n_bins": num_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args, } ) if task == ClassificationTaskNoMultilabel.BINARY: return BinaryAdaptiveCalibrationError(**kwargs) # task is ClassificationTaskNoMultilabel.MULTICLASS if not isinstance(num_classes, int): raise TypeError( f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`" ) return MulticlassAdaptiveCalibrationError(num_classes, **kwargs)