SmoothCalibrationError#

class torch_uncertainty.metrics.classification.SmoothCalibrationError(kernel_type='logit', bandwidth='auto', eps=0.001, mesh_pts=200, refine_steps=10, **kwargs)[source]#

Smooth Expected Calibration Error (SmECE).

This metric implements the Kernel Density Estimation based ECE as proposed by Błasiok & Nakkiran (2023). It addresses the limitations of standard binned ECE, such as bin-edge effects and poor resolution for overconfident models, by using a continuous kernel and an adaptive bandwidth selection strategy. Computed on the top label.

Parameters:
  • kernel_type (Literal['logit', 'reflected']) – The kernel to use. Choose between:

  • 'logit' (-) – Applies a Gaussian kernel in log-odds space. This effectively uses an adaptive bandwidth that is narrower near 1.0, making it ideal for modern overconfident models. (Default)

  • 'reflected' (-) –

    Applies a Gaussian kernel in probability space with reflections at 0 and 1 to prevent boundary bias.

    Note that relplot’s original implementation uses 'reflected' by default.

  • bandwidth (Union[float, Literal['auto']]) – The kernel bandwidth \(h\). If set to 'auto', it uses a fixed-point binary search to find a bandwidth consistent with the error level. Defaults to 'auto'.

  • eps (float) – The tolerance for the binary search when bandwidth is 'auto'. Defaults to 0.001.

  • mesh_pts (int) – The base number of points for the grid discretization. The actual number may be higher depending on the bandwidth. Defaults to 200.

  • refine_steps (int) – Number of binary search iterations for the 'auto' bandwidth. Defaults to 10.

  • **kwargs – Additional arguments for the torchmetrics.Metric base.

Note

In the multiclass case, this metric evaluates the calibration of the maximum probability (top-label calibration). In the binary case, it evaluates the calibration of the predicted class (i.e., using \(\max(p, 1-p)\)).

Note

This implementation has been tested on a use case and provided the same values (with 6 equal significant figures) as relplot’s original implementation.

References

  • Błasiok, J. & Nakkiran, P. Smooth ECE: Principled Reliability

Diagrams. ICLR 2024.

compute()[source]#

Compute the Smooth ECE based on the accumulated state.

Returns:

The scalar SmECE value.

Return type:

Tensor

update(preds, target)[source]#

Update the state with predictions and targets.

Parameters:
  • preds (Tensor) – Predictions from the model. - Multiclass: Shape (N, C) (logits or probabilities). - Binary: Shape (N,) or (N, 1) (logits or probabilities).

  • target (Tensor) – Ground truth labels. - Multiclass: Shape (N,) containing class indices. - Binary: Shape (N,) containing 0 or 1.

Return type:

None