LaplaceApprox#

class torch_uncertainty.post_processing.LaplaceApprox(task, model=None, weight_subset='last_layer', hessian_struct='kron', pred_type='glm', link_approx='probit', optimize_prior_precision=True)[source]#

Laplace approximation for post-hoc Bayesian uncertainty estimation.

Fits a Gaussian posterior \(\mathcal{N}(\boldsymbol{\theta}_\text{MAP}, \mathbf{H}^{-1})\) around the MAP estimate of a trained network, where \(\mathbf{H}\) is (an approximation of) the Hessian of the negative log-posterior, computed on the calibration set. Predictions are then obtained by marginalising over the posterior — analytically for regression (and the "probit" / "bridge" classification approximations), or by Monte Carlo for "mc".

This class is a thin wrapper around the laplace-torch library.

Parameters:
  • task (Literal['classification', 'regression']) – Task type. Either "classification" or "regression".

  • model (Module | None) – Model to be converted.

  • weight_subset (str) – Subset of weights to be considered (e.g. "last_layer" or "all"). Defaults to "last_layer".

  • hessian_struct (str) – Structure of the Hessian approximation (e.g. "kron", "diag", "full"). Defaults to "kron".

  • pred_type (Literal['glm', 'nn']) – Type of posterior predictive. See the Laplace library for details. Defaults to "glm".

  • link_approx (Literal['mc', 'probit', 'bridge', 'bridge_norm']) – How to approximate the classification link function for the "glm" predictive. See the Laplace library for details. Defaults to "probit".

  • optimize_prior_precision (bool) – Whether to optimize the prior precision by marginal likelihood. Defaults to True.

References

[1] Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., & Hennig, P. (2021). Laplace Redux — Effortless Bayesian Deep Learning. NeurIPS 2021.