Source code for torch_uncertainty.baselines.segmentation.deeplab

from typing import Literal

from torch import nn

from torch_uncertainty.models.segmentation import deep_lab_v3_resnet
from torch_uncertainty.routines.segmentation import SegmentationRoutine


[docs] class DeepLabBaseline(SegmentationRoutine): archs = [50, 101] def __init__( self, num_classes: int, loss: nn.Module, version: Literal["std"], arch: int, style: Literal["v3", "v3+"], output_stride: int, separable: bool, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, num_bins_cal_err: int = 15, pretrained_backbone: bool = True, ) -> None: params = { "num_classes": num_classes, "arch": arch, "style": style, "output_stride": output_stride, "separable": separable, "pretrained_backbone": pretrained_backbone, } format_batch_fn = nn.Identity() model = deep_lab_v3_resnet(**params) super().__init__( num_classes=num_classes, model=model, loss=loss, format_batch_fn=format_batch_fn, metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, num_bins_cal_err=num_bins_cal_err, ) self.save_hyperparameters(ignore=["loss"])