Source code for torch_uncertainty.datasets.frost
from collections.abc import Callable
from importlib import util
from importlib.abc import Traversable
from importlib.resources import files
from pathlib import Path
from typing import Any
from PIL import Image
from torchvision.datasets import VisionDataset
FROST_ASSETS_MOD = "torch_uncertainty_assets.frost"
tu_assets_installed = util.find_spec("torch_uncertainty_assets")
def pil_loader(path: Path | Traversable) -> Image.Image:
with path.open("rb") as f:
img = Image.open(f)
return img.convert("RGB")
[docs]
class FrostImages(VisionDataset):
def __init__(
self,
transform: Callable[..., Any] | None = None,
target_transform: Callable[..., Any] | None = None,
) -> None:
if not tu_assets_installed: # coverage: ignore
raise ImportError(
"The torch-uncertainty-assets library is not installed. Please install"
"torch_uncertainty with the image option:"
"""pip install -U "torch_uncertainty[image]"."""
)
super().__init__(
FROST_ASSETS_MOD,
transform=transform,
target_transform=target_transform,
)
self.loader = pil_loader
sample_path = files(FROST_ASSETS_MOD)
self.samples = [sample_path.joinpath(f"frost{i}.jpg") for i in range(1, 6)]
def __getitem__(self, index: int) -> Any:
"""Get the samples of the dataset.
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
sample = self.loader(self.samples[index])
if self.transform is not None:
sample = self.transform(sample)
return sample
def __len__(self) -> int:
"""Get the length of the dataset."""
return len(self.samples)