Shortcuts

Source code for torch_uncertainty.metrics.classification.calibration_error

from typing import Any, Literal

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import seaborn as sns
import torch
from torchmetrics.classification.calibration_error import (
    BinaryCalibrationError,
    MulticlassCalibrationError,
)
from torchmetrics.functional.classification.calibration_error import (
    _binning_bucketize,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE

from .adaptive_calibration_error import AdaptiveCalibrationError


def _reliability_diagram_subplot(
    ax,
    accuracies: np.ndarray,
    confidences: np.ndarray,
    bin_sizes: np.ndarray,
    bins: np.ndarray,
    title: str = "Reliability Diagram",
    xlabel: str = "Top-class Confidence (%)",
    ylabel: str = "Success Rate (%)",
) -> None:
    widths = 1.0 / len(bin_sizes)
    positions = bins + widths / 2.0
    alphas = 0.2 + 0.8 * bin_sizes

    colors = np.zeros((len(bin_sizes), 4))
    colors[:, 0] = 240 / 255.0
    colors[:, 1] = 60 / 255.0
    colors[:, 2] = 60 / 255.0
    colors[:, 3] = alphas

    gap_plt = ax.bar(
        positions,
        np.abs(accuracies - confidences),
        bottom=np.minimum(accuracies, confidences),
        width=widths,
        edgecolor=colors,
        color=colors,
        linewidth=1,
        label="Gap",
    )

    acc_plt = ax.bar(
        positions,
        0,
        bottom=accuracies,
        width=widths,
        edgecolor="black",
        color="black",
        alpha=1.0,
        linewidth=3,
        label="Accuracy",
    )

    ax.set_aspect("equal")
    ax.plot([0, 1], [0, 1], linestyle="--", color="gray")

    gaps = np.abs(accuracies - confidences)
    ece = (np.sum(gaps * bin_sizes) / np.sum(bin_sizes)) * 100

    ax.text(
        0.98,
        0.02,
        f"ECE={ece:.03}%",
        color="black",
        ha="right",
        va="bottom",
        transform=ax.transAxes,
    )

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    ax.legend(handles=[gap_plt, acc_plt])


def _confidence_histogram_subplot(
    ax,
    accuracies: np.ndarray,
    confidences: np.ndarray,
    title="Examples per bin",
    xlabel="Top-class Confidence (%)",
    ylabel="Density",
) -> None:
    sns.kdeplot(
        confidences,
        linewidth=2,
        ax=ax,
        fill=True,
        alpha=0.5,
    )

    ax.set_xlim(0, 1)
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    avg_acc = np.mean(accuracies)
    avg_conf = np.mean(confidences)

    acc_plt = ax.axvline(
        x=avg_acc,
        ls="solid",
        lw=3,
        c="black",
        label="Accuracy",
    )
    conf_plt = ax.axvline(
        x=avg_conf,
        ls="dotted",
        lw=3,
        c="#444",
        label="Avg. confidence",
    )
    ax.legend(handles=[acc_plt, conf_plt], loc="upper left")


def reliability_chart(
    accuracies: np.ndarray,
    confidences: np.ndarray,
    bin_accuracies: np.ndarray,
    bin_confidences: np.ndarray,
    bin_sizes: np.ndarray,
    bins: np.ndarray,
    title="Reliability Diagram",
    figsize=(6, 6),
    dpi=72,
) -> _PLOT_OUT_TYPE:
    """Builds Reliability Diagram
    `Source <https://github.com/hollance/reliability-diagrams>`_.
    """
    figsize = (figsize[0], figsize[0] * 1.4)

    fig, ax = plt.subplots(
        nrows=2,
        ncols=1,
        sharex=True,
        figsize=figsize,
        dpi=dpi,
        gridspec_kw={"height_ratios": [4, 1]},
    )

    plt.tight_layout()
    plt.subplots_adjust(hspace=0)

    # reliability diagram subplot
    _reliability_diagram_subplot(
        ax[0],
        bin_accuracies,
        bin_confidences,
        bin_sizes,
        bins,
        title=title,
    )

    # confidence histogram subplot
    _confidence_histogram_subplot(ax[1], accuracies, confidences, title="")

    new_ticks = np.abs(ax[1].get_yticks()).astype(np.int32)
    ax[1].yaxis.set_major_locator(mticker.FixedLocator(new_ticks))
    ax[1].set_yticklabels(new_ticks)

    return fig, ax


def custom_plot(self) -> _PLOT_OUT_TYPE:
    confidences = dim_zero_cat(self.confidences)
    accuracies = dim_zero_cat(self.accuracies)

    bin_boundaries = torch.linspace(
        0,
        1,
        self.n_bins + 1,
        dtype=torch.float,
        device=confidences.device,
    )

    with torch.no_grad():
        acc_bin, conf_bin, prop_bin = _binning_bucketize(
            confidences, accuracies, bin_boundaries
        )

    np_acc_bin = acc_bin.cpu().numpy()
    np_conf_bin = conf_bin.cpu().numpy()
    np_prop_bin = prop_bin.cpu().numpy()
    np_bin_boundaries = bin_boundaries.cpu().numpy()

    return reliability_chart(
        accuracies=accuracies.cpu().numpy(),
        confidences=confidences.cpu().numpy(),
        bin_accuracies=np_acc_bin,
        bin_confidences=np_conf_bin,
        bin_sizes=np_prop_bin,
        bins=np_bin_boundaries,
    )


# overwrite the plot method of the original metrics
BinaryCalibrationError.plot = custom_plot
MulticlassCalibrationError.plot = custom_plot


[docs]class CalibrationError: r"""Top-label Calibration Error. See `CalibrationError <https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html>`_ for details. Our version of the metric is a wrapper around the original metric providing a plotting functionality. Reference: Naeini et al. "Obtaining well calibrated probabilities using Bayesian binning." In AAAI, 2015. """ def __new__( # type: ignore[misc] cls, task: Literal["binary", "multiclass"], adaptive: bool = False, num_bins: int = 10, norm: Literal["l1", "l2", "max"] = "l1", num_classes: int | None = None, ignore_index: int | None = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: """Initialize task metric.""" if adaptive: return AdaptiveCalibrationError( task=task, num_bins=num_bins, norm=norm, num_classes=num_classes, ignore_index=ignore_index, validate_args=validate_args, **kwargs, ) task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update( { "n_bins": num_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args, } ) if task == ClassificationTaskNoMultilabel.BINARY: return BinaryCalibrationError(**kwargs) # task is ClassificationTaskNoMultilabel.MULTICLASS if not isinstance(num_classes, int): raise TypeError( f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`" ) return MulticlassCalibrationError(num_classes, **kwargs)