Source code for torch_uncertainty.metrics.regression.mse_log

from torch import Tensor
from torchmetrics import MeanSquaredError


[docs] class MeanSquaredLogError(MeanSquaredError): def __init__(self, squared: bool = True, **kwargs) -> None: r"""Computes the Mean Squared Logarithmic Error (MSLE) regression metric. This metric is commonly used in regression problems where the relative difference between predictions and targets is of greater importance than the absolute difference. It is particularly effective for datasets with wide-ranging magnitudes, as it penalizes underestimation more than overestimation. .. math:: \text{MSELog} = \frac{1}{N}\sum_i^N (\log \hat{y_i} - \log y_i)^2 where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. As input to ``forward`` and ``update`` the metric accepts the following input: - **preds** (:class:`~torch.Tensor`): Predictions from model - **target** (:class:`~torch.Tensor`): Ground truth values As output of ``forward`` and ``compute`` the metric returns the following output: - **mse_log** (:class:`~torch.Tensor`): A tensor with the relative mean absolute error over the state Args: squared: If True returns MSELog value, if False returns EMSELog value. kwargs: Additional keyword arguments, see `Advanced metric settings <https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_. Reference: [1] `From big to small: Multi-scale local planar guidance for monocular depth estimation <https://arxiv.org/abs/1907.10326>`_. Example: .. code-block:: python from torch_uncertainty.metrics.regression import MeanSquaredLogError import torch # Initialize the metric msle_metric = MeanSquaredLogError(squared=True) # Example predictions and targets (must be non-negative) preds = torch.tensor([2.5, 1.0, 2.0, 8.0]) target = torch.tensor([3.0, 1.5, 2.0, 7.0]) # Update the metric state msle_metric.update(preds, target) # Compute the Mean Squared Logarithmic Error result = msle_metric.compute() print(f"Mean Squared Logarithmic Error: {result.item()}") # Output: Mean Squared Logarithmic Error: 0.05386843904852867 """ super().__init__(squared, **kwargs)
[docs] def update(self, pred: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" return super().update(pred.log(), target.log())