Shortcuts

Source code for torch_uncertainty.datasets.nyu

from collections.abc import Callable
from importlib import util
from pathlib import Path
from typing import Literal

import numpy as np
from PIL import Image
from torchvision import tv_tensors
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
    download_url,
)

if util.find_spec("cv2"):
    import cv2

    cv2_installed = True
else:  # coverage: ignore
    cv2_installed = False

if util.find_spec("h5py"):
    import h5py

    h5py_installed = True
else:  # coverage: ignore
    h5py_installed = False


[docs]class NYUv2(VisionDataset): root: Path rgb_urls = { "train": "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz", "val": "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz", } rgb_md5 = { "train": "ad124bbde47e371359caa4642a8a4611", "val": "f47f7c7c8a20d1210db7941c4f153b06", } depth_url = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat" depth_md5 = "520609c519fba3ba5ac58c8fefcc3530" def __init__( self, root: Path | str, split: Literal["train", "val"], transforms: Callable | None = None, min_depth: float = 0.0, max_depth: float = 10.0, download: bool = False, ): """NYUv2 depth dataset. Args: root (Path | str): Root directory where dataset is stored. split (Literal["train", "val"]): Dataset split. transforms (Callable | None): Transform to apply to samples & targets. Defaults to None. min_depth (float): Minimum depth value. Defaults to 1e-3. max_depth (float): Maximum depth value. Defaults to 10. download (bool): Download dataset if not found. Defaults to False. """ if not cv2_installed: # coverage: ignore raise ImportError( "The cv2 library is not installed. Please install" "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) if not h5py_installed: # coverage: ignore raise ImportError( "The h5py library is not installed. Please install" "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) super().__init__(Path(root) / "NYUv2", transforms=transforms) self.min_depth = min_depth self.max_depth = max_depth if split not in ["train", "val"]: raise ValueError( f"split must be one of ['train', 'val']. Got {split}." ) self.split = split if not self._check_integrity(): if download: self._download() else: raise FileNotFoundError( f"NYUv2 {split} split not found or incomplete. Set download=True to download it." ) # make dataset path = self.root / self.split self.samples = sorted((path / "rgb_img").glob("**/*")) self.targets = sorted((path / "depth").glob("**/*")) def __getitem__(self, index: int): """Return image and target at index. Args: index (int): Index of the sample. """ image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) target = Image.fromarray( cv2.imread( str(self.targets[index]), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, ) ) target = np.asarray(target, np.uint16) target = tv_tensors.Mask(target / 1e4) # convert to meters target[(target <= self.min_depth) | (target > self.max_depth)] = float( "nan" ) if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self): """Return number of samples in dataset.""" return len(self.samples) def _check_integrity(self) -> bool: """Check if dataset is present and complete.""" return ( check_integrity( self.root / f"nyu_{self.split}_rgb.tgz", self.rgb_md5[self.split], ) and check_integrity(self.root / "depth.mat", self.depth_md5) and (self.root / self.split / "rgb_img").exists() and (self.root / self.split / "depth").exists() ) def _download(self): """Download and extract dataset.""" download_and_extract_archive( self.rgb_urls[self.split], self.root, extract_root=self.root / self.split / "rgb_img", filename=f"nyu_{self.split}_rgb.tgz", md5=self.rgb_md5[self.split], ) if not check_integrity(self.root / "depth.mat", self.depth_md5): download_url( NYUv2.depth_url, self.root, "depth.mat", self.depth_md5 ) self._create_depth_files() def _create_depth_files(self): """Create depth images from the depth.mat file.""" path = self.root / self.split (path / "depth").mkdir() samples = sorted((path / "rgb_img").glob("**/*")) ids = [int(p.stem.split("_")[-1]) for p in samples] file = h5py.File(self.root / "depth.mat", "r") depths = file["depths"] for i in range(len(depths)): img_id = i + 1 if img_id in ids: img = (depths[i] * 1e4).astype(np.uint16).T Image.fromarray(img).save( path / "depth" / f"nyu_depth_{str(img_id).zfill(4)}.png" )