Shortcuts

Source code for torch_uncertainty.datasets.segmentation.camvid

import json
import logging
import shutil
from collections.abc import Callable
from pathlib import Path
from typing import Literal, NamedTuple

import torch
from einops import rearrange, repeat
from PIL import Image
from torchvision import tv_tensors
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    download_and_extract_archive,
    download_url,
)
from torchvision.transforms.v2 import functional as F


class CamVidClass(NamedTuple):
    name: str
    index: int
    color: tuple[int, int, int]


[docs]class CamVid(VisionDataset): # Notes: some classes are not used here classes = [ CamVidClass("animal", 0, (64, 128, 64)), CamVidClass("archway", 1, (192, 0, 128)), CamVidClass("bicyclist", 2, (0, 128, 192)), CamVidClass("bridge", 3, (0, 128, 64)), CamVidClass("building", 4, (128, 0, 0)), CamVidClass("car", 5, (64, 0, 128)), CamVidClass("cart_luggage_pram", 6, (64, 0, 192)), CamVidClass("child", 7, (192, 128, 64)), CamVidClass("column_pole", 8, (192, 192, 128)), CamVidClass("fence", 9, (64, 64, 128)), CamVidClass("lane_mkgs_driv", 10, (128, 0, 192)), CamVidClass("lane_mkgs_non_driv", 11, (192, 0, 64)), CamVidClass("misc_text", 12, (128, 128, 64)), CamVidClass("motorcycle_scooter", 13, (192, 0, 192)), CamVidClass("othermoving", 14, (128, 64, 64)), CamVidClass("parking_block", 15, (64, 192, 128)), CamVidClass("pedestrian", 16, (64, 64, 0)), CamVidClass("road", 17, (128, 64, 128)), CamVidClass("road_shoulder", 18, (128, 128, 192)), CamVidClass("sidewalk", 19, (0, 0, 192)), CamVidClass("sign_symbol", 20, (192, 128, 128)), CamVidClass("sky", 21, (128, 128, 128)), CamVidClass("suv_pickup_truck", 22, (64, 128, 192)), CamVidClass("traffic_cone", 23, (0, 0, 64)), CamVidClass("traffic_light", 24, (0, 64, 64)), CamVidClass("train", 25, (192, 64, 128)), CamVidClass("tree", 26, (128, 128, 0)), CamVidClass("truck_bus", 27, (192, 128, 192)), CamVidClass("tunnel", 28, (64, 0, 64)), CamVidClass("vegetation_misc", 29, (192, 192, 0)), CamVidClass("void", 30, (0, 0, 0)), CamVidClass("wall", 31, (64, 192, 0)), ] superclasses = [ CamVidClass("sky", 0, (128, 128, 128)), CamVidClass("building", 1, (128, 0, 0)), CamVidClass("pole", 2, (192, 192, 128)), CamVidClass("road", 3, (128, 64, 128)), CamVidClass("pavement", 4, (0, 0, 192)), CamVidClass("tree", 5, (128, 128, 0)), CamVidClass("sign_symbol", 6, (192, 128, 128)), CamVidClass("fence", 7, (64, 64, 128)), CamVidClass("car", 8, (64, 0, 128)), CamVidClass("pedestrian", 9, (64, 64, 0)), CamVidClass("bicyclist", 10, (0, 128, 192)), CamVidClass("void", None, (0, 0, 0)), ] superclasses_indices = [ [21], [3, 4, 31, 28, 1], [8, 23], [17, 10, 11], [19, 15, 18], [26, 29], [20, 12, 24], [9], [5, 22, 27, 25, 14], [16, 7, 6, 0], [2, 13], ] urls = { "raw": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip", "label": "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip", "splits": "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/segmentation/camvid/splits.json", } splits_md5 = "db45289aaa83c60201391b11e78c6382" filenames = { "raw": "701_StillsRaw_full.zip", "label": "LabeledApproved_full.zip", } base_folder = "camvid" num_samples = { "train": 367, "val": 101, "test": 233, "all": 701, } def __init__( self, root: str, group_classes: bool = True, split: Literal["train", "val", "test"] | None = None, transforms: Callable | None = None, download: bool = False, ) -> None: """`CamVid <http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/>`_ Dataset. Args: root (str): Root directory of dataset where ``camvid/`` exists or will be saved to if download is set to ``True``. group_classes (bool, optional): Whether to group the 32 classes into 11 superclasses. Default: ``True``. split (str, optional): The dataset split, supports ``train``, ``val`` and ``test``. Default: ``None``. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. Default: ``None``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ if split not in ["train", "val", "test", None]: raise ValueError( f"Unknown split '{split}'. " "Supported splits are ['train', 'val', 'test', None]" ) super().__init__(root, transforms, None, None) self.group_classes = group_classes self.class_to_superclass = [] for i in range(32): if i == 30: # For void self.class_to_superclass.append(None) for j, superclass in enumerate(self.superclasses_indices): if i in superclass: self.class_to_superclass.append(j) break if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. " "You can use download=True to download it" ) # get filenames for split if split is None: self.images = sorted((Path(self.root) / "camvid" / "raw").glob("*.png")) self.targets = sorted((Path(self.root) / "camvid" / "label").glob("*.png")) else: with (Path(self.root) / "camvid" / "splits.json").open() as f: filenames = json.load(f)[split] self.images = sorted( [ path for path in (Path(self.root) / "camvid" / "raw").glob("*.png") if path.stem in filenames ] ) self.targets = sorted( [ path for path in (Path(self.root) / "camvid" / "label").glob("*.png") if path.stem[:-2] in filenames ] ) self.split = split if split is not None else "all"
[docs] def encode_target(self, target: Image.Image) -> torch.Tensor: """Encode target image to tensor. Args: target (Image.Image): Target PIL image. Returns: torch.Tensor: Encoded target. """ colored_target = F.pil_to_tensor(target) colored_target = rearrange(colored_target, "c h w -> h w c") target = torch.zeros_like(colored_target[..., :1]) # convert target color to index for camvid_class in self.classes: index = camvid_class.index if camvid_class.index != 30 else 255 if self.group_classes and index != 255: index = self.class_to_superclass[index] target[ (colored_target == torch.tensor(camvid_class.color, dtype=target.dtype)).all(dim=-1) ] = index return rearrange(target, "h w c -> c h w")
[docs] def decode_target(self, target: torch.Tensor) -> Image.Image: """Decode target tensor to image. Args: target (torch.Tensor): Target tensor. Returns: Image.Image: Decoded target as a PIL.Image. """ colored_target = repeat(target.clone(), "h w -> h w 3", c=3) if not self.group_classes: for camvid_class in self.classes: colored_target[ (target == torch.tensor(camvid_class.index, dtype=target.dtype)).all(dim=0) ] = torch.tensor(camvid_class.color, dtype=target.dtype) else: for camvid_class in self.superclasses: colored_target[ (target == torch.tensor(camvid_class.index, dtype=target.dtype)).all(dim=0) ] = torch.tensor(camvid_class.color, dtype=target.dtype) return F.to_pil_image(rearrange(colored_target, "h w c -> c h w"))
@property def color_palette(self) -> list[tuple[int, int, int]]: """Return the color palette of the dataset.""" if self.group_classes: return [camvid_class.color for camvid_class in self.superclasses] return [camvid_class.color for camvid_class in self.classes] def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the image and target at the given index. Args: index (int): Sample index. Returns: tuple[tv_tensors.Image, tv_tensors.Mask]: Image and target. """ image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB")) target = tv_tensors.Mask(self.encode_target(Image.open(self.targets[index]))) if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self) -> int: """Return the number of samples.""" return self.num_samples[self.split] def _check_integrity(self) -> bool: """Check if the dataset exists.""" if len(list((Path(self.root) / "camvid" / "raw").glob("*.png"))) != self.num_samples["all"]: return False if ( len(list((Path(self.root) / "camvid" / "label").glob("*.png"))) != self.num_samples["all"] ): return False return (Path(self.root) / "camvid" / "splits.json").exists()
[docs] def download(self) -> None: """Download the CamVid data if it doesn't exist already.""" if self._check_integrity(): logging.info("Files already downloaded and verified") return if (Path(self.root) / self.base_folder).exists(): shutil.rmtree(Path(self.root) / self.base_folder) download_and_extract_archive( self.urls["raw"], self.root, extract_root=Path(self.root) / "camvid", filename=self.filenames["raw"], ) (Path(self.root) / "camvid" / "701_StillsRaw_full").replace( Path(self.root) / "camvid" / "raw" ) download_and_extract_archive( self.urls["label"], self.root, extract_root=Path(self.root) / "camvid" / "label", filename=self.filenames["label"], ) download_url( self.urls["splits"], Path(self.root) / "camvid", filename="splits.json", md5=self.splits_md5, )