Source code for torch_uncertainty.metrics.regression.quantile_calibration

import warnings

import torch
from torch import Tensor
from torch.distributions import Distribution, Independent
from torchmetrics.classification import BinaryCalibrationError
from torchmetrics.functional.classification.calibration_error import (
    _binning_bucketize,
)
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE

from torch_uncertainty.metrics.classification.calibration_error import reliability_chart


[docs] class QuantileCalibrationError(BinaryCalibrationError): is_differentiable = False higher_is_better = False full_state_update = False not_implemented_error = False def __init__(self, num_bins=15, norm="l1", ignore_index=None, validate_args=True, **kwargs): """Quantile Calibration Error for regression tasks. This metric computes the calibration error of quantile predictions against the ground truth values. Args: num_bins (int, optional): Number of bins to use for calibration. Defaults to `15`. norm (str, optional): Norm to use for calibration error computation. Defaults to `"l1"`. ignore_index (int, optional): Index to ignore during calibration. Defaults to `None`. validate_args (bool, optional): Whether to validate the input arguments. Defaults to `True`. kwargs: Additional keyword arguments, see `Advanced metric settings <https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_. """ super().__init__(num_bins, norm, ignore_index, validate_args, **kwargs) self.conf_intervals = torch.linspace(0.05, 0.95, self.n_bins + 1)
[docs] def update( self, dist: Distribution, target: Tensor, padding_mask: Tensor | None = None, ) -> None: """Update the metric with new predictions and targets. Args: dist (Distribution): The predicted distribution. target (Tensor): The ground truth values. padding_mask (Tensor | None, optional): A mask to ignore certain values. Defaults to `None`. """ reduce_event_dims = False if isinstance(dist, Independent): iid_dist = dist.base_dist reduce_event_dims = True else: iid_dist = dist try: iid_dist.icdf((1 - self.conf_intervals[0]) / 2) except NotImplementedError: warnings.warn( "The distribution does not support the `icdf()` method. " "This metric will therefore return `nan` values. " "Please use a distribution that implements `icdf()`.", UserWarning, stacklevel=2, ) self.not_implemented_error = True return confidences = self.conf_intervals.expand(*dist.batch_shape, -1) correct_mask = torch.empty_like(confidences) for i, conf in enumerate(self.conf_intervals): b_min = iid_dist.icdf((1 - conf) / 2) bound_log_prob = iid_dist.log_prob(b_min) target_log_prob = dist.log_prob(target) if reduce_event_dims: bound_log_prob = bound_log_prob.sum( dim=list(range(-dist.reinterpreted_batch_ndims, 0)) ) correct_mask[..., i] = (bound_log_prob <= target_log_prob).float() if padding_mask is not None: confidences = confidences[~padding_mask] correct_mask = correct_mask[~padding_mask] super().update(confidences.flatten(), correct_mask.flatten())
[docs] def compute(self) -> Tensor: """Compute the quantile calibration error. Returns: Tensor: The quantile calibration error. Warning: If the distribution does not support the `icdf()` method, this will return `nan` values. """ if self.not_implemented_error: return torch.tensor(float("nan")) return super().compute()
[docs] def plot(self) -> _PLOT_OUT_TYPE: """Plot the quantile calibration reliability diagram. Raises: NotImplementedError: If the distribution does not support the `icdf()` method. """ if self.not_implemented_error: raise NotImplementedError( "The distribution does not support the `icdf()` method. " "This metric will therefore return `nan` values. " "Please use a distribution that implements `icdf()`." ) confidences = dim_zero_cat(self.confidences) accuracies = dim_zero_cat(self.accuracies) bin_boundaries = torch.linspace( 0, 1, self.n_bins + 1, dtype=torch.float, device=confidences.device, ) with torch.no_grad(): acc_bin, conf_bin, prop_bin = _binning_bucketize( confidences, accuracies, bin_boundaries ) np_acc_bin = acc_bin.cpu().numpy() np_conf_bin = conf_bin.cpu().numpy() np_prop_bin = prop_bin.cpu().numpy() np_bin_boundaries = bin_boundaries.cpu().numpy() return reliability_chart( accuracies=accuracies.cpu().numpy(), confidences=confidences.cpu().numpy(), bin_accuracies=np_acc_bin, bin_confidences=np_conf_bin, bin_sizes=np_prop_bin, bins=np_bin_boundaries, )