Source code for torch_uncertainty.models.classification.inception_time.mimo
from einops import rearrange
from torch import Tensor
from .std import _InceptionTime
class _MIMOInceptionTime(_InceptionTime):
def __init__(
self,
in_channels: int,
num_classes: int,
num_estimators: int,
kernel_size: int = 40,
embed_dim: int = 32,
num_blocks: int = 6,
dropout: float = 0.0,
residual: bool = True,
):
super().__init__(
in_channels=in_channels * num_estimators,
num_classes=num_classes * num_estimators,
kernel_size=kernel_size,
embed_dim=embed_dim,
num_blocks=num_blocks,
dropout=dropout,
residual=residual,
)
self.num_estimators = num_estimators
def forward(self, x: Tensor) -> Tensor:
if not self.training:
x = x.repeat(self.num_estimators, 1, 1)
out = rearrange(x, "(m b) c t -> b (m c) t", m=self.num_estimators)
out = super().forward(out)
return rearrange(out, "b (m d) -> (m b) d", m=self.num_estimators)
[docs]
def mimo_inception_time(
in_channels: int,
num_classes: int,
num_estimators: int,
kernel_size: int = 40,
embed_dim: int = 32,
num_blocks: int = 6,
dropout: float = 0.0,
residual: bool = True,
) -> _MIMOInceptionTime:
"""MIMO of InceptionTime.
Args:
in_channels (int): Number of input channels.
num_classes (int): Number of output classes.
num_estimators (int): Number of estimators for MIMO.
kernel_size (int): Size of the convolutional kernel. Default is ``40``.
embed_dim (int): Dimension of the embedding. Default is ``32``.
num_blocks (int): Number of inception blocks. Default is ``6``.
dropout (float): Dropout rate. Default is ``0.0``.
residual (bool): Whether to use residual connections. Default is ``True``.
Returns:
_MIMOInceptionTime: The MIMO InceptionTime model.
"""
return _MIMOInceptionTime(
in_channels=in_channels,
num_classes=num_classes,
num_estimators=num_estimators,
kernel_size=kernel_size,
embed_dim=embed_dim,
num_blocks=num_blocks,
dropout=dropout,
residual=residual,
)