Source code for torch_uncertainty.metrics.segmentation.seg_binary_auroc
from typing import Any
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import BinaryAUROC
[docs]
class SegmentationBinaryAUROC(Metric):
is_differentiable = False
higher_is_better = True
full_state_update = False
def __init__(
self,
max_fpr: float | None = None,
thresholds: int | list[float] | Tensor | None = None,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
"""SegmentationBinaryAUROC computes the Area Under the Receiver Operating Characteristic Curve (AUROC)
for binary segmentation tasks. It aggregates the AUROC across batches and computes the average AUROC
over all batches processed.
"""
super().__init__(**kwargs)
self.auroc_metric = BinaryAUROC(
max_fpr=max_fpr,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=validate_args,
**kwargs,
)
self.add_state("binary_auroc", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None:
batch_size = preds.size(0)
auroc = self.auroc_metric(preds, target)
self.binary_auroc += auroc * batch_size
self.total += batch_size
def compute(self) -> Tensor:
if self.total == 0:
return torch.tensor(0.0, device=self.binary_auroc.device)
return self.binary_auroc / self.total