Shortcuts

Source code for torch_uncertainty.post_processing.laplace

from importlib import util
from typing import Literal

from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset

from .abstract import PostProcessing

if util.find_spec("laplace"):
    from laplace import Laplace

    laplace_installed = True
else:  # coverage: ignore
    laplace_installed = False


[docs]class LaplaceApprox(PostProcessing): def __init__( self, task: Literal["classification", "regression"], model: nn.Module | None = None, weight_subset="last_layer", hessian_struct="kron", pred_type: Literal["glm", "nn"] = "glm", link_approx: Literal[ "mc", "probit", "bridge", "bridge_norm" ] = "probit", batch_size: int = 256, optimize_prior_precision: bool = True, ) -> None: """Laplace approximation for uncertainty estimation. This class is a wrapper of Laplace classes from the laplace-torch library. Args: task (Literal["classification", "regression"]): task type. model (nn.Module): model to be converted. weight_subset (str): subset of weights to be considered. Defaults to "last_layer". hessian_struct (str): structure of the Hessian matrix. Defaults to "kron". pred_type (Literal["glm", "nn"], optional): type of posterior predictive, See the Laplace library for more details. Defaults to "glm". link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): how to approximate the classification link function for the `'glm'`. See the Laplace library for more details. Defaults to "probit". batch_size (int, optional): batch size for the Laplace approximation. Defaults to 256. optimize_prior_precision (bool, optional): whether to optimize the prior precision. Defaults to True. Reference: Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. """ super().__init__() if not laplace_installed: # coverage: ignore raise ImportError( "The laplace-torch library is not installed. Please install" "torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" ) self.pred_type = pred_type self.link_approx = link_approx self.task = task self.weight_subset = weight_subset self.hessian_struct = hessian_struct self.batch_size = batch_size self.optimize_prior_precision = optimize_prior_precision if model is not None: self.set_model(model) def set_model(self, model: nn.Module) -> None: super().set_model(model) self.la = Laplace( model=model, likelihood=self.task, subset_of_weights=self.weight_subset, hessian_structure=self.hessian_struct, ) def fit(self, dataset: Dataset) -> None: dl = DataLoader(dataset, batch_size=self.batch_size) self.la.fit(train_loader=dl) if self.optimize_prior_precision: self.la.optimize_prior_precision(method="marglik") def forward( self, x: Tensor, ) -> Tensor: return self.la( x, pred_type=self.pred_type, link_approx=self.link_approx )