GroupingLoss¶
- class torch_uncertainty.metrics.classification.GroupingLoss(**kwargs)[source]¶
Metric to estimate the Top-label Grouping Loss.
- Parameters:
kwargs – Additional keyword arguments, see Advanced metric settings.
- Inputs:
probs
: ortarget
: orfeatures
: or
where
is the batch size, is the number of classes and is the number of estimators.
Warning
Make sure that the probabilities in
probs
are normalized to sum to one.- Raises:
ValueError – If
reduction
is not one of'mean'
,'sum'
,'none'
orNone
.
- Reference:
Perez-Lebel, Alexandre, Le Morvan, Marine and Varoquaux, Gaël. Beyond calibration: estimating the grouping loss of modern neural networks. In ICLR 2023.
- compute()[source]¶
Compute the final Brier score based on inputs passed to
update
.- Returns:
The final value(s) for the Brier score
- Return type:
torch.Tensor
- update(probs, target, features)[source]¶
Accumulate the tensors for the estimation of the Grouping Loss.
- Parameters:
probs (Tensor) – A probability tensor of shape (batch, num_classes), (batch, num_estimators, num_classes), or (batch) if binary classification
target (Tensor) – A tensor of ground truth labels of shape (batch, num_classes) or (batch)
features (Tensor) – A tensor of features of shape (batch, num_estimators, num_features) or (batch, num_features)