Source code for torch_uncertainty.datasets.frost

from collections.abc import Callable
from importlib import util
from importlib.resources import files
from pathlib import Path
from typing import Any

from PIL import Image
from torchvision.datasets import VisionDataset

FROST_ASSETS_MOD = "torch_uncertainty_assets.frost"
tu_assets_installed = util.find_spec("torch_uncertainty_assets")


def pil_loader(path: Path) -> Image.Image:
    with path.open("rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


[docs] class FrostImages(VisionDataset): def __init__( self, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, ) -> None: """Frost corruption image dataset. This dataset provides a small collection of frost-corrupted images that are primarily used to simulate distribution shift in vision experiments. It is typically leveraged for robustness evaluation, out-of-distribution (OOD) testing, and uncertainty estimation under image corruption. The dataset contains five JPEG images (``frost1.jpg`` to ``frost5.jpg``) stored in the ``torch-uncertainty-assets`` package. No labels are provided, and each sample consists only of an image. Args: transform: A function/transform applied to the input image. Defaults to ``None``. target_transform: A function/transform applied to the target. Since no targets are provided, this argument is kept for API compatibility. Defaults to ``None``. Raises: ImportError: If the ``torch-uncertainty-assets`` package with image support is not installed. Note: This dataset is intended for generating controlled distribution shifts. """ if not tu_assets_installed: # coverage: ignore raise ImportError( "The torch-uncertainty-assets library is not installed. Please install " "torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) super().__init__( FROST_ASSETS_MOD, transform=transform, target_transform=target_transform, ) self.loader = pil_loader sample_path = files(FROST_ASSETS_MOD) self.samples = [sample_path.joinpath(f"frost{i}.jpg") for i in range(1, 6)] def __getitem__(self, index: int) -> Any: """Get the samples of the dataset. Args: index: Index of the image to get. Returns: tuple: (sample, target) where target is class_index of the target class. """ sample = self.loader(self.samples[index]) if self.transform is not None: sample = self.transform(sample) return sample def __len__(self) -> int: """Get the length of the dataset.""" return len(self.samples)