Source code for torch_uncertainty.datasets.muad

import logging
import os
import shutil
from collections.abc import Callable
from importlib import util
from operator import attrgetter
from pathlib import Path
from typing import Literal, NamedTuple

from huggingface_hub import hf_hub_download
from PIL import Image

if util.find_spec("cv2"):
    import cv2

    cv2_installed = True
else:  # coverage: ignore
    cv2_installed = False
import numpy as np
from torchvision import tv_tensors
from torchvision.datasets import VisionDataset


class MUADClass(NamedTuple):
    name: str
    id: int
    train_id: int
    color: tuple[int, int, int]
    is_ood: bool


[docs] class MUAD(VisionDataset): classes_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/segmentation/muad/classes.json" base_urls = { "full": "ENSTA-U2IS/MUAD", "small": "ENSTA-U2IS/miniMUAD", } huggingface_splits = { "full": [ "train", "val", "test_id", "test_ood", "test_id_low_adv", "test_id_high_adv", "test_ood_low_adv", "test_ood_high_adv", ], "small": [ "train", "val", "test", "ood", ], } _num_samples = { "full": { "train": 3420, "val": 492, "test_id": 551, "test_ood": 1668, "test_id_low_adv": 605, "test_id_high_adv": 602, "test_ood_low_adv": 1552, "test_ood_high_adv": 1421, }, "small": { "train": 400, "val": 54, "test": 112, "ood": 20, }, } classes = [ MUADClass("road", 0, 0, (128, 64, 128), False), MUADClass("sidewalk", 1, 1, (244, 35, 232), False), MUADClass("building", 2, 2, (70, 70, 70), False), MUADClass("wall", 3, 3, (102, 102, 156), False), MUADClass("fence", 4, 4, (190, 153, 153), False), MUADClass("pole", 5, 5, (153, 153, 153), False), MUADClass("traffic_light", 6, 6, (250, 170, 30), False), MUADClass("traffic_sign", 7, 7, (220, 220, 0), False), MUADClass("vegetation", 8, 8, (107, 142, 35), False), MUADClass("terrain", 9, 9, (152, 251, 152), False), MUADClass("sky", 10, 10, (70, 130, 180), False), MUADClass("person", 11, 11, (220, 20, 60), False), MUADClass("car", 13, 12, (0, 0, 142), False), MUADClass("truck", 14, 13, (0, 0, 70), False), MUADClass("bus", 15, 14, (0, 60, 100), False), MUADClass("rider", 12, 15, (255, 0, 0), True), MUADClass("train", 16, 16, (0, 80, 100), True), MUADClass("motorcycle", 17, 17, (0, 0, 230), True), MUADClass("bicycle", 18, 18, (119, 11, 32), True), MUADClass("bear deer cow", 19, 19, (255, 228, 196), True), MUADClass("garbage_bag stand_food trash_can", 20, 20, (128, 128, 0), True), MUADClass("unlabeled", 21, 255, (0, 0, 0), False), # id 255 or 21 ] targets: list[Path] = [] num_id_classes = 15 num_ood_classes = 6 def __init__( self, root: str | Path, split: Literal["train", "val", "test", "ood"], version: Literal["small", "full"] = "full", min_depth: float | None = None, max_depth: float | None = None, target_type: Literal["semantic", "depth"] = "semantic", transforms: Callable | None = None, download: bool = False, use_train_ids: bool = True, ) -> None: """The MUAD Dataset. Args: root (str | Path): Root directory of dataset where directory ``leftImg8bit`` and ``leftLabel`` or ``leftDepth`` are located. split (str, optional): The image split to use, ``train``, ``val``, ``test`` or ``ood``. version (str, optional): The version of the dataset to use, ``small`` or ``full``. Defaults to ``full``. min_depth (float, optional): The maximum depth value to use if target_type is ``depth``. Defaults to ``None``. max_depth (float, optional): The maximum depth value to use if target_type is ``depth``. Defaults to ``None``. target_type (str, optional): The type of target to use, ``semantic`` or ``depth``. Defaults to ``semantic``. transforms (callable, optional): A function/transform that takes in a tuple of PIL images and returns a transformed version. Defaults to ``None``. download (bool, optional): If ``True``, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to ``False``. use_train_ids (bool, optional): If ``True``, uses the train ids instead of the original ids. Defaults to ``True``. Note that this is only used for the ``semantic`` target type. Reference: https://muad-dataset.github.io Note: MUAD cannot be used for commercial purposes. Read MUAD's license carefully before using it and verify that you can comply. """ if not cv2_installed: # coverage: ignore raise ImportError( "The cv2 library is not installed. Please install" "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) if version == "small" and target_type == "depth": raise ValueError("Depth target is not available for the small version of MUAD.") logging.info( "MUAD is restricted to non-commercial use. By using MUAD, you " "agree to the terms and conditions." ) dataset_root = Path(root) / "MUAD" if version == "full" else Path(root) / "MUAD_small" super().__init__(dataset_root, transforms=transforms) self.min_depth = min_depth self.max_depth = max_depth if split not in self.huggingface_splits[version]: raise ValueError( f"split must be one of {self.huggingface_splits[version]}. Got {split}." ) self.split = split self.version = version self.target_type = target_type self.use_train_ids = use_train_ids if not self.check_split_integrity("leftImg8bit"): if download: self._download(split=split) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) if not self.check_split_integrity("leftLabel") and target_type == "semantic": if download: self._download(split=split) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) if not self.check_split_integrity("leftDepth") and target_type == "depth": if download: self._download(split=f"{split}_depth") # Depth target for train are in a different folder # thus we move them to the correct folder if split == "train": shutil.move( self.root / f"{split}_depth", self.root / split / "leftDepth", ) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) self._make_dataset(self.root / split)
[docs] def encode_target(self, target: tv_tensors.Mask) -> tv_tensors.Mask: """Encode the target tensor to the train ids. Args: target (torch.Tensor): Target tensor. Returns: torch.Tensor: Encoded target tensor. """ original_tgt = target.clone() for c in self.classes: target[original_tgt == c.id] = c.train_id return target
def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. Args: index (int): Index Returns: tuple: (image, target) where target is either a segmentation mask or a depth map. """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) if self.target_type == "semantic": target = tv_tensors.Mask(Image.open(self.targets[index])) if self.use_train_ids: target = self.encode_target(target) else: os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" target = Image.fromarray( cv2.imread( str(self.targets[index]), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, ) ) # TODO: in the long run it would be better to use a custom # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) target = np.asarray(target, np.float32) target = tv_tensors.Mask(400 * (1 - target)) # convert to meters target[(target <= self.min_depth) | (target > self.max_depth)] = float("nan") if self.transforms is not None: image, target = self.transforms(image, target) return image, target def check_split_integrity(self, folder: str) -> bool: split_path = self.root / self.split return ( split_path.is_dir() and len(list((split_path / folder).glob("**/*"))) == self.__len__() ) def __len__(self) -> int: """The number of samples in the dataset.""" return self._num_samples[self.version][self.split] def _make_dataset(self, path: Path) -> None: """Create a list of samples and targets. Args: path (Path): The path to the dataset. """ if "depth" in path.name: raise NotImplementedError( "Depth mode is not implemented yet. Raise an issue if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": self.targets = sorted((path / "leftLabel/").glob("**/*")) elif self.target_type == "depth": self.targets = sorted((path / "leftDepth/").glob("**/*")) else: raise ValueError( f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." ) def _download(self, split: str) -> None: # coverage: ignore """Download and extract the chosen split of the dataset.""" repo_id = self.base_urls[self.version] filename = f"{split}.zip" downloaded_file = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") shutil.unpack_archive(downloaded_file, extract_dir=self.root) @property def color_palette(self) -> np.ndarray: sorting_key = "train_id" if self.use_train_ids else "id" sorted_cls = sorted(self.classes, key=attrgetter(sorting_key)) return [c.color for c in sorted_cls]