import inspect
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def get_dist_linear_layer(dist_family: str) -> type[nn.Module]:
if dist_family == "normal":
return NormalLinear
if dist_family == "laplace":
return LaplaceLinear
if dist_family == "cauchy":
return CauchyLinear
if dist_family == "student":
return StudentTLinear
if dist_family == "nig":
return NormalInverseGammaLinear
raise NotImplementedError(
f"{dist_family} distribution is not supported. Raise an issue if needed."
)
def get_dist_conv_layer(dist_family: str) -> type[nn.Module]:
if dist_family == "normal":
return NormalConvNd
if dist_family == "laplace":
return LaplaceConvNd
if dist_family == "cauchy":
return CauchyConvNd
if dist_family == "student":
return StudentTConvNd
if dist_family == "nig":
return NormalInverseGammaConvNd
raise NotImplementedError(
f"{dist_family} distribution is not supported. Raise an issue if needed."
)
class _ExpandOutputLinear(nn.Module):
"""Abstract class for expanding the output of any nn.Module using an `out_features` argument.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
num_params (int): The number of parameters to output. For instance, the normal distribution
has 2 parameters (loc and scale).
**layer_args: Additional arguments for the base layer.
"""
def __init__(self, base_layer: type[nn.Module], event_dim: int, num_params: int, **layer_args):
if "out_features" not in inspect.getfullargspec(base_layer.__init__).args:
raise ValueError(f"{base_layer.__name__} does not have an `out_features` argument.")
super().__init__()
self.base_layer = base_layer(out_features=num_params * event_dim, **layer_args)
self.event_dim = event_dim
def forward(self, x: Tensor) -> Tensor:
return self.base_layer(x)
class _ExpandOutputConvNd(nn.Module):
"""Abstract class for expanding the output of any nn.Module using an `out_channels` argument.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
num_params (int): The number of parameters to output. For instance, the normal distribution
has 2 parameters (loc and scale).
**layer_args: Additional arguments for the base layer.
"""
def __init__(self, base_layer: type[nn.Module], event_dim: int, num_params: int, **layer_args):
if "out_channels" not in inspect.getfullargspec(base_layer.__init__).args:
raise ValueError(f"{base_layer.__name__} does not have an `out_channels` argument.")
super().__init__()
self.base_layer = base_layer(out_channels=num_params * event_dim, **layer_args)
self.event_dim = event_dim
def forward(self, x: Tensor) -> Tensor:
return self.base_layer(x)
class _LocScaleLinear(_ExpandOutputLinear):
"""Base Linear layer for any distribution with loc and scale parameters.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
**layer_args: Additional arguments for the base layer.
"""
def __init__(
self,
base_layer: type[nn.Module],
event_dim: int,
min_scale: float = 1e-6,
**layer_args,
) -> None:
super().__init__(
base_layer=base_layer,
event_dim=event_dim,
num_params=2,
**layer_args,
)
self.min_scale = min_scale
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = super().forward(x)
loc = x[..., : self.event_dim]
scale = torch.clamp(
F.softplus(x[..., self.event_dim : 2 * self.event_dim]), min=self.min_scale
)
return {"loc": loc, "scale": scale}
class _LocScaleConvNd(_ExpandOutputConvNd):
"""Base Convolutional layer for any distribution with loc and scale parameters.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
**layer_args: Additional arguments for the base layer.
"""
def __init__(
self,
base_layer: type[nn.Module],
event_dim: int,
min_scale: float = 1e-6,
**layer_args,
) -> None:
super().__init__(
base_layer=base_layer,
event_dim=event_dim,
num_params=2,
**layer_args,
)
self.min_scale = min_scale
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = super().forward(x)
loc = x[:, : self.event_dim]
scale = torch.clamp(
F.softplus(x[:, self.event_dim : 2 * self.event_dim]), min=self.min_scale
)
return {"loc": loc, "scale": scale}
[docs]class NormalLinear(_LocScaleLinear):
r"""Normal Distribution Linear Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including
none and :math:`H_{in} = \text{in_features}`.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Normal distribution of shape :math:`(\ast, H_{out})` where
all but the last dimension are the same as the input and
:math:`H_{out} = \text{out_features}`.
- ``"scale"``: The standard deviation of the Normal distribution of shape
:math:`(\ast, H_{out})`.
"""
[docs]class NormalConvNd(_LocScaleConvNd):
r"""Normal Distribution Convolutional Density Layer.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of event channels.
kernel_size (int | tuple[int]): The size of the convolutional kernel.
stride (int | tuple[int]): The stride of the convolution.
padding (int | tuple[int]): The padding of the convolution.
dilation (int | tuple[int]): The dilation of the convolution.
groups (int): The number of groups in the convolution.
min_scale (float): The minimal value of the scale parameter.
device (torch.device): The device where the layer is stored.
dtype (torch.dtype): The datatype of the layer.
Shape:
- Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and
:math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Normal distribution of shape :math:`(N, C_{out}, \ast)` where
:math:`C_{out} = \text{out_channels}`.
- ``"scale"``: The standard deviation of the Normal distribution of shape
:math:`(\ast, C_{out}, \ast)`.
"""
[docs]class LaplaceLinear(_LocScaleLinear):
r"""Laplace Distribution Linear Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including
none and :math:`H_{in} = \text{in_features}`.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Laplace distribution of shape :math:`(\ast, H_{out})` where
all but the last dimension are the same as the input and
:math:`H_{out} = \text{out_features}`.
- ``"scale"``: The standard deviation of the Laplace distribution of shape
:math:`(\ast, H_{out})`.
"""
[docs]class LaplaceConvNd(_LocScaleConvNd):
r"""Laplace Distribution Convolutional Density Layer.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of event channels.
kernel_size (int | tuple[int]): The size of the convolutional kernel.
stride (int | tuple[int]): The stride of the convolution.
padding (int | tuple[int]): The padding of the convolution.
dilation (int | tuple[int]): The dilation of the convolution.
groups (int): The number of groups in the convolution.
min_scale (float): The minimal value of the scale parameter.
device (torch.device): The device where the layer is stored.
dtype (torch.dtype): The datatype of the layer.
Shape:
- Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and
:math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Laplace distribution of shape :math:`(N, C_{out}, \ast)` where
:math:`C_{out} = \text{out_channels}`.
- ``"scale"``: The standard deviation of the Laplace distribution of shape
:math:`(\ast, C_{out}, \ast)`.
"""
[docs]class CauchyLinear(_LocScaleLinear):
r"""Cauchy Distribution Linear Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including
none and :math:`H_{in} = \text{in_features}`.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Cauchy distribution of shape :math:`(\ast, H_{out})` where
all but the last dimension are the same as the input and
:math:`H_{out} = \text{out_features}`.
- ``"scale"``: The standard deviation of the Cauchy distribution of shape
:math:`(\ast, H_{out})`.
"""
[docs]class CauchyConvNd(_LocScaleConvNd):
r"""Cauchy Distribution Convolutional Density Layer.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of event channels.
kernel_size (int | tuple[int]): The size of the convolutional kernel.
stride (int | tuple[int]): The stride of the convolution.
padding (int | tuple[int]): The padding of the convolution.
dilation (int | tuple[int]): The dilation of the convolution.
groups (int): The number of groups in the convolution.
min_scale (float): The minimal value of the scale parameter.
device (torch.device): The device where the layer is stored.
dtype (torch.dtype): The datatype of the layer.
Shape:
- Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and
:math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Cauchy distribution of shape :math:`(N, C_{out}, \ast)` where
:math:`C_{out} = \text{out_channels}`.
- ``"scale"``: The standard deviation of the Cauchy distribution of shape
:math:`(\ast, C_{out}, \ast)`.
"""
[docs]class StudentTLinear(_ExpandOutputLinear):
r"""Student's T-Distribution Linear Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
min_df (float): The minimal value of the degrees of freedom parameter.
fixed_df (float): If not None, the degrees of freedom parameter is fixed to this value.
Otherwise, it is learned.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including
none and :math:`H_{in} = \text{in_features}`.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Student's t-distribution of shape :math:`(\ast, H_{out})` where
all but the last dimension are the same as the input and
:math:`H_{out} = \text{out_features}`.
- ``"scale"``: The standard deviation of the Student's t-distribution of shape
:math:`(\ast, H_{out})`.
- ``"df"``: The degrees of freedom of the Student's t distribution of shape
:math:`(\ast, H_{out})` or Number.
"""
def __init__(
self,
base_layer: type[nn.Module],
event_dim: int,
min_scale: float = 1e-6,
min_df: float = 2.0,
fixed_df: float | None = None,
**layer_args,
) -> None:
super().__init__(
base_layer=base_layer,
event_dim=event_dim,
num_params=3 if fixed_df is None else 2,
**layer_args,
)
self.min_scale = min_scale
self.min_df = min_df
self.fixed_df = fixed_df
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = super().forward(x)
loc = x[..., : self.event_dim]
scale = torch.clamp(
F.softplus(x[..., self.event_dim : 2 * self.event_dim]), min=self.min_scale
)
df = (
torch.clamp(F.softplus(x[..., 2 * self.event_dim :]), min=self.min_df)
if self.fixed_df is None
else torch.full_like(loc, self.fixed_df)
)
return {"loc": loc, "scale": scale, "df": df}
[docs]class StudentTConvNd(_ExpandOutputConvNd):
r"""Student's T-Distribution Convolutional Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_scale (float): The minimal value of the scale parameter.
min_df (float): The minimal value of the degrees of freedom parameter.
fixed_df (float): If not None, the degrees of freedom parameter is fixed to this value.
Otherwise, it is learned.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and
:math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Student's t-distribution of shape :math:`(N, C_{out}, \ast)` where
:math:`C_{out} = \text{out_channels}`.
- ``"scale"``: The standard deviation of the Student's t-distribution of shape
:math:`(\ast, C_{out}, \ast)`.
- ``"df"``: The degrees of freedom of the Student's t distribution of shape
:math:`(\ast, C_{out}, \ast)`.
"""
def __init__(
self,
base_layer: type[nn.Module],
event_dim: int,
min_scale: float = 1e-6,
min_df: float = 2.0,
fixed_df: float | None = None,
**layer_args,
) -> None:
super().__init__(
base_layer=base_layer,
event_dim=event_dim,
num_params=3 if fixed_df is None else 2,
**layer_args,
)
self.min_scale = min_scale
self.min_df = min_df
self.fixed_df = fixed_df
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = super().forward(x)
loc = x[:, : self.event_dim]
scale = torch.clamp(
F.softplus(x[:, self.event_dim : 2 * self.event_dim]), min=self.min_scale
)
df = (
torch.clamp(F.softplus(x[:, 2 * self.event_dim :]), min=self.min_df)
if self.fixed_df is None
else torch.full_like(loc, self.fixed_df)
)
return {"loc": loc, "scale": scale, "df": df}
[docs]class NormalInverseGammaLinear(_ExpandOutputLinear):
r"""Normal-Inverse-Gamma Distribution Linear Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_lmbda (float): The minimal value of the :math:`\lambda` parameter.
min_alpha (float): The minimal value of the :math:`\alpha` parameter.
min_beta (float): The minimal value of the :math:`\beta` parameter.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(\ast, H_{in})` where :math:`\ast` means any number of dimensions including
none and :math:`H_{in} = \text{in_features}`.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Normal-Inverse-Gamma distribution of shape :math:`(\ast, H_{out})` where
all but the last dimension are the same as the input and
:math:`H_{out} = \text{out_features}`.
- ``"lmbda"``: The lambda parameter of the Normal-Inverse-Gamma distribution of shape
:math:`(\ast, H_{out})`.
- ``"alpha"``: The alpha parameter of the Normal-Inverse-Gamma distribution of shape
:math:`(\ast, H_{out})`.
- ``"beta"``: The beta parameter of the Normal-Inverse-Gamma distribution of shape
:math:`(\ast, H_{out})`.
Source:
- `Normal-Inverse-Gamma Distribution <https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution>`_
"""
def __init__(
self,
base_layer: type[nn.Module],
event_dim: int,
min_lmbda: float = 1e-6,
min_alpha: float = 1e-6,
min_beta: float = 1e-6,
**layer_args,
) -> None:
super().__init__(
base_layer=base_layer,
event_dim=event_dim,
num_params=4,
**layer_args,
)
self.min_lmbda = min_lmbda
self.min_alpha = min_alpha
self.min_beta = min_beta
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = super().forward(x)
loc = x[..., : self.event_dim]
lmbda = torch.clamp(
F.softplus(x[..., self.event_dim : 2 * self.event_dim]), min=self.min_lmbda
)
alpha = 1 + torch.clamp(
F.softplus(x[..., 2 * self.event_dim : 3 * self.event_dim]), min=self.min_alpha
)
beta = torch.clamp(F.softplus(x[..., 3 * self.event_dim :]), min=self.min_beta)
return {
"loc": loc,
"lmbda": lmbda,
"alpha": alpha,
"beta": beta,
}
[docs]class NormalInverseGammaConvNd(_ExpandOutputConvNd):
r"""Normal-Inverse-Gamma Distribution Convolutional Density Layer.
Args:
base_layer (type[nn.Module]): The base layer class.
event_dim (int): The number of event dimensions.
min_lmbda (float): The minimal value of the :math:`\lambda` parameter.
min_alpha (float): The minimal value of the :math:`\alpha` parameter.
min_beta (float): The minimal value of the :math:`\beta` parameter.
**layer_args: Additional arguments for the base layer.
Shape:
- Input: :math:`(N, C_{in}, \ast)` where :math:`\ast` means any number of dimensions and
:math:`C_{in} = \text{in_channels}` and :math:`N` is the batch size.
- Output: A dict with the following keys
- ``"loc"``: The mean of the Normal-Inverse-Gamma distribution of shape :math:`(N, C_{out}, \ast)` where
:math:`C_{out} = \text{out_channels}`.
- ``"lmbda"``: The lambda parameter of the Normal-Inverse-Gamma distribution of shape
:math:`(N, C_{out}, \ast)`.
- ``"alpha"``: The alpha parameter of the Normal-Inverse-Gamma distribution of shape
:math:`(N, C_{out}, \ast)`.
- ``"beta"``: The beta parameter of the Normal-Inverse-Gamma distribution of shape
:math:`(N, C_{out}, \ast)`.
Source:
- `Normal-Inverse-Gamma Distribution <https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution>`_
"""
def __init__(
self,
base_layer: type[nn.Module],
event_dim: int,
min_lmbda: float = 1e-6,
min_alpha: float = 1e-6,
min_beta: float = 1e-6,
**layer_args,
) -> None:
super().__init__(
base_layer=base_layer,
event_dim=event_dim,
num_params=4,
**layer_args,
)
self.min_lmbda = min_lmbda
self.min_alpha = min_alpha
self.min_beta = min_beta
def forward(self, x: Tensor) -> dict[str, Tensor]:
x = super().forward(x)
loc = x[:, : self.event_dim]
lmbda = torch.clamp(
F.softplus(x[:, self.event_dim : 2 * self.event_dim]), min=self.min_lmbda
)
alpha = 1 + torch.clamp(
F.softplus(x[:, 2 * self.event_dim : 3 * self.event_dim]), min=self.min_alpha
)
beta = torch.clamp(F.softplus(x[:, 3 * self.event_dim :]), min=self.min_beta)
return {
"loc": loc,
"lmbda": lmbda,
"alpha": alpha,
"beta": beta,
}