Source code for torch_uncertainty.datasets.kitti
import json
import logging
import shutil
from collections.abc import Callable
from pathlib import Path
from typing import Literal
from PIL import Image
from torchvision import tv_tensors
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
download_and_extract_archive,
download_url,
)
from torchvision.transforms import functional as F
from tqdm import tqdm
[docs]class KITTIDepth(VisionDataset):
root: Path
depth_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip"
depth_md5 = "7d1ce32633dc2f43d9d1656a1f875e47"
raw_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/"
raw_filenames_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/download/kitti/raw_filenames.json"
raw_filenames_md5 = "e5b7fad5ecd059488ef6c02dc9e444c1"
_num_samples = {
"train": 42949,
"val": 3426,
"test": ...,
}
def __init__(
self,
root: str | Path,
split: Literal["train", "val"],
min_depth: float = 0.0,
max_depth: float = 80.0,
transforms: Callable | None = None,
download: bool = False,
remove_unused: bool = False,
) -> None:
logging.info(
"KITTIDepth is copyrighted by the Karlsruhe Institute of Technology "
"(KIT) and the Toyota Technological Institute at Chicago (TTIC). "
"By using KITTIDepth, you agree to the terms and conditions of the "
"Creative Commons Attribution-NonCommercial-ShareAlike 3.0 License. "
"This means that you must attribute the work in the manner specified "
"by the authors, you may not use this work for commercial purposes "
"and if you alter, transform, or build upon this work, you may "
"distribute the resulting work only under the same license."
)
super().__init__(
root=Path(root) / "KITTIDepth",
transforms=transforms,
)
self.min_depth = min_depth
self.max_depth = max_depth
if split not in ["train", "val"]:
raise ValueError(f"split must be one of ['train', 'val']. Got {split}.")
self.split = split
if not self.check_split_integrity("leftDepth"):
if download:
self._download_depth()
else:
raise FileNotFoundError(
f"KITTI {split} split not found or incomplete. Set download=True to download it."
)
if not self.check_split_integrity("leftImg8bit"):
if download:
self._download_raw(remove_unused)
else:
raise FileNotFoundError(
f"KITTI {split} split not found or incomplete. Set download=True to download it."
)
self._make_dataset()
def check_split_integrity(self, folder: str) -> bool:
split_path = self.root / self.split
return (
split_path.is_dir()
and len(list((split_path / folder).glob("*.png"))) == self._num_samples[self.split]
)
def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]:
"""Get the sample at the given index.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a depth map.
"""
image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB"))
target = tv_tensors.Mask(
F.pil_to_tensor(Image.open(self.targets[index])).squeeze(0) / 256.0
)
target[(target <= self.min_depth) | (target > self.max_depth)] = float("nan")
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
"""The number of samples in the dataset."""
return self._num_samples[self.split]
def _make_dataset(self) -> None:
self.samples = sorted((self.root / self.split / "leftImg8bit").glob("*.png"))
self.targets = sorted((self.root / self.split / "leftDepth").glob("*.png"))
def _download_depth(self) -> None:
"""Download and extract the depth annotation dataset."""
if not (self.root / "tmp").exists():
download_and_extract_archive(
self.depth_url,
download_root=self.root,
extract_root=self.root / "tmp",
md5=self.depth_md5,
)
logging.info("Re-structuring the depth annotations...")
if (self.root / "train" / "leftDepth").exists():
shutil.rmtree(self.root / "train" / "leftDepth")
(self.root / "train" / "leftDepth").mkdir(parents=True, exist_ok=False)
depth_files = list((self.root).glob("**/tmp/train/**/image_02/*.png"))
logging.info("Train files...")
for file in tqdm(depth_files):
exp_code = file.parents[3].name.split("_")
filecode = "_".join([exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name])
shutil.copy(file, self.root / "train" / "leftDepth" / filecode)
if (self.root / "val" / "leftDepth").exists():
shutil.rmtree(self.root / "val" / "leftDepth")
(self.root / "val" / "leftDepth").mkdir(parents=True, exist_ok=False)
depth_files = list((self.root).glob("**/tmp/val/**/image_02/*.png"))
logging.info("Validation files...")
for file in tqdm(depth_files):
exp_code = file.parents[3].name.split("_")
filecode = "_".join([exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name])
shutil.copy(file, self.root / "val" / "leftDepth" / filecode)
shutil.rmtree(self.root / "tmp")
def _download_raw(self, remove_unused: bool) -> None:
"""Download and extract the raw dataset."""
download_url(
self.raw_filenames_url,
self.root,
"raw_filenames.json",
self.raw_filenames_md5,
)
with (self.root / "raw_filenames.json").open() as file:
raw_filenames = json.load(file)
for filename in tqdm(raw_filenames):
logging.info("%s", self.raw_url + filename)
download_and_extract_archive(
self.raw_url + filename,
download_root=self.root,
extract_root=self.root / "raw",
md5=None,
)
logging.info("Re-structuring the raw data...")
samples_to_keep = list((self.root / "train" / "leftDepth").glob("*.png"))
if (self.root / "train" / "leftImg8bit").exists():
shutil.rmtree(self.root / "train" / "leftImg8bit")
(self.root / "train" / "leftImg8bit").mkdir(parents=True, exist_ok=False)
logging.info("Train files...")
for sample in tqdm(samples_to_keep):
filecode = sample.name.split("_")
first_level = "_".join([filecode[0], filecode[1], filecode[2]])
second_level = "_".join(
[
filecode[0],
filecode[1],
filecode[2],
"drive",
filecode[3],
"sync",
]
)
raw_path = (
self.root / "raw" / first_level / second_level / "image_02" / "data" / filecode[4]
)
shutil.copy(raw_path, self.root / "train" / "leftImg8bit" / sample.name)
samples_to_keep = list((self.root / "val" / "leftDepth").glob("*.png"))
if (self.root / "val" / "leftImg8bit").exists():
shutil.rmtree(self.root / "val" / "leftImg8bit")
(self.root / "val" / "leftImg8bit").mkdir(parents=True, exist_ok=False)
logging.info("Validation files...")
for sample in tqdm(samples_to_keep):
filecode = sample.name.split("_")
first_level = "_".join([filecode[0], filecode[1], filecode[2]])
second_level = "_".join(
[
filecode[0],
filecode[1],
filecode[2],
"drive",
filecode[3],
"sync",
]
)
raw_path = (
self.root / "raw" / first_level / second_level / "image_02" / "data" / filecode[4]
)
shutil.copy(raw_path, self.root / "val" / "leftImg8bit" / sample.name)
if remove_unused:
shutil.rmtree(self.root / "raw")