SmoothCalibrationError#

class torch_uncertainty.metrics.classification.SmoothCalibrationError(kernel_type='logit', bandwidth='auto', eps=0.001, mesh_pts=200, refine_steps=10, **kwargs)[source]#

Smooth Expected Calibration Error (SmECE).

This metric implements the Kernel Density Estimation based ECE proposed by Błasiok & Nakkiran (2023). It addresses the limitations of the standard binned ECE — bin-edge effects, poor resolution for overconfident models — by replacing the histogram with a kernel density estimate of the residuals.

Given top-class confidences \(\hat{p}_i \in [0, 1]\) and accuracies \(a_i \in \{0, 1\}\), define the kernel-smoothed conditional gap

\[r_h(t) = \frac{\sum_{i=1}^{N} K_h(t, \hat{p}_i) (\hat{p}_i - a_i)} {\sum_{i=1}^{N} K_h(t, \hat{p}_i)},\]

where \(K_h\) is a kernel of bandwidth \(h\). The Smooth ECE is then

\[\text{SmECE} = \int_0^1 |r_h(t)| \, \hat{f}(t) \, \mathrm{d}t,\]

with \(\hat{f}(t) = \tfrac{1}{N}\sum_i K_h(t, \hat{p}_i)\) the kernel density of confidences. The bandwidth \(h\) can be fixed or selected adaptively via a fixed-point binary search.

Parameters:
  • kernel_type (Literal['logit', 'reflected']) – The kernel to use. Choose between:

  • 'logit' (-) – Applies a Gaussian kernel in log-odds space. This effectively uses an adaptive bandwidth that is narrower near 1.0, making it ideal for modern overconfident models. (Default)

  • 'reflected' (-) –

    Applies a Gaussian kernel in probability space with reflections at 0 and 1 to prevent boundary bias.

    Note that relplot’s original implementation uses 'reflected' by default.

  • bandwidth (Union[float, Literal['auto']]) – The kernel bandwidth \(h\). If set to 'auto', it uses a fixed-point binary search to find a bandwidth consistent with the error level. Defaults to 'auto'.

  • eps (float) – The tolerance for the binary search when bandwidth is 'auto'. Defaults to 0.001.

  • mesh_pts (int) – The base number of points for the grid discretization. The actual number may be higher depending on the bandwidth. Defaults to 200.

  • refine_steps (int) – Number of binary search iterations for the 'auto' bandwidth. Defaults to 10.

  • **kwargs – Additional arguments for the torchmetrics.Metric base.

Note

In the multiclass case, this metric evaluates the calibration of the maximum probability (top-label calibration). In the binary case, it evaluates the calibration of the predicted class (i.e., using \(\max(p, 1-p)\)).

Note

This implementation has been tested on a use case and provides the same values as relplot’s original implementation (to 6 significant figures).

References

  • Błasiok, J. & Nakkiran, P. Smooth ECE: Principled Reliability Diagrams via Kernel Smoothing. ICLR 2024.

compute()[source]#

Compute the Smooth ECE based on the accumulated state.

Returns:

The scalar SmECE value.

Return type:

Tensor

update(preds, target)[source]#

Update the state with predictions and targets.

Parameters:
  • preds (Tensor) – Predictions from the model. - Multiclass: Shape (N, C) (logits or probabilities). - Binary: Shape (N,) or (N, 1) (logits or probabilities).

  • target (Tensor) – Ground truth labels. - Multiclass: Shape (N,) containing class indices. - Binary: Shape (N,) containing 0 or 1.

Return type:

None