Source code for torch_uncertainty.datasets.segmentation.cityscapes
from collections.abc import Callable
from typing import Any
import torch
from einops import rearrange
from PIL import Image
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchvision import tv_tensors
from torchvision.datasets import Cityscapes as TVCityscapes
from torchvision.transforms.v2 import functional as F
[docs]class Cityscapes(TVCityscapes):
def __init__(
self,
root: str,
split: str = "train",
mode: str = "fine",
target_type: list[str] | str = "instance",
transform: Callable[..., Any] | None = None,
target_transform: Callable[..., Any] | None = None,
transforms: Callable[..., Any] | None = None,
) -> None:
super().__init__(
root,
split,
mode,
target_type,
transform,
target_transform,
transforms,
)
train_id_to_color = [
c.color
for c in self.classes
if (c.train_id != -1 and c.train_id != 255)
]
train_id_to_color.append([0, 0, 0])
self.train_id_to_color = torch.tensor(train_id_to_color)
[docs] @classmethod
def encode_target(cls, target: Image.Image) -> Image.Image:
"""Encode target image to tensor.
Args:
target (Image.Image): Target PIL image.
Returns:
torch.Tensor: Encoded target.
"""
colored_target = F.pil_to_tensor(target)
colored_target = rearrange(colored_target, "c h w -> h w c")
target = torch.zeros_like(colored_target[..., :1])
# convert target color to index
for cityscapes_class in cls.classes:
target[
(
colored_target
== torch.tensor(cityscapes_class.id, dtype=target.dtype)
).all(dim=-1)
] = cityscapes_class.train_id
return F.to_pil_image(rearrange(target, "h w c -> c h w"))
[docs] def decode_target(self, target: torch.Tensor) -> torch.Tensor:
"""Decode target tensor to RGB tensor.
Args:
target (torch.Tensor): Target RGB tensor.
Returns:
Image.Image: Decoded target.
"""
target[target == 255] = -1
return self.train_id_to_color[target]
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""Get the sample at the given index.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types
if ``target_type`` is a list with more
than one item. Otherwise, target is a json object if
``target_type="polygon"``, else the image segmentation.
"""
image = tv_tensors.Image(Image.open(self.images[index]).convert("RGB"))
targets: Any = []
for i, t in enumerate(self.target_type):
if t == "polygon":
target = self._load_json(self.targets[index][i])
elif t == "semantic":
target = tv_tensors.Mask(
self.encode_target(Image.open(self.targets[index][i]))
)
else:
target = Image.open(self.targets[index][i])
targets.append(target)
target = tuple(targets) if len(targets) > 1 else targets[0]
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
[docs] def plot_sample(
self, index: int, ax: _AX_TYPE | None = None
) -> _PLOT_OUT_TYPE:
"""Plot a sample from the dataset.
Args:
index: The index of the sample to plot.
ax: Optional matplotlib axis to plot on.
Returns:
The axis on which the sample was plotted.
"""
raise NotImplementedError("This method is not implemented yet.")
@property
def color_palette(self) -> list[tuple[int, int, int]]:
"""Return the color palette of the dataset."""
return [c.color for c in self.classes]