[docs]classBayesLinear(nn.Module):__constants__=["in_features","out_features"]in_features:intout_features:intweight:Tensorlprior:Tensorlvposterior:Tensordef__init__(self,in_features:int,out_features:int,prior_sigma_1:float=0.1,prior_sigma_2:float=0.4,prior_pi:float=1,mu_init:float=0.0,sigma_init:float=-7.0,frozen:bool=False,bias:bool=True,device=None,dtype=None,)->None:"""Bayesian Linear Layer with Mixture of Normals prior and Normal posterior. Args: in_features (int): Number of input features out_features (int): Number of output features prior_sigma_1 (float, optional): Standard deviation of the first prior distribution. Defaults to 0.1. prior_sigma_2 (float, optional): Standard deviation of the second prior distribution. Defaults to 0.1. prior_pi (float, optional): Mixture control variable. Defaults to 0.1. mu_init (float, optional): Initial mean of the posterior distribution. Defaults to 0.0. sigma_init (float, optional): Initial standard deviation of the posterior distribution. Defaults to -7.0. frozen (bool, optional): Whether to freeze the posterior distribution. Defaults to False. bias (bool, optional): Whether to use a bias term. Defaults to True. device (optional): Device to use. Defaults to None. dtype (optional): Data type to use. Defaults to None. Paper Reference: Blundell, Charles, et al. "Weight uncertainty in neural networks" ICML 2015. """factory_kwargs={"device":device,"dtype":dtype}super().__init__()self.in_features=in_featuresself.out_features=out_featuresself.prior_sigma_1=prior_sigma_1self.prior_sigma_2=prior_sigma_2self.prior_pi=prior_piself.mu_init=mu_initself.sigma_init=sigma_initself.frozen=frozenself.weight_mu=nn.Parameter(torch.empty((out_features,in_features),**factory_kwargs))self.weight_sigma=nn.Parameter(torch.empty((out_features,in_features),**factory_kwargs))ifbias:self.bias_mu=nn.Parameter(torch.empty(out_features,**factory_kwargs))self.bias_sigma=nn.Parameter(torch.empty(out_features,**factory_kwargs))else:self.register_parameter("bias_mu",None)self.register_parameter("bias_log_sigma",None)self.reset_parameters()self.weight_sampler=TrainableDistribution(self.weight_mu,self.weight_sigma)ifbias:self.bias_sampler=TrainableDistribution(self.bias_mu,self.bias_sigma)self.weight_prior_dist=CenteredGaussianMixture(prior_sigma_1,prior_sigma_2,prior_pi)ifbias:self.bias_prior_dist=CenteredGaussianMixture(prior_sigma_1,prior_sigma_2,prior_pi)defreset_parameters(self)->None:# TODO: change initinit.normal_(self.weight_mu,mean=self.mu_init,std=0.1)init.normal_(self.weight_sigma,mean=self.sigma_init,std=0.1)ifself.bias_muisnotNone:init.normal_(self.bias_mu,mean=self.mu_init,std=0.1)init.normal_(self.bias_sigma,mean=self.sigma_init,std=0.1)defforward(self,inputs:Tensor)->Tensor:ifself.frozen:returnself._frozen_forward(inputs)returnself._forward(inputs)def_frozen_forward(self,inputs)->Tensor:returnF.linear(inputs,self.weight_mu,self.bias_mu)def_forward(self,inputs:Tensor)->Tensor:weight=self.weight_sampler.sample()ifself.bias_muisnotNone:bias=self.bias_sampler.sample()bias_lposterior=self.bias_sampler.log_posterior()bias_lprior=self.bias_prior_dist.log_prob(bias)else:bias,bias_lposterior,bias_lprior=None,0,0self.lvposterior=self.weight_sampler.log_posterior()+bias_lposteriorself.lprior=self.weight_prior_dist.log_prob(weight)+bias_lpriorreturnF.linear(inputs,weight,bias)
[docs]deffreeze(self)->None:"""Freeze the layer by setting the frozen attribute to True."""self.frozen=True
[docs]defunfreeze(self)->None:"""Unfreeze the layer by setting the frozen attribute to False."""self.frozen=False
[docs]defsample(self)->tuple[Tensor,Tensor|None]:"""Sample the Bayesian layer's posterior."""weight=self.weight_sampler.sample()bias=self.bias_sampler.sample()ifself.bias_muisnotNoneelseNonereturnweight,bias