Source code for torch_uncertainty.metrics.classification.set_size

from typing import Literal, cast

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.data import dim_zero_cat


[docs] class SetSize(Metric): is_differentiable = False higher_is_better = False full_state_update = False sizes: list[Tensor] | Tensor total: Tensor def __init__( self, reduction: Literal["mean", "sum", "none"] | None = "mean", **kwargs, ) -> None: r"""Average prediction-set size — the standard *efficiency* metric for conformal prediction methods. For a set-valued predictor :math:`\mathcal{C}(X) \subseteq \{1, \dots, C\}`, .. math:: \text{SetSize} = \frac{1}{N} \sum_{i=1}^{N} |\mathcal{C}(x_i)|. Smaller sets are more informative, hence ``higher_is_better = False``. Set size is typically reported jointly with the empirical :class:`~torch_uncertainty.metrics.classification.CoverageRate`: a useful conformal predictor achieves the target coverage with as small a set as possible. Args: reduction: Determines how to reduce over the :math:`B`/batch dimension: - ``'mean'`` [default]: Averages score across samples - ``'sum'``: Sum score across samples - ``'none'`` or ``None``: Returns score per sample kwargs: Additional keyword arguments, see `Advanced metric settings <https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_. """ super().__init__(**kwargs) allowed_reduction = ("sum", "mean", "none", None) if reduction not in allowed_reduction: raise ValueError( "Expected argument `reduction` to be one of ", f"{allowed_reduction} but got {reduction}", ) self.reduction = reduction if self.reduction in ["mean", "sum"]: self.add_state("sizes", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") else: self.add_state("sizes", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") # pyrefly: ignore[bad-override]
[docs] def update(self, preds: torch.Tensor, targets: torch.Tensor | None = None) -> None: """Update the metric state with predictions and targets. Args: preds: Predicted sets tensor of shape ``(B, C)``, where ``B`` is the batch size and ``C`` is the number of classes. targets: Unused. Kept for API consistency. Defaults to ``None``. """ batch_size = preds.size(0) pred_sizes = preds.bool().sum(-1) if self.reduction is None or self.reduction == "none": sizes = cast("list[Tensor]", self.sizes) sizes.append(pred_sizes) else: self.sizes += pred_sizes.sum() self.total += batch_size
[docs] def compute(self) -> Tensor: """Compute the set size. Returns: Tensor: The set size according to the selected reduction. """ values = dim_zero_cat(self.sizes) if self.reduction == "sum": return values if self.reduction == "mean": return _safe_divide(values, self.total) return values