Shortcuts

GroupingLoss

class torch_uncertainty.metrics.classification.GroupingLoss(**kwargs)[source]

Metric to estimate the Top-label Grouping Loss.

Parameters:

kwargs – Additional keyword arguments, see Advanced metric settings.

Inputs:
  • probs: \((B, C)\) or \((B, N, C)\)

  • target: \((B)\) or \((B, C)\)

  • features: \((B, F)\) or \((B, N, F)\)

where \(B\) is the batch size, \(C\) is the number of classes and \(N\) is the number of estimators.

Warning

Make sure that the probabilities in probs are normalized to sum to one.

Raises:

ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

Reference:

Perez-Lebel, Alexandre, Le Morvan, Marine and Varoquaux, Gaël. Beyond calibration: estimating the grouping loss of modern neural networks. In ICLR 2023.

compute()[source]

Compute the final Brier score based on inputs passed to update.

Returns:

The final value(s) for the Brier score

Return type:

torch.Tensor

update(probs, target, features)[source]

Accumulate the tensors for the estimation of the Grouping Loss.

Parameters:
  • probs (Tensor) – A probability tensor of shape (batch, num_classes), (batch, num_estimators, num_classes), or (batch) if binary classification

  • target (Tensor) – A tensor of ground truth labels of shape (batch, num_classes) or (batch)

  • features (Tensor) – A tensor of features of shape (batch, num_estimators, num_features) or (batch, num_features)