mimo_wideresnet28x10#

torch_uncertainty.models.mimo_wideresnet28x10(in_channels, num_classes, num_estimators, conv_bias=True, dropout_rate=0.3, groups=1, style=ResNetStyle.IMAGENET, activation_fn=<function relu>, normalization_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>)[source]#

MIMO of Wide-ResNet-28x10.

Parameters:
  • in_channels (int) – Number of input channels.

  • num_classes (int) – Number of classes to predict.

  • num_estimators (int) – Number of estimators in the ensemble.

  • groups (int) – Number of subgroups in the convolutions.

  • conv_bias (bool) – Whether to use bias in convolutions. Defaults to True.

  • dropout_rate (float, optional) – Dropout rate. Defaults to 0.3.

  • style (ResNetStyle | Literal["imagenet", "cifar"]) – Whether to use the ImageNet or CIFAR structure. Defaults to ResNetStyle.IMAGENET.

  • activation_fn (Callable, optional) – Activation function. Defaults to torch.nn.functional.relu.

  • normalization_layer (nn.Module, optional) – Normalization layer. Defaults to torch.nn.BatchNorm2d.

Returns:

A MIMO Wide-ResNet-28x10.

Return type:

_MIMOWideResNet