Shortcuts

LaplaceApprox

class torch_uncertainty.post_processing.LaplaceApprox(task, model=None, weight_subset='last_layer', hessian_struct='kron', pred_type='glm', link_approx='probit', batch_size=256, optimize_prior_precision=True)[source]

Laplace approximation for uncertainty estimation.

This class is a wrapper of Laplace classes from the laplace-torch library.

Parameters:
  • task (Literal["classification", "regression"]) – task type.

  • model (nn.Module) – model to be converted.

  • weight_subset (str) – subset of weights to be considered. Defaults to “last_layer”.

  • hessian_struct (str) – structure of the Hessian matrix. Defaults to “kron”.

  • pred_type (Literal["glm", "nn"], optional) – type of posterior predictive, See the Laplace library for more details. Defaults to “glm”.

  • link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional) – how to approximate the classification link function for the ‘glm’. See the Laplace library for more details. Defaults to “probit”.

  • batch_size (int, optional) – batch size for the Laplace approximation. Defaults to 256.

  • optimize_prior_precision (bool, optional) – whether to optimize the prior precision. Defaults to True.

Reference:

Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021.