Source code for torch_uncertainty.models.classification.wideresnet.packed

from collections.abc import Callable
from typing import Literal

from torch import Tensor, nn
from torch.nn.functional import relu

from torch_uncertainty.layers import PackedConv2d, PackedLinear

__all__ = [
    "packed_wideresnet28x10",
]


class _WideBasicBlock(nn.Module):
    def __init__(
        self,
        in_planes: int,
        planes: int,
        conv_bias: bool,
        dropout_rate: float,
        stride: int,
        alpha: int,
        num_estimators: int,
        gamma: int,
        groups: int,
        activation_fn: Callable,
        normalization_layer: type[nn.Module],
    ) -> None:
        super().__init__()
        self.activation_fn = activation_fn
        self.conv1 = PackedConv2d(
            in_planes,
            planes,
            kernel_size=3,
            alpha=alpha,
            num_estimators=num_estimators,
            gamma=gamma,
            groups=groups,
            padding=1,
            bias=conv_bias,
        )
        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.bn1 = normalization_layer(alpha * planes)
        self.conv2 = PackedConv2d(
            planes,
            planes,
            kernel_size=3,
            alpha=alpha,
            num_estimators=num_estimators,
            gamma=gamma,
            groups=groups,
            stride=stride,
            padding=1,
            bias=conv_bias,
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                PackedConv2d(
                    in_planes,
                    planes,
                    kernel_size=1,
                    alpha=alpha,
                    num_estimators=num_estimators,
                    gamma=gamma,
                    groups=groups,
                    stride=stride,
                    bias=conv_bias,
                ),
            )
        self.bn2 = normalization_layer(alpha * planes)

    def forward(self, x: Tensor) -> Tensor:
        out = self.activation_fn(self.bn1(self.dropout(self.conv1(x))))
        out = self.conv2(out)
        out += self.shortcut(x)
        return self.activation_fn(self.bn2(out))


class _PackedWideResNet(nn.Module):
    def __init__(
        self,
        depth: int,
        widen_factor: int,
        in_channels: int,
        num_classes: int,
        conv_bias: bool,
        dropout_rate: float,
        num_estimators: int,
        alpha: int = 2,
        gamma: int = 1,
        groups: int = 1,
        style: Literal["imagenet", "cifar"] = "imagenet",
        activation_fn: Callable = relu,
        normalization_layer: type[nn.Module] = nn.BatchNorm2d,
    ) -> None:
        super().__init__()
        self.num_estimators = num_estimators
        self.activation_fn = activation_fn
        self.in_planes = 16

        if (depth - 4) % 6 != 0:
            raise ValueError(f"Wide-resnet depth should be 6n+4. Got {depth}.")
        num_blocks = int((depth - 4) / 6)
        num_stages = [
            16,
            16 * widen_factor,
            32 * widen_factor,
            64 * widen_factor,
        ]

        if style == "imagenet":
            self.conv1 = PackedConv2d(
                in_channels,
                num_stages[0],
                kernel_size=7,
                alpha=alpha,
                num_estimators=self.num_estimators,
                stride=2,
                padding=3,
                gamma=1,  # No groups for the first layer
                groups=groups,
                bias=conv_bias,
                first=True,
            )
        elif style == "cifar":
            self.conv1 = PackedConv2d(
                in_channels,
                num_stages[0],
                kernel_size=3,
                alpha=alpha,
                num_estimators=self.num_estimators,
                stride=1,
                padding=1,
                gamma=gamma,
                groups=groups,
                bias=conv_bias,
                first=True,
            )
        else:
            raise ValueError(f"Unknown WideResNet style: {style}. ")

        self.bn1 = normalization_layer(num_stages[0] * alpha)

        if style == "imagenet":
            self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.optional_pool = nn.Identity()

        self.layer1 = self._wide_layer(
            _WideBasicBlock,
            num_stages[1],
            num_blocks,
            conv_bias=conv_bias,
            dropout_rate=dropout_rate,
            stride=1,
            alpha=alpha,
            num_estimators=self.num_estimators,
            gamma=gamma,
            groups=groups,
            activation_fn=activation_fn,
            normalization_layer=normalization_layer,
        )
        self.layer2 = self._wide_layer(
            _WideBasicBlock,
            num_stages[2],
            num_blocks,
            conv_bias=conv_bias,
            dropout_rate=dropout_rate,
            stride=2,
            alpha=alpha,
            num_estimators=self.num_estimators,
            gamma=gamma,
            groups=groups,
            activation_fn=activation_fn,
            normalization_layer=normalization_layer,
        )
        self.layer3 = self._wide_layer(
            _WideBasicBlock,
            num_stages[3],
            num_blocks,
            conv_bias=conv_bias,
            dropout_rate=dropout_rate,
            stride=2,
            alpha=alpha,
            num_estimators=self.num_estimators,
            gamma=gamma,
            groups=groups,
            activation_fn=activation_fn,
            normalization_layer=normalization_layer,
        )

        self.final_dropout = nn.Dropout(p=dropout_rate)
        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.flatten = nn.Flatten(1)

        self.linear = PackedLinear(
            num_stages[3],
            num_classes,
            alpha=alpha,
            num_estimators=num_estimators,
            last=True,
        )

    def _wide_layer(
        self,
        block: type[_WideBasicBlock],
        planes: int,
        num_blocks: int,
        conv_bias: bool,
        dropout_rate: float,
        stride: int,
        alpha: int,
        num_estimators: int,
        gamma: int,
        groups: int,
        activation_fn: Callable,
        normalization_layer: type[nn.Module],
    ) -> nn.Module:
        strides = [stride] + [1] * (int(num_blocks) - 1)
        layers = []

        for stride in strides:
            layers.append(
                block(
                    in_planes=self.in_planes,
                    planes=planes,
                    conv_bias=conv_bias,
                    dropout_rate=dropout_rate,
                    stride=stride,
                    alpha=alpha,
                    num_estimators=num_estimators,
                    gamma=gamma,
                    groups=groups,
                    activation_fn=activation_fn,
                    normalization_layer=normalization_layer,
                )
            )
            self.in_planes = planes
        return nn.Sequential(*layers)

    def feats_forward(self, x: Tensor) -> Tensor:
        out = self.activation_fn(self.bn1(self.conv1(x)))
        out = self.optional_pool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.pool(out)
        return self.final_dropout(self.flatten(out))

    def forward(self, x: Tensor) -> Tensor:
        return self.linear(self.feats_forward(x))


[docs] def packed_wideresnet28x10( in_channels: int, num_classes: int, num_estimators: int, alpha: int, gamma: int, conv_bias: bool = True, dropout_rate: float = 0.3, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _PackedWideResNet: """Packed-Ensembles of Wide-ResNet-28x10. Args: in_channels (int): Number of input channels. num_classes (int): Number of classes to predict. num_estimators (int): Number of estimators in the ensemble. alpha (int): Expansion factor affecting the width of the estimators. gamma (int): Number of groups within each estimator. groups (int): Number of subgroups in the convolutions. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.3``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. activation_fn (Callable, optional): Activation function. Defaults to ``torch.nn.functional.relu``. normalization_layer (nn.Module, optional): Normalization layer. Defaults to ``torch.nn.BatchNorm2d``. Returns: _PackedWideResNet: A Packed-Ensembles Wide-ResNet-28x10. """ return _PackedWideResNet( in_channels=in_channels, num_classes=num_classes, depth=28, widen_factor=10, conv_bias=conv_bias, dropout_rate=dropout_rate, num_estimators=num_estimators, alpha=alpha, gamma=gamma, groups=groups, style=style, activation_fn=activation_fn, normalization_layer=normalization_layer, )