Shortcuts

Source code for torch_uncertainty.baselines.segmentation.segformer

from typing import Literal

from torch import nn

from torch_uncertainty.models.segmentation.segformer import (
    seg_former,
)
from torch_uncertainty.routines.segmentation import SegmentationRoutine


[docs]class SegFormerBaseline(SegmentationRoutine): single = ["std"] versions = { "std": seg_former, } archs = [0, 1, 2, 3, 4, 5] def __init__( self, num_classes: int, loss: nn.Module, version: Literal["std"], arch: int, ) -> None: r"""SegFormer backbone baseline for segmentation providing support for various versions and architectures. Args: num_classes (int): Number of classes to predict. loss (type[Module]): Training loss. version (str): Determines which SegFormer version to use. Options are: - ``"std"``: original SegFormer arch (int): Determines which architecture to use. Options are: - ``0``: SegFormer-B0 - ``1``: SegFormer-B1 - ``2``: SegFormer-B2 - ``3``: SegFormer-B3 - ``4``: SegFormer-B4 - ``5``: SegFormer-B5 num_estimators (int, optional): Number of estimators in the ensemble. Defaults to 1 (single model). """ params = { "num_classes": num_classes, "arch": arch, } format_batch_fn = nn.Identity() if version not in self.versions: raise ValueError(f"Unknown version {version}") model = self.versions[version](**params) super().__init__( num_classes=num_classes, model=model, loss=loss, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"])