LaplaceApprox#
- class torch_uncertainty.post_processing.LaplaceApprox(task, model=None, weight_subset='last_layer', hessian_struct='kron', pred_type='glm', link_approx='probit', optimize_prior_precision=True)[source]#
Laplace approximation for post-hoc Bayesian uncertainty estimation.
Fits a Gaussian posterior \(\mathcal{N}(\boldsymbol{\theta}_\text{MAP}, \mathbf{H}^{-1})\) around the MAP estimate of a trained network, where \(\mathbf{H}\) is (an approximation of) the Hessian of the negative log-posterior, computed on the calibration set. Predictions are then obtained by marginalising over the posterior — analytically for regression (and the
"probit"/"bridge"classification approximations), or by Monte Carlo for"mc".This class is a thin wrapper around the laplace-torch library.
- Parameters:
task (
Literal['classification','regression']) – Task type. Either"classification"or"regression".model (
Module|None) – Model to be converted.weight_subset (
str) – Subset of weights to be considered (e.g."last_layer"or"all"). Defaults to"last_layer".hessian_struct (
str) – Structure of the Hessian approximation (e.g."kron","diag","full"). Defaults to"kron".pred_type (
Literal['glm','nn']) – Type of posterior predictive. See the Laplace library for details. Defaults to"glm".link_approx (
Literal['mc','probit','bridge','bridge_norm']) – How to approximate the classification link function for the"glm"predictive. See the Laplace library for details. Defaults to"probit".optimize_prior_precision (
bool) – Whether to optimize the prior precision by marginal likelihood. Defaults toTrue.
References