Source code for torch_uncertainty.models.classification.resnet.packed

from typing import Any, Literal

import torch.nn.functional as F
from torch import Tensor, nn

from torch_uncertainty.layers import PackedConv2d, PackedLinear
from torch_uncertainty.utils import load_hf

from .utils import get_resnet_num_blocks

__all__ = [
    "packed_resnet",
]

weight_ids = {
    "10": {
        "18": "pe_resnet18_c10",
        "32": None,
        "50": "pe_resnet50_c10",
        "101": None,
        "152": None,
    },
    "100": {
        "18": "pe_resnet18_c100",
        "32": None,
        "50": "pe_resnet50_c100",
        "101": None,
        "152": None,
    },
    "1000": {
        "18": None,
        "32": None,
        "50": "pe_resnet50_in1k",
        "101": None,
        "152": None,
    },
}


class _BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int,
        alpha: int,
        num_estimators: int,
        gamma: int,
        conv_bias: bool,
        dropout_rate: float,
        groups: int,
        normalization_layer: type[nn.Module],
    ) -> None:
        super().__init__()

        # No subgroups for the first layer
        self.conv1 = PackedConv2d(
            in_planes,
            planes,
            kernel_size=3,
            alpha=alpha,
            num_estimators=num_estimators,
            groups=groups,
            stride=stride,
            padding=1,
            bias=conv_bias,
        )
        self.bn1 = normalization_layer(planes * alpha)
        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.conv2 = PackedConv2d(
            planes,
            planes,
            kernel_size=3,
            alpha=alpha,
            num_estimators=num_estimators,
            gamma=gamma,
            groups=groups,
            stride=1,
            padding=1,
            bias=conv_bias,
        )
        self.bn2 = normalization_layer(planes * alpha)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                PackedConv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    alpha=alpha,
                    num_estimators=num_estimators,
                    gamma=gamma,
                    groups=groups,
                    stride=stride,
                    bias=conv_bias,
                ),
                normalization_layer(self.expansion * planes * alpha),
            )

    def forward(self, x: Tensor) -> Tensor:
        out = F.relu(self.dropout(self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)


class _Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int,
        alpha: int,
        num_estimators: int,
        gamma: int,
        conv_bias: bool,
        dropout_rate: float,
        groups: int,
        normalization_layer: type[nn.Module],
    ) -> None:
        super().__init__()

        # No subgroups for the first layer
        self.conv1 = PackedConv2d(
            in_planes,
            planes,
            kernel_size=1,
            alpha=alpha,
            num_estimators=num_estimators,
            gamma=1,  # No groups from gamma in the first layer
            groups=groups,
            bias=conv_bias,
        )
        self.bn1 = normalization_layer(planes * alpha)
        self.conv2 = PackedConv2d(
            planes,
            planes,
            kernel_size=3,
            alpha=alpha,
            num_estimators=num_estimators,
            gamma=gamma,
            stride=stride,
            padding=1,
            groups=groups,
            bias=conv_bias,
        )
        self.bn2 = normalization_layer(planes * alpha)
        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.conv3 = PackedConv2d(
            planes,
            self.expansion * planes,
            kernel_size=1,
            alpha=alpha,
            num_estimators=num_estimators,
            gamma=gamma,
            groups=groups,
            bias=conv_bias,
        )
        self.bn3 = normalization_layer(self.expansion * planes * alpha)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                PackedConv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    alpha=alpha,
                    num_estimators=num_estimators,
                    gamma=gamma,
                    groups=groups,
                    stride=stride,
                    bias=conv_bias,
                ),
                normalization_layer(self.expansion * planes * alpha),
            )

    def forward(self, x: Tensor) -> Tensor:
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.dropout(self.bn2(self.conv2(out))))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return F.relu(out)


class _PackedResNet(nn.Module):
    def __init__(
        self,
        block: type[_BasicBlock | _Bottleneck],
        num_blocks: list[int],
        in_channels: int,
        num_classes: int,
        conv_bias: bool,
        num_estimators: int,
        dropout_rate: float,
        alpha: int = 2,
        gamma: int = 1,
        groups: int = 1,
        style: Literal["imagenet", "cifar"] = "imagenet",
        in_planes: int = 64,
        normalization_layer: type[nn.Module] = nn.BatchNorm2d,
        linear_implementation: str = "conv1d",
    ) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.alpha = alpha
        self.gamma = gamma
        self.groups = groups
        self.num_estimators = num_estimators

        self.in_planes = in_planes
        block_planes = in_planes

        if style == "imagenet":
            self.conv1 = PackedConv2d(
                self.in_channels,
                block_planes,
                kernel_size=7,
                stride=2,
                padding=3,
                alpha=alpha,
                num_estimators=num_estimators,
                gamma=1,  # No groups for the first layer
                groups=groups,
                bias=conv_bias,
                first=True,
            )
        elif style == "cifar":
            self.conv1 = PackedConv2d(
                self.in_channels,
                block_planes,
                kernel_size=3,
                stride=1,
                padding=1,
                alpha=alpha,
                num_estimators=num_estimators,
                gamma=1,  # No groups for the first layer
                groups=groups,
                bias=conv_bias,
                first=True,
            )
        else:
            raise ValueError(f"Unknown style. Got {style}.")

        self.bn1 = normalization_layer(block_planes * alpha)

        if style == "imagenet":
            self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.optional_pool = nn.Identity()

        self.layer1 = self._make_layer(
            block,
            block_planes,
            num_blocks[0],
            stride=1,
            alpha=alpha,
            num_estimators=num_estimators,
            conv_bias=conv_bias,
            dropout_rate=dropout_rate,
            gamma=gamma,
            groups=groups,
            normalization_layer=normalization_layer,
        )
        self.layer2 = self._make_layer(
            block,
            block_planes * 2,
            num_blocks[1],
            stride=2,
            alpha=alpha,
            num_estimators=num_estimators,
            conv_bias=conv_bias,
            dropout_rate=dropout_rate,
            gamma=gamma,
            groups=groups,
            normalization_layer=normalization_layer,
        )
        self.layer3 = self._make_layer(
            block,
            block_planes * 4,
            num_blocks[2],
            stride=2,
            alpha=alpha,
            num_estimators=num_estimators,
            conv_bias=conv_bias,
            dropout_rate=dropout_rate,
            gamma=gamma,
            groups=groups,
            normalization_layer=normalization_layer,
        )
        if len(num_blocks) == 4:
            self.layer4 = self._make_layer(
                block,
                block_planes * 8,
                num_blocks[3],
                stride=2,
                alpha=alpha,
                num_estimators=num_estimators,
                conv_bias=conv_bias,
                dropout_rate=dropout_rate,
                gamma=gamma,
                groups=groups,
                normalization_layer=normalization_layer,
            )
            linear_multiplier = 8
        else:
            self.layer4 = nn.Identity()
            linear_multiplier = 4

        self.final_dropout = nn.Dropout(p=dropout_rate)
        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.flatten = nn.Flatten(1)

        self.linear = PackedLinear(
            block_planes * linear_multiplier * block.expansion,
            num_classes,
            alpha=alpha,
            num_estimators=num_estimators,
            last=True,
            implementation=linear_implementation,
        )

    def _make_layer(
        self,
        block: type[_BasicBlock | _Bottleneck],
        planes: int,
        num_blocks: int,
        stride: int,
        alpha: int,
        num_estimators: int,
        conv_bias: bool,
        dropout_rate: float,
        gamma: int,
        groups: int,
        normalization_layer: type[nn.Module],
    ) -> nn.Module:
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(
                block(
                    in_planes=self.in_planes,
                    planes=planes,
                    stride=stride,
                    alpha=alpha,
                    num_estimators=num_estimators,
                    conv_bias=conv_bias,
                    dropout_rate=dropout_rate,
                    gamma=gamma,
                    groups=groups,
                    normalization_layer=normalization_layer,
                )
            )
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.optional_pool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.pool(out)
        out = self.final_dropout(self.flatten(out))
        return self.linear(out)

    def check_config(self, config: dict[str, Any]) -> bool:
        """Check if the pretrained configuration matches the current model."""
        return (
            (config["alpha"] == self.alpha)
            * (config["gamma"] == self.gamma)
            * (config["groups"] == self.groups)
            * (config["num_estimators"] == self.num_estimators)
        )


[docs] def packed_resnet( in_channels: int, num_classes: int, arch: int, num_estimators: int, alpha: int, gamma: int, conv_bias: bool = False, width_multiplier: float = 1.0, groups: int = 1, dropout_rate: float = 0, style: Literal["imagenet", "cifar"] = "imagenet", normalization_layer: type[nn.Module] = nn.BatchNorm2d, pretrained: bool = False, linear_implementation: str = "conv1d", ) -> _PackedResNet: """Packed-Ensembles of ResNet. Args: in_channels (int): Number of input channels. num_classes (int): Number of classes to predict. arch (int): The architecture of the ResNet. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float): Dropout rate. Defaults to ``0``. num_estimators (int): Number of estimators in the ensemble. alpha (int): Expansion factor affecting the width of the estimators. gamma (int): Number of groups within each estimator. width_multiplier (float): Width multiplier. Defaults to ``1``. groups (int): Number of groups within each estimator group. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. normalization_layer (nn.Module, optional): Normalization layer. pretrained (bool, optional): Whether to load pretrained weights. Defaults to ``False``. linear_implementation (str, optional): Implementation of the packed linear layer. Defaults to ``"conv1d"``. Returns: _PackedResNet: A Packed-Ensembles ResNet. """ block = _BasicBlock if arch in [18, 20, 34, 44, 56, 110, 1202] else _Bottleneck in_planes = 16 if arch in [20, 44, 56, 110, 1202] else 64 net = _PackedResNet( block=block, num_blocks=get_resnet_num_blocks(arch), in_channels=in_channels, num_estimators=num_estimators, alpha=alpha, gamma=gamma, conv_bias=conv_bias, dropout_rate=dropout_rate, groups=groups, num_classes=num_classes, style=style, in_planes=int(in_planes * width_multiplier), normalization_layer=normalization_layer, linear_implementation=linear_implementation, ) if pretrained: # coverage: ignore weights = weight_ids[str(num_classes)][str(arch)] if weights is None: raise ValueError("No pretrained weights for this configuration") state_dict, config = load_hf(weights) if not net.check_config(config): raise ValueError("Pretrained weights do not match current configuration.") net.load_state_dict(state_dict) return net