Source code for torch_uncertainty.layers.bayesian.bayes_linear
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import init
from .sampler import CenteredGaussianMixture, TrainableDistribution
[docs]class BayesLinear(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: Tensor
lprior: Tensor
lvposterior: Tensor
def __init__(
self,
in_features: int,
out_features: int,
prior_sigma_1: float = 0.1,
prior_sigma_2: float = 0.4,
prior_pi: float = 1,
mu_init: float = 0.0,
sigma_init: float = -7.0,
frozen: bool = False,
bias: bool = True,
device=None,
dtype=None,
) -> None:
"""Bayesian Linear Layer with Mixture of Normals prior and Normal posterior.
Args:
in_features (int): Number of input features
out_features (int): Number of output features
prior_sigma_1 (float, optional): Standard deviation of the first prior
distribution. Defaults to 0.1.
prior_sigma_2 (float, optional): Standard deviation of the second prior
distribution. Defaults to 0.1.
prior_pi (float, optional): Mixture control variable. Defaults to 0.1.
mu_init (float, optional): Initial mean of the posterior distribution.
Defaults to 0.0.
sigma_init (float, optional): Initial standard deviation of the
posterior distribution. Defaults to -7.0.
frozen (bool, optional): Whether to freeze the posterior distribution.
Defaults to False.
bias (bool, optional): Whether to use a bias term. Defaults to True.
device (optional): Device to use. Defaults to None.
dtype (optional): Data type to use. Defaults to None.
Paper Reference:
Blundell, Charles, et al. "Weight uncertainty in neural networks"
ICML 2015.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_sigma_1 = prior_sigma_1
self.prior_sigma_2 = prior_sigma_2
self.prior_pi = prior_pi
self.mu_init = mu_init
self.sigma_init = sigma_init
self.frozen = frozen
self.weight_mu = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
self.weight_sigma = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias_mu = nn.Parameter(torch.empty(out_features, **factory_kwargs))
self.bias_sigma = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias_mu", None)
self.register_parameter("bias_log_sigma", None)
self.reset_parameters()
self.weight_sampler = TrainableDistribution(self.weight_mu, self.weight_sigma)
if bias:
self.bias_sampler = TrainableDistribution(self.bias_mu, self.bias_sigma)
self.weight_prior_dist = CenteredGaussianMixture(prior_sigma_1, prior_sigma_2, prior_pi)
if bias:
self.bias_prior_dist = CenteredGaussianMixture(prior_sigma_1, prior_sigma_2, prior_pi)
def reset_parameters(self) -> None:
# TODO: change init
init.normal_(self.weight_mu, mean=self.mu_init, std=0.1)
init.normal_(self.weight_sigma, mean=self.sigma_init, std=0.1)
if self.bias_mu is not None:
init.normal_(self.bias_mu, mean=self.mu_init, std=0.1)
init.normal_(self.bias_sigma, mean=self.sigma_init, std=0.1)
def forward(self, inputs: Tensor) -> Tensor:
if self.frozen:
return self._frozen_forward(inputs)
return self._forward(inputs)
def _frozen_forward(self, inputs) -> Tensor:
return F.linear(inputs, self.weight_mu, self.bias_mu)
def _forward(self, inputs: Tensor) -> Tensor:
weight = self.weight_sampler.sample()
if self.bias_mu is not None:
bias = self.bias_sampler.sample()
bias_lposterior = self.bias_sampler.log_posterior()
bias_lprior = self.bias_prior_dist.log_prob(bias)
else:
bias, bias_lposterior, bias_lprior = None, 0, 0
self.lvposterior = self.weight_sampler.log_posterior() + bias_lposterior
self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior
return F.linear(inputs, weight, bias)
[docs] def freeze(self) -> None:
"""Freeze the layer by setting the frozen attribute to True."""
self.frozen = True
[docs] def unfreeze(self) -> None:
"""Unfreeze the layer by setting the frozen attribute to False."""
self.frozen = False
[docs] def sample(self) -> tuple[Tensor, Tensor | None]:
"""Sample the Bayesian layer's posterior."""
weight = self.weight_sampler.sample()
bias = self.bias_sampler.sample() if self.bias_mu is not None else None
return weight, bias
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias_mu is not None}"