Shortcuts

Source code for torch_uncertainty.datamodules.segmentation.camvid

import logging
from pathlib import Path

import torch
from torch import nn
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from torchvision import tv_tensors
from torchvision.transforms import v2

from torch_uncertainty.datamodules import TUDataModule
from torch_uncertainty.datasets.segmentation import CamVid
from torch_uncertainty.transforms import RandomRescale


[docs]class CamVidDataModule(TUDataModule): num_channels = 3 training_task = "segmentation" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) def __init__( self, root: str | Path, batch_size: int, crop_size: _size_2_t = 640, eval_size: _size_2_t = (720, 960), group_classes: bool = True, basic_augment: bool = True, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, ) -> None: r"""DataModule for the CamVid dataset. Args: root (str or Path): Root directory of the datasets. batch_size (int): Number of samples per batch. crop_size (sequence or int, optional): Desired input image and segmentation mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``640``. eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during evaluation. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. Defaults to ``(720,960)``. group_classes (bool, optional): Whether to group the 32 classes into 11 superclasses. Default: ``True``. basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. val_split (float or None, optional): Share of training samples to use for validation. Defaults to ``None``. num_workers (int, optional): Number of dataloaders to use. Defaults to ``1``. pin_memory (bool, optional): Whether to pin memory. Defaults to ``True``. persistent_workers (bool, optional): Whether to use persistent workers. Defaults to ``True``. Note: This datamodule injects the following transforms into the training and validation/test datasets: .. code-block:: python from torchvision.transforms import v2 v2.Compose( [ v2.Resize(640), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), ] ) This behavior can be modified by overriding ``self.train_transform`` and ``self.test_transform`` after initialization. """ if val_split is not None: # coverage: ignore logging.warning("val_split is not used for CamVidDataModule.") super().__init__( root=root, batch_size=batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) if group_classes: self.num_classes = 11 else: self.num_classes = 32 self.dataset = CamVid self.group_classes = group_classes self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) if basic_augment: basic_transform = v2.Compose( [ RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), v2.RandomCrop( size=self.crop_size, pad_if_needed=True, fill={tv_tensors.Image: 0, tv_tensors.Mask: 255}, ), v2.ColorJitter( brightness=0.5, contrast=0.5, saturation=0.5 ), v2.RandomHorizontalFlip(), ] ) else: basic_transform = nn.Identity() self.train_transform = v2.Compose( [ v2.ToImage(), basic_transform, v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = v2.Compose( [ v2.ToImage(), v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=self.mean, std=self.std), ] ) def prepare_data(self) -> None: # coverage: ignore self.dataset(root=self.root, download=True) def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: self.train = self.dataset( root=self.root, split="train", group_classes=self.group_classes, download=False, transforms=self.train_transform, ) self.val = self.dataset( root=self.root, split="val", group_classes=self.group_classes, download=False, transforms=self.test_transform, ) if stage == "test" or stage is None: self.test = self.dataset( root=self.root, split="test", group_classes=self.group_classes, download=False, transforms=self.test_transform, ) if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.")