Source code for torch_uncertainty.datasets.segmentation.cityscapes

from collections.abc import Callable
from typing import Any

import torch
from einops import rearrange
from PIL import Image
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchvision import tv_tensors
from torchvision.datasets import Cityscapes as TVCityscapes
from torchvision.transforms.v2 import functional as F


[docs] class Cityscapes(TVCityscapes): color_palette = [ (128, 64, 128), # 0: road (244, 35, 232), # 1: sidewalk (70, 70, 70), # 2: building (102, 102, 156), # 3: wall (190, 153, 153), # 4: fence (153, 153, 153), # 5: pole (250, 170, 30), # 6: traffic light (220, 220, 0), # 7: traffic sign (107, 142, 35), # 8: vegetation (152, 251, 152), # 9: terrain (70, 130, 180), # 10: sky (220, 20, 60), # 11: person (255, 0, 0), # 12: rider (0, 0, 142), # 13: car (0, 0, 70), # 14: truck (0, 60, 100), # 15: bus (0, 80, 100), # 16: train (0, 0, 230), # 17: motorcycle (119, 11, 32), # 18: bicycle (0, 0, 0), # 19: void ] def __init__( self, root: str, split: str = "train", mode: str = "fine", target_type: list[str] | str = "instance", transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, transforms: Callable[..., Any] | None = None, ) -> None: """Cityscapes dataset wrapper with train-ID color mapping. Extends :class:`torchvision.datasets.Cityscapes` to provide a stable color palette for visualization and convenience helpers to encode and decode semantic segmentation targets using Cityscapes train IDs. A tensor mapping train IDs to RGB colors is constructed on initialization for decoding predicted masks. Attributes: color_palette: List of RGB tuples for each class label in Cityscapes. train_id_to_color: Tensor mapping train IDs to RGB colors for decoding. Args: root: Root directory of the Cityscapes dataset. split: Dataset split to use (e.g., "train", "val", or "test"). Defaults to ``"train"``. mode: Annotation mode ("fine" or "coarse"). Defaults to ``"fine"``. target_type: One or more target types to load ("instance", "semantic", etc.). Defaults to ``"instance"``. transform: Transformation applied to the input image. Defaults to ``None``. target_transform: Transformation applied to the target. Defaults to ``None``. transforms: Combined transformation for image and target. Defaults to ``None``. """ super().__init__( root, split, mode, target_type, transform, target_transform, transforms, ) train_id_to_color = [ c.color for c in self.classes if (c.train_id != -1 and c.train_id != 255) ] train_id_to_color.append([0, 0, 0]) self.train_id_to_color = torch.tensor(train_id_to_color)
[docs] @classmethod def encode_target(cls, target: Image.Image) -> Image.Image: """Encode a Cityscapes target PIL image into a train-ID image. The input is a color-coded PIL image (train-ID or label colors). The method converts it to a single-channel image where each pixel value is the corresponding train ID. Args: target: A PIL image containing the ground-truth target map. Returns: A PIL image where pixel values correspond to Cityscapes train IDs. """ colored_target = F.pil_to_tensor(target) colored_target = rearrange(colored_target, "c h w -> h w c") target = torch.zeros_like(colored_target[..., :1]) # convert target color to index for cityscapes_class in cls.classes: target[ (colored_target == torch.tensor(cityscapes_class.id, dtype=target.dtype)).all( dim=-1 ) ] = cityscapes_class.train_id return F.to_pil_image(rearrange(target, "h w c -> c h w"))
[docs] def decode_target(self, target: torch.Tensor) -> torch.Tensor: """Decode a train-ID tensor into an RGB tensor using the palette. Pixels with value ``255`` are treated as void and mapped to the last palette entry (black). The returned tensor contains RGB color values for each pixel according to the dataset palette. Args: target: Integer tensor of train IDs. Returns: A tensor of RGB colors with shape ``(H, W, 3)`` mapped from train IDs. """ target[target == 255] = -1 return self.train_id_to_color[target]
def __getitem__(self, index: int) -> tuple[Any, Any]: """Return the sample at the given index. Args: index: Integer index of the sample to retrieve. Returns: ``(image, target)`` tuple: If ``target_type`` contains multiple types, ``target`` is a tuple with each corresponding target. If ``target_type=="polygon"``, the target is a JSON object; for semantic targets the returned object is a mask or ``tv_tensors.Mask``. """ image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) targets: Any = [] for i, t in enumerate(self.target_type): if t == "polygon": target = self._load_json(self.targets[index][i]) elif t == "semantic": target = tv_tensors.Mask(self.encode_target(Image.open(self.targets[index][i]))) else: target = Image.open(self.targets[index][i]) targets.append(target) target = tuple(targets) if len(targets) > 1 else targets[0] if self.transforms is not None: image, target = self.transforms(image, target) return image, target
[docs] def plot_sample(self, index: int, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: """Plot a dataset sample (image and decoded target) for inspection. Intended behavior: load the sample at ``index``, decode the target to RGB using :attr:`train_id_to_color`, and render the image and overlay/target on the provided axis. If ``ax`` is ``None``, a new matplotlib figure and axis should be created. Args: index: Index of the sample to plot. ax: Optional matplotlib axis to draw on. Defaults to ``None``. Returns: A tuple ``(fig, ax)`` or the axis object used for plotting, depending on the plotting utility in use. Raises: NotImplementedError: This plotting helper is not implemented yet. """ raise NotImplementedError("plot_sample is not implemented yet.")