Source code for torch_uncertainty.datamodules.segmentation.cityscapes

from pathlib import Path

import torch
from torch import nn
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from torchvision import tv_tensors
from torchvision.transforms import v2

from torch_uncertainty.datamodules import TUDataModule
from torch_uncertainty.datasets.segmentation import Cityscapes
from torch_uncertainty.datasets.utils import create_train_val_split
from torch_uncertainty.transforms import RandomRescale


[docs] class CityscapesDataModule(TUDataModule): num_classes = 19 num_channels = 3 training_task = "segmentation" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) def __init__( self, root: str | Path, batch_size: int, eval_batch_size: int | None = None, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), train_transform: nn.Module | None = None, test_transform: nn.Module | None = None, basic_augment: bool = True, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, ) -> None: r"""DataModule for the Cityscapes dataset. Args: root (str or Path): Root directory of the datasets. batch_size (int): Number of samples per batch during training. eval_batch_size (int | None) : Number of samples per batch during evaluation (val and test). Set to :attr:`batch_size` if ``None``. Defaults to ``None``. crop_size (sequence or int, optional): Desired input image and segmentation mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Has to be provided if :attr:`train_transform` is not provided. Otherwise has no effect. Defaults to ``1024``. eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. Has to be provided if :attr:`test_transform` is not provided. Otherwise has no effect. Defaults to ``(1024,2048)``. train_transform (nn.Module | None): Custom training transform. Defaults to ``None``. If not provided, a default transform is used. test_transform (nn.Module | None): Custom test transform. Defaults to ``None``. If not provided, a default transform is used. basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. Only used if ``train_transform`` is not provided. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to ``1``. pin_memory (bool, optional): Whether to pin memory. Defaults to ``True``. persistent_workers (bool, optional): Whether to use persistent workers. Defaults to ``True``. Note: By default this datamodule injects the following transforms into the training and validation/test datasets: Training transforms: .. code-block:: python from torchvision.transforms import v2 v2.Compose( [ v2.ToImage(), RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), v2.RandomCrop(size=crop_size, pad_if_needed=True), v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), v2.RandomHorizontalFlip(), v2.ToDtype( { tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) Validation/Test transforms: .. code-block:: python from torchvision.transforms import v2 v2.Compose( [ v2.ToImage(), v2.Resize(size=eval_size, antialias=True), v2.ToDtype( { tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) This behavior can be modified by setting up ``train_transform`` and ``test_transform`` at initialization. """ super().__init__( root=root, batch_size=batch_size, eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) self.dataset = Cityscapes self.mode = "fine" self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) if train_transform is not None: self.train_transform = train_transform else: if basic_augment: basic_transform = v2.Compose( [ RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), v2.RandomCrop( size=self.crop_size, pad_if_needed=True, fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, ), v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), v2.RandomHorizontalFlip(), ] ) else: basic_transform = nn.Identity() self.train_transform = v2.Compose( [ v2.ToImage(), v2.Resize(size=self.eval_size, antialias=True), basic_transform, v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=self.mean, std=self.std), ] ) if test_transform is not None: self.test_transform = test_transform else: self.test_transform = v2.Compose( [ v2.ToImage(), v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=self.mean, std=self.std), ] ) def prepare_data(self) -> None: # coverage: ignore self.dataset(root=self.root, split="train", mode=self.mode) def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: full = self.dataset( root=self.root, split="train", mode=self.mode, target_type="semantic", transforms=self.train_transform, ) if self.val_split is not None: self.train, self.val = create_train_val_split( full, self.val_split, self.test_transform, ) else: self.train = full self.val = self.dataset( root=self.root, split="val", mode=self.mode, target_type="semantic", transforms=self.test_transform, ) if stage == "test" or stage is None: self.test = self.dataset( root=self.root, split="val", mode=self.mode, target_type="semantic", transforms=self.test_transform, ) if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.")