Shortcuts

Source code for torch_uncertainty.datasets.muad

import logging
import os
import shutil
from collections.abc import Callable
from importlib import util
from pathlib import Path
from typing import Literal, NamedTuple

from huggingface_hub import hf_hub_download
from PIL import Image

if util.find_spec("cv2"):
    import cv2

    cv2_installed = True
else:  # coverage: ignore
    cv2_installed = False
import numpy as np
from torchvision import tv_tensors
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    download_and_extract_archive,
)


class MUADClass(NamedTuple):
    name: str
    id: int
    color: tuple[int, int, int]


[docs]class MUAD(VisionDataset): classes_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/segmentation/muad/classes.json" classes_md5 = "1db6e6143939824792f0af11a4fe7bb1" # avoid replacement attack base_url = "https://zenodo.org/records/10619959/files/" zip_md5 = { "train": "cea6a672225b10dda1add8b2974a5982", "train_depth": "934d122ac09e0471db62ae68c3456b0f", "val": "957af9c1c36f0a85c33279e06b6cf8d8", "val_depth": "0282030d281aeffee3335f713ba12373", } small_muad_url = "ENSTA-U2IS/miniMUAD" _num_samples = { "full": { "train": 3420, "val": 492, "test": ..., }, "small": { "train": 400, "val": 54, "test": 112, "ood": 20, }, } classes = [ MUADClass("road", 0, (128, 64, 128)), MUADClass("sidewalk", 1, (244, 35, 232)), MUADClass("building", 2, (70, 70, 70)), MUADClass("wall", 3, (102, 102, 156)), MUADClass("fence", 4, (190, 153, 153)), MUADClass("pole", 5, (153, 153, 153)), MUADClass("traffic_light", 6, (250, 170, 30)), MUADClass("traffic_sign", 7, (220, 220, 0)), MUADClass("vegetation", 8, (107, 142, 35)), MUADClass("terrain", 9, (152, 251, 152)), MUADClass("sky", 10, (70, 130, 180)), MUADClass("person", 11, (220, 20, 60)), MUADClass("rider", 12, (255, 0, 0)), MUADClass("car", 13, (0, 0, 142)), MUADClass("truck", 14, (0, 0, 70)), MUADClass("bus", 15, (0, 60, 100)), MUADClass("train", 16, (0, 80, 100)), MUADClass("motorcycle", 17, (0, 0, 230)), MUADClass("bicycle", 18, (119, 11, 32)), MUADClass("bear deer cow", 19, (255, 228, 196)), MUADClass("garbage_bag stand_food trash_can", 20, (128, 128, 0)), MUADClass("unlabeled", 21, (0, 0, 0)), # id 255 or 21 ] targets: list[Path] = [] def __init__( self, root: str | Path, split: Literal["train", "val", "test", "ood"], version: Literal["small", "full"] = "full", min_depth: float | None = None, max_depth: float | None = None, target_type: Literal["semantic", "depth"] = "semantic", transforms: Callable | None = None, download: bool = False, ) -> None: """The MUAD Dataset. Args: root (str): Root directory of dataset where directory 'leftImg8bit' and 'leftLabel' or 'leftDepth' are located. split (str, optional): The image split to use, 'train' or 'val'. version (str, optional): The version of the dataset to use, 'small' or 'full'. Defaults to 'full'. min_depth (float, optional): The maximum depth value to use if target_type is 'depth'. Defaults to None. max_depth (float, optional): The maximum depth value to use if target_type is 'depth'. Defaults to None. target_type (str, optional): The type of target to use, 'semantic' or 'depth'. transforms (callable, optional): A function/transform that takes in a tuple of PIL images and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Reference: https://muad-dataset.github.io Note: MUAD cannot be used for commercial purposes. Read MUAD's license carefully before using it and verify that you can comply. """ if not cv2_installed: # coverage: ignore raise ImportError( "The cv2 library is not installed. Please install" "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) if version == "small" and target_type == "depth": raise ValueError("Depth target is not available for the small version of MUAD.") logging.info( "MUAD is restricted to non-commercial use. By using MUAD, you " "agree to the terms and conditions." ) dataset_root = Path(root) / "MUAD" if version == "full" else Path(root) / "MUAD_small" super().__init__(dataset_root, transforms=transforms) self.min_depth = min_depth self.max_depth = max_depth if split not in ["train", "val", "test", "ood"]: raise ValueError(f"split must be one of ['train', 'val']. Got {split}.") self.split = split self.version = version self.target_type = target_type if not self.check_split_integrity("leftImg8bit"): if download: self._download(split=split) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) if not self.check_split_integrity("leftLabel") and target_type == "semantic": if download: self._download(split=split) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) if not self.check_split_integrity("leftDepth") and target_type == "depth": if download: self._download(split=f"{split}_depth") # Depth target for train are in a different folder # thus we move them to the correct folder if split == "train": shutil.move( self.root / f"{split}_depth", self.root / split / "leftDepth", ) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) self._make_dataset(self.root / split) def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. Args: index (int): Index Returns: tuple: (image, target) where target is either a segmentation mask or a depth map. """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) if self.target_type == "semantic": target = tv_tensors.Mask(Image.open(self.targets[index])) else: os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" target = Image.fromarray( cv2.imread( str(self.targets[index]), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, ) ) # TODO: in the long run it would be better to use a custom # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) target = np.asarray(target, np.float32) target = tv_tensors.Mask(400 * (1 - target)) # convert to meters target[(target <= self.min_depth) | (target > self.max_depth)] = float("nan") if self.transforms is not None: image, target = self.transforms(image, target) return image, target def check_split_integrity(self, folder: str) -> bool: split_path = self.root / self.split return ( split_path.is_dir() and len(list((split_path / folder).glob("**/*"))) == self.__len__() ) def __len__(self) -> int: """The number of samples in the dataset.""" return self._num_samples[self.version][self.split] def _make_dataset(self, path: Path) -> None: """Create a list of samples and targets. Args: path (Path): The path to the dataset. """ if "depth" in path.name: raise NotImplementedError( "Depth mode is not implemented yet. Raise an issue " "if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": self.targets = sorted((path / "leftLabel/").glob("**/*")) elif self.target_type == "depth": self.targets = sorted((path / "leftDepth/").glob("**/*")) else: raise ValueError( f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." ) def _download(self, split: str) -> None: """Download and extract the chosen split of the dataset.""" if self.version == "small": filename = f"{split}.zip" downloaded_file = hf_hub_download( repo_id=self.small_muad_url, filename=filename, repo_type="dataset" ) shutil.unpack_archive(downloaded_file, extract_dir=self.root) else: split_url = self.base_url + split + ".zip" download_and_extract_archive(split_url, self.root, md5=self.zip_md5[split]) @property def color_palette(self) -> np.ndarray: return [c.color for c in self.classes]