Source code for torch_uncertainty.metrics.segmentation.seg_binary_average_precision

from typing import Any

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import BinaryAveragePrecision


[docs] class SegmentationBinaryAveragePrecision(Metric): is_differentiable = False higher_is_better = True full_state_update = False binary_aupr: Tensor total: Tensor def __init__( self, thresholds: int | list[float] | Tensor | None = None, ignore_index: int | None = None, validate_args: bool = True, **kwargs: Any, ) -> None: r"""Image-averaged binary Average Precision for dense segmentation tasks. Per-image Average Precision summarises the precision-recall curve obtained by sweeping a threshold over the pixel scores of image :math:`b`: .. math:: \text{AP}_b = \sum_{k} \left( R_b(k) - R_b(k-1) \right) P_b(k), where :math:`P_b(k)` and :math:`R_b(k)` are the precision and recall at the :math:`k`-th threshold. The final metric is averaged over all :math:`B` images: .. math:: \text{AP} = \frac{1}{B} \sum_{b=1}^{B} \text{AP}_b. As for :class:`SegmentationBinaryAUROC`, image-wise averaging is the convention used in the dense OOD-detection literature. Args: thresholds: Optional explicit thresholds for the PR curve, see :class:`~torchmetrics.classification.BinaryAveragePrecision`. ignore_index: Optional label value to ignore. validate_args: Whether to validate input arguments. kwargs: Additional keyword arguments, see `Advanced metric settings <https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_. """ super().__init__(**kwargs) self.aupr_metric = BinaryAveragePrecision( thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args, **kwargs ) self.add_state("binary_aupr", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # pyrefly: ignore[bad-override] batch_size = preds.size(0) aupr = self.aupr_metric(preds, target) self.binary_aupr += aupr * batch_size self.total += batch_size def compute(self) -> Tensor: if self.total == 0: return torch.tensor(0.0, device=self.binary_aupr.device) return self.binary_aupr / self.total