mimo_resnet#
- torch_uncertainty.models.mimo_resnet(in_channels, num_classes, arch, num_estimators, conv_bias=True, dropout_rate=0.0, width_multiplier=1.0, groups=1, style='imagenet', normalization_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>)[source]#
MIMO ResNet.
- Parameters:
in_channels (int) – Number of input channels.
num_classes (int) – Number of classes to predict.
arch (int) – The architecture of the ResNet.
num_estimators (int) – Number of estimators in the ensemble.
conv_bias (bool, optional) – Whether to use bias in convolutional layers. Defaults to
True.dropout_rate (float, optional) – Dropout rate. Defaults to
0.0.width_multiplier (float, optional) – Width multiplier. Defaults to
1.0.groups (int, optional) – Number of groups for grouped convolution. Defaults to
1.style (Literal["imagenet", "cifar"], optional) – Style of ResNet. Defaults to
"imagenet".normalization_layer (nn.Module, optional) – Normalization layer. Defaults to
torch.nn.BatchNorm2d.
- Returns:
A MIMO-style ResNet.
- Return type:
_MIMOResNet