SmoothCalibrationError#

class torch_uncertainty.metrics.classification.SmoothCalibrationError(kernel_type='logit', bandwidth='auto', eps=0.001, mesh_pts=200, refine_steps=10, **kwargs)[source]#

Smooth Expected Calibration Error (SmECE).

This metric implements the Kernel Density Estimation based ECE as proposed by Błasiok & Nakkiran (2023). It addresses the limitations of standard binned ECE, such as bin-edge effects and poor resolution for overconfident models, by using a continuous kernel and an adaptive bandwidth selection strategy. Computed on the top label.

Parameters:
  • kernel_type (str, optional) –

    The kernel to use. Choose between: - 'logit': Applies a Gaussian kernel in log-odds space. This

    effectively uses an adaptive bandwidth that is narrower near 1.0, making it ideal for modern overconfident models. (Default)

    • 'reflected': Applies a Gaussian kernel in probability space

      with reflections at 0 and 1 to prevent boundary bias.

    Note that relplot’s original implementation has 'reflected' as default.

  • bandwidth (Literal[auto] | float, optional) – The kernel bandwidth $h$. If set to 'auto', it uses a fixed-point binary search to find a bandwidth consistent with the error level. Defaults to 'auto'.

  • eps (float, optional) – The tolerance for the binary search when bandwidth is 'auto'. Defaults to 0.001.

  • mesh_pts (int, optional) – The base number of points for the grid discretization. The actual number may be higher depending on the bandwidth. Defaults to 200.

  • refine_steps (int, optional) – Number of binary search iterations for the 'auto' bandwidth. Defaults to 10.

  • **kwargs – Additional arguments for the torchmetrics.Metric base.

Note

In the multiclass case, this metric evaluates the calibration of the maximum probability (top-label calibration). In the binary case, it evaluates the calibration of the predicted class (i.e., using $max(p, 1-p)$).

Note

This implementation has been tested on a use case and provided the same values (with 6 equal significant figures) as relplot’s original implementation.

References

  • Błasiok, J. & Nakkiran, P. Smooth ECE: Principled Reliability

Diagrams. ICLR 2024.

compute()[source]#

Compute the Smooth ECE based on the accumulated state.

Returns:

The scalar SmECE value.

Return type:

Tensor

update(preds, target)[source]#

Update the state with predictions and targets.

Parameters:
  • preds (Tensor) – Predictions from the model. - Multiclass: Shape (N, C) (logits or probabilities). - Binary: Shape (N,) or (N, 1) (logits or probabilities).

  • target (Tensor) – Ground truth labels. - Multiclass: Shape (N,) containing class indices. - Binary: Shape (N,) containing 0 or 1.