Source code for torch_uncertainty.metrics.segmentation.seg_binary_auroc
from typing import Any
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import BinaryAUROC
[docs]
class SegmentationBinaryAUROC(Metric):
is_differentiable = False
higher_is_better = True
full_state_update = False
binary_auroc: Tensor
total: Tensor
def __init__(
self,
max_fpr: float | None = None,
thresholds: int | list[float] | Tensor | None = None,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
r"""Image-averaged binary AUROC for dense binary segmentation tasks.
At each image, a per-pixel binary AUROC is computed from the pixel scores
:math:`s_{ij}` and binary labels :math:`y_{ij} \in \{0, 1\}`:
.. math::
\text{AUROC}_b = \int_0^1 \text{TPR}_b\!\left(\text{FPR}_b^{-1}(u)\right) \mathrm{d}u,
where TPR and FPR are computed by sweeping a threshold over the pixel-level
scores of image :math:`b`. The final metric is the average over all images:
.. math::
\text{AUROC} = \frac{1}{B} \sum_{b=1}^{B} \text{AUROC}_b.
This image-wise averaging is the convention used in the dense OOD-detection
literature (e.g., MUAD) and behaves better than computing AUROC over the
flattened set of all pixels when image sizes or OOD prevalences vary.
Args:
max_fpr: If set, computes the partial AUROC up to this FPR value
(passed to :class:`~torchmetrics.classification.BinaryAUROC`).
thresholds: Optional explicit thresholds to use when computing the ROC.
ignore_index: Optional label value to ignore.
validate_args: Whether to validate input arguments.
kwargs: Additional keyword arguments, see `Advanced metric settings
<https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_.
"""
super().__init__(**kwargs)
self.auroc_metric = BinaryAUROC(
max_fpr=max_fpr,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=validate_args,
**kwargs,
)
self.add_state("binary_auroc", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None: # pyrefly: ignore[bad-override]
batch_size = preds.size(0)
auroc = self.auroc_metric(preds, target)
self.binary_auroc += auroc * batch_size
self.total += batch_size
def compute(self) -> Tensor:
if self.total == 0:
return torch.tensor(0.0, device=self.binary_auroc.device)
return self.binary_auroc / self.total