Source code for torch_uncertainty.post_processing.laplace
from importlib import util
from typing import Literal
from torch import Tensor, nn
from torch.utils.data import DataLoader
from .abstract import PostProcessing
if util.find_spec("laplace"):
from laplace import Laplace
laplace_installed = True
else: # coverage: ignore
laplace_installed = False
[docs]
class LaplaceApprox(PostProcessing):
def __init__(
self,
task: Literal["classification", "regression"],
model: nn.Module | None = None,
weight_subset: str = "last_layer",
hessian_struct: str = "kron",
pred_type: Literal["glm", "nn"] = "glm",
link_approx: Literal["mc", "probit", "bridge", "bridge_norm"] = "probit",
optimize_prior_precision: bool = True,
) -> None:
r"""Laplace approximation for post-hoc Bayesian uncertainty estimation.
Fits a Gaussian posterior :math:`\mathcal{N}(\boldsymbol{\theta}_\text{MAP},
\mathbf{H}^{-1})` around the MAP estimate of a trained network, where
:math:`\mathbf{H}` is (an approximation of) the Hessian of the negative
log-posterior, computed on the calibration set. Predictions are then obtained
by marginalising over the posterior — analytically for regression (and the
``"probit"`` / ``"bridge"`` classification approximations), or by Monte Carlo
for ``"mc"``.
This class is a thin wrapper around the
`laplace-torch <https://github.com/aleximmer/Laplace>`_ library.
Args:
task: Task type. Either ``"classification"`` or ``"regression"``.
model: Model to be converted.
weight_subset: Subset of weights to be considered (e.g. ``"last_layer"``
or ``"all"``). Defaults to ``"last_layer"``.
hessian_struct: Structure of the Hessian approximation (e.g. ``"kron"``,
``"diag"``, ``"full"``). Defaults to ``"kron"``.
pred_type: Type of posterior predictive. See the Laplace library for
details. Defaults to ``"glm"``.
link_approx: How to approximate the classification link function for the
``"glm"`` predictive. See the Laplace library for details. Defaults to
``"probit"``.
optimize_prior_precision: Whether to optimize the prior precision by
marginal likelihood. Defaults to ``True``.
References:
[1] `Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., &
Hennig, P. (2021). Laplace Redux — Effortless Bayesian Deep Learning. NeurIPS 2021
<https://arxiv.org/abs/2106.14806>`_.
"""
super().__init__()
if not laplace_installed:
raise ImportError(
"The laplace-torch library is not installed. Please install "
"torch_uncertainty with the all option: pip install -U torch_uncertainty[all]"
)
self.pred_type = pred_type
self.link_approx = link_approx
self.task = task
self.weight_subset = weight_subset
self.hessian_struct = hessian_struct
self.optimize_prior_precision = optimize_prior_precision
if model is not None:
self.set_model(model)
def set_model(self, model: nn.Module) -> None:
super().set_model(model)
self.la = Laplace(
model=model,
likelihood=self.task,
subset_of_weights=self.weight_subset,
hessian_structure=self.hessian_struct,
)
def fit(self, dataloader: DataLoader) -> None:
self.la.fit(train_loader=dataloader)
if self.optimize_prior_precision:
self.la.optimize_prior_precision(method="marglik", pred_type=self.pred_type)
def forward(
self,
inputs: Tensor,
) -> Tensor:
out = self.la(inputs, pred_type=self.pred_type, link_approx=self.link_approx, n_samples=100)
if isinstance(out, tuple): # coverage: ignore
return out[0]
return out