Shortcuts

Source code for torch_uncertainty.baselines.regression.mlp

from typing import Literal

from torch import nn

from torch_uncertainty.layers.distributions import (
    LaplaceLayer,
    NormalInverseGammaLayer,
    NormalLayer,
)
from torch_uncertainty.models.mlp import mlp, packed_mlp
from torch_uncertainty.routines.regression import (
    RegressionRoutine,
)
from torch_uncertainty.transforms.batch import RepeatTarget

ENSEMBLE_METHODS = ["packed"]


[docs]class MLPBaseline(RegressionRoutine): versions = {"std": mlp, "packed": packed_mlp} def __init__( self, output_dim: int, in_features: int, loss: nn.Module, version: Literal["std", "packed"], hidden_dims: list[int], num_estimators: int | None = 1, dropout_rate: float = 0.0, alpha: float | None = None, gamma: int = 1, distribution: Literal["normal", "laplace", "nig"] | None = None, ) -> None: r"""MLP baseline for regression providing support for various versions.""" probabilistic = True params = { "dropout_rate": dropout_rate, "in_features": in_features, "num_outputs": output_dim, "hidden_dims": hidden_dims, } if distribution == "normal": final_layer = NormalLayer final_layer_args = {"dim": output_dim} params["num_outputs"] *= 2 elif distribution == "laplace": final_layer = LaplaceLayer final_layer_args = {"dim": output_dim} params["num_outputs"] *= 2 elif distribution == "nig": final_layer = NormalInverseGammaLayer final_layer_args = {"dim": output_dim} params["num_outputs"] *= 4 else: # distribution is None: probabilistic = False final_layer = nn.Identity final_layer_args = {} params["final_layer"] = final_layer params["final_layer_args"] = final_layer_args format_batch_fn = nn.Identity() if version not in self.versions: raise ValueError(f"Unknown version: {version}") if version == "packed": params |= { "alpha": alpha, "num_estimators": num_estimators, "gamma": gamma, } format_batch_fn = RepeatTarget(num_repeats=num_estimators) model = self.versions[version](**params) # version in self.versions: super().__init__( probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"])