Shortcuts

Source code for torch_uncertainty.metrics.classification.mean_iou

from torch import Tensor
from torchmetrics.classification.stat_scores import MulticlassStatScores
from torchmetrics.utilities.compute import _safe_divide


[docs]class MeanIntersectionOverUnion(MulticlassStatScores): is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False def __init__( self, num_classes: int, top_k: int = 1, ignore_index: int | None = None, validate_args: bool = True, **kwargs, ) -> None: r"""Computes Mean Intersection over Union (IoU) score. Args: num_classes (int): Integer specifying the number of classes. top_k (int, optional): Number of highest probability or logit score predictions considered to find the correct label. Only works when ``preds`` contain probabilities/logits. Defaults to ``1``. ignore_index (int | None, optional): Specifies a target value that is ignored and does not contribute to the metric calculation. Defaults to ``None``. validate_args (bool, optional): Bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. Defaults to ``True``. **kwargs: kwargs: Additional keyword arguments, see `Advanced metric settings <https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metric-kwargs>`_ for more info. Shape: As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, ...)`` or float tensor of shape ``(B, C, ..)``. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, ...)``. As output to ``forward`` and ``compute`` the metric returns the following output: - ``mean_iou`` (:class:`~torch.Tensor`): The computed Mean Intersection over Union (IoU) score. A tensor containing a single float value. """ super().__init__( num_classes, top_k, "macro", "global", ignore_index, validate_args, **kwargs, )
[docs] def compute(self) -> Tensor: """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() return _safe_divide(tp, tp + fp + fn, zero_division=float("nan")).nanmean()