Source code for torch_uncertainty.metrics.classification.coverage_rate
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.data import _bincount
[docs]
class CoverageRate(Metric):
is_differentiable = False
higher_is_better = True
full_state_update = False
def __init__(
self,
num_classes: int | None = None,
average: str = "micro",
validate_args: bool = True,
**kwargs,
) -> None:
"""Empirical coverage rate metric.
Args:
num_classes (int | None, optional): Number of classes. Defaults to ``None``.
average (str, optional): Defines the reduction that is applied over labels. Should be
one of the following:
- ``'macro'`` (default): Compute the metric for each class separately and find their
unweighted mean. This does not take label imbalance into account.
- ``'micro'``: Sum statistics across over all labels.
validate_args (bool, optional): Whether to validate the arguments. Defaults to ``True``.
kwargs: Additional keyword arguments, see `Advanced metric settings
<https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_.
Raises:
ValueError: If `num_classes` is `None` and `average` is not `micro`.
ValueError: If `num_classes` is not an integer larger than 1.
ValueError: If `average` is not one of `macro` or `micro`.
"""
super().__init__(**kwargs)
if validate_args:
if num_classes is None and average != "micro":
raise ValueError(
f"Argument `num_classes` can only be `None` for `average='micro'`, but got `average={average}`."
)
if num_classes is not None and (not isinstance(num_classes, int) or num_classes < 2):
raise ValueError(
f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}"
)
if average not in ["macro", "micro"]:
raise ValueError("average must be either 'macro' or 'micro'.")
self.num_classes = num_classes
self.average = average
self.validate_args = validate_args
size = 1 if (average == "micro" or num_classes is None) else num_classes
self.add_state("correct", torch.zeros(size, dtype=torch.long), dist_reduce_fx="sum")
self.add_state("total", torch.zeros(size, dtype=torch.float), dist_reduce_fx="sum")
[docs]
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update the metric state with predictions and targets.
Args:
preds (torch.Tensor): predicted sets tensor of shape (B, C), where B is the batch size
and C is the number of classes.
target (torch.Tensor): target sets tensor of shape (B,).
"""
batch_size = preds.size(0)
target = target.long()
covered = preds[torch.arange(batch_size), target] # (B,)
if self.average == "micro":
self.correct += covered.bool().sum()
self.total += batch_size
else:
self.correct += _bincount(target[covered.bool()], self.num_classes)
self.total += _bincount(target, self.num_classes)
[docs]
def compute(self) -> Tensor:
"""Compute the coverage rate.
Returns:
Tensor: The coverage rate.
"""
if self.average == "micro":
return _safe_divide(self.correct, self.total)
return _safe_divide(self.correct, self.total).mean()