Shortcuts

Source code for torch_uncertainty.datasets.muad

import json
import logging
import os
import shutil
from collections.abc import Callable
from importlib import util
from pathlib import Path
from typing import Literal

if util.find_spec("cv2"):
    import cv2

    cv2_installed = True
else:  # coverage: ignore
    cv2_installed = False
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from torchvision import tv_tensors
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
    download_url,
)
from torchvision.transforms.v2 import functional as F


[docs]class MUAD(VisionDataset): classes_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/segmentation/muad/classes.json" classes_md5 = "1db6e6143939824792f0af11a4fe7bb1" # avoid replacement attack base_url = "https://zenodo.org/records/10619959/files/" zip_md5 = { "train": "cea6a672225b10dda1add8b2974a5982", "train_depth": "934d122ac09e0471db62ae68c3456b0f", "val": "957af9c1c36f0a85c33279e06b6cf8d8", "val_depth": "0282030d281aeffee3335f713ba12373", } _num_samples = { "train": 3420, "val": 492, "test": ..., } targets: list[Path] = [] def __init__( self, root: str | Path, split: Literal["train", "val"], min_depth: float | None = None, max_depth: float | None = None, target_type: Literal["semantic", "depth"] = "semantic", transforms: Callable | None = None, download: bool = False, ) -> None: """The MUAD Dataset. Args: root (str): Root directory of dataset where directory 'leftImg8bit' and 'leftLabel' or 'leftDepth' are located. split (str, optional): The image split to use, 'train' or 'val'. min_depth (float, optional): The maximum depth value to use if target_type is 'depth'. Defaults to None. max_depth (float, optional): The maximum depth value to use if target_type is 'depth'. Defaults to None. target_type (str, optional): The type of target to use, 'semantic' or 'depth'. transforms (callable, optional): A function/transform that takes in a tuple of PIL images and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Reference: https://muad-dataset.github.io Note: MUAD cannot be used for commercial purposes. Read MUAD's license carefully before using it and verify that you can comply. """ if not cv2_installed: # coverage: ignore raise ImportError( "The cv2 library is not installed. Please install" "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) logging.info( "MUAD is restricted to non-commercial use. By using MUAD, you " "agree to the terms and conditions." ) super().__init__( root=Path(root) / "MUAD", transforms=transforms, ) self.min_depth = min_depth self.max_depth = max_depth if split not in ["train", "val"]: raise ValueError( f"split must be one of ['train', 'val']. Got {split}." ) self.split = split self.target_type = target_type if not self.check_split_integrity("leftImg8bit"): if download: self._download(split=split) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) if ( not self.check_split_integrity("leftLabel") and target_type == "semantic" ): if download: self._download(split=split) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) if ( not self.check_split_integrity("leftDepth") and target_type == "depth" ): if download: self._download(split=f"{split}_depth") # Depth target for train are in a different folder # thus we move them to the correct folder if split == "train": shutil.move( self.root / f"{split}_depth", self.root / split / "leftDepth", ) else: raise FileNotFoundError( f"MUAD {split} split not found or incomplete. Set download=True to download it." ) # Load classes metadata cls_path = self.root / "classes.json" if (not check_integrity(cls_path, self.classes_md5)) and download: download_url( self.classes_url, self.root, "classes.json", self.classes_md5, ) with (self.root / "classes.json").open() as file: self.classes = json.load(file) train_id_to_color = [ c["object_id"] for c in self.classes if c["train_id"] not in [-1, 255] ] train_id_to_color.append([0, 0, 0]) self.train_id_to_color = np.array(train_id_to_color) self._make_dataset(self.root / split)
[docs] def encode_target(self, target: Image.Image) -> Image.Image: """Encode target image to tensor. Args: target (Image.Image): Target PIL image. Returns: torch.Tensor: Encoded target. """ target = F.pil_to_tensor(target) target = rearrange(target, "c h w -> h w c") out = torch.zeros_like(target[..., :1]) # convert target color to index for muad_class in self.classes: out[ ( target == torch.tensor(muad_class["id"], dtype=target.dtype) ).all(dim=-1) ] = muad_class["train_id"] return F.to_pil_image(rearrange(out, "h w c -> c h w"))
def decode_target(self, target: Image.Image) -> np.ndarray: target[target == 255] = 19 return self.train_id_to_color[target] def __getitem__( self, index: int ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. Args: index (int): Index Returns: tuple: (image, target) where target is either a segmentation mask or a depth map. """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) if self.target_type == "semantic": target = tv_tensors.Mask( self.encode_target(Image.open(self.targets[index])) ) else: os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" target = Image.fromarray( cv2.imread( str(self.targets[index]), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, ) ) # TODO: in the long run it would be better to use a custom # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) target = np.asarray(target, np.float32) target = tv_tensors.Mask(400 * (1 - target)) # convert to meters target[(target <= self.min_depth) | (target > self.max_depth)] = ( float("nan") ) if self.transforms is not None: image, target = self.transforms(image, target) return image, target def check_split_integrity(self, folder: str) -> bool: split_path = self.root / self.split return ( split_path.is_dir() and len(list((split_path / folder).glob("**/*"))) == self._num_samples[self.split] ) def __len__(self) -> int: """The number of samples in the dataset.""" return self._num_samples[self.split] def _make_dataset(self, path: Path) -> None: """Create a list of samples and targets. Args: path (Path): The path to the dataset. """ if "depth" in path.name: raise NotImplementedError( "Depth mode is not implemented yet. Raise an issue " "if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": self.targets = sorted((path / "leftLabel/").glob("**/*")) elif self.target_type == "depth": self.targets = sorted((path / "leftDepth/").glob("**/*")) else: raise ValueError( f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." ) def _download(self, split: str) -> None: """Download and extract the chosen split of the dataset.""" split_url = self.base_url + split + ".zip" download_and_extract_archive( split_url, self.root, md5=self.zip_md5[split] ) @property def color_palette(self) -> np.ndarray: return self.train_id_to_color.tolist()