Source code for torch_uncertainty.metrics.classification.grouping_loss
from importlib import util
import torch
if util.find_spec("glest"):
from glest import GLEstimator as GLEstimatorBase
glest_installed = True
else: # coverage: ignore
glest_installed = False
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
class GLEstimator(GLEstimatorBase):
def fit(
self, probs: Tensor, targets: Tensor, features: Tensor
) -> "GLEstimator":
probs = probs.detach().cpu().numpy()
features = features.detach().cpu().numpy()
targets = (targets * 1).detach().cpu().numpy()
self.classifier = probs
return super().fit(features, targets)
[docs]class GroupingLoss(Metric):
is_differentiable: bool = False
higher_is_better: bool | None = False
full_state_update: bool = False
def __init__(
self,
**kwargs,
) -> None:
r"""Metric to estimate the Top-label Grouping Loss.
Args:
kwargs: Additional keyword arguments, see `Advanced metric settings
<https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_.
Inputs:
- :attr:`probs`: :math:`(B, C)` or :math:`(B, N, C)`
- :attr:`target`: :math:`(B)` or :math:`(B, C)`
- :attr:`features`: :math:`(B, F)` or :math:`(B, N, F)`
where :math:`B` is the batch size, :math:`C` is the number of classes
and :math:`N` is the number of estimators.
Warning:
Make sure that the probabilities in :attr:`probs` are normalized to sum
to one.
Raises:
ValueError:
If :attr:`reduction` is not one of ``'mean'``, ``'sum'``,
``'none'`` or ``None``.
Reference:
Perez-Lebel, Alexandre, Le Morvan, Marine and Varoquaux, Gaƫl.
Beyond calibration: estimating the grouping loss of modern neural
networks. In ICLR 2023.
"""
super().__init__(**kwargs)
if not glest_installed: # coverage: ignore
raise ImportError(
"The glest library is not installed. Please install"
"torch_uncertainty with the all option:"
"""pip install -U "torch_uncertainty[all]"."""
)
self.estimator = GLEstimator(None)
self.add_state("probs", default=[], dist_reduce_fx="cat")
self.add_state("targets", default=[], dist_reduce_fx="cat")
self.add_state("features", default=[], dist_reduce_fx="cat")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
rank_zero_warn(
"Metric `GroupingLoss` will save all targets, predictions and features"
" in buffer. For large datasets this may lead to large memory"
" footprint."
)
[docs] def update(self, probs: Tensor, target: Tensor, features: Tensor) -> None:
"""Accumulate the tensors for the estimation of the Grouping Loss.
Args:
probs (Tensor): A probability tensor of shape (batch, num_classes),
(batch, num_estimators, num_classes), or (batch) if binary
classification
target (Tensor): A tensor of ground truth labels of shape
(batch, num_classes) or (batch)
features (Tensor): A tensor of features of shape
(batch, num_estimators, num_features) or (batch, num_features)
"""
if target.ndim == 2:
target = target.argmax(dim=-1)
elif target.ndim != 1:
raise ValueError(
"Expected `target` to be of shape (batch) or (batch, num_classes) "
f"but got {target.shape}."
)
if probs.ndim == 1:
self.probs.append(probs)
self.targets.append(target == (probs > 0.5).int())
elif probs.ndim == 2:
max_probs = probs.max(-1)
self.probs.append(max_probs.values)
self.targets.append(target == max_probs.indices)
elif probs.ndim == 3:
max_probs = probs.mean(1).max(-1)
self.probs.append(max_probs.values)
self.targets.append(target == max_probs.indices)
else:
raise ValueError(
"Expected `probs` to be of shape (batch, num_classes) or "
"(batch, num_estimators, num_classes) or (batch) "
f"but got {probs.shape}."
)
if features.ndim == 2:
self.features.append(features)
elif features.ndim == 3:
self.features.append(features[:, 0, :])
else:
raise ValueError(
"Expected `features` to be of shape (batch, num_features) or "
"(batch, num_estimators, num_features) but got "
f"{features.shape}."
)
[docs] def compute(self) -> Tensor:
"""Compute the final Brier score based on inputs passed to ``update``.
Returns:
torch.Tensor: The final value(s) for the Brier score
"""
probs = dim_zero_cat(self.probs)
features = dim_zero_cat(self.features)
targets = dim_zero_cat(self.targets)
estimator = self.estimator.fit(probs, targets, features)
return estimator.metrics("brier")["GL"]