Shortcuts

Source code for torch_uncertainty.datasets.frost

import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
)


def pil_loader(path: Path) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with path.open("rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


[docs]class FrostImages(VisionDataset): # TODO: Use ImageFolder url = "https://zenodo.org/records/10438904/files/frost.zip" zip_md5 = "d82f29f620d43a68e71e34b28f7c35cb" filename = "frost.zip" samples = [ "frost1.png", "frost2.png", "frost3.jpg", "frost4.jpg", "frost5.jpg", ] def __init__( self, root: str | Path, transform: Callable[..., Any] | None, target_transform: Callable[..., Any] | None = None, download: bool = False, ) -> None: self.root = Path(root) if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to " "download it." ) super().__init__( self.root / "frost", transform=transform, target_transform=target_transform, ) self.loader = pil_loader def _check_integrity(self) -> bool: fpath = self.root / self.filename return check_integrity( fpath, self.zip_md5, ) def download(self) -> None: if self._check_integrity(): logging.info("Files already downloaded and verified") return download_and_extract_archive( self.url, download_root=self.root, filename=self.filename, md5=self.zip_md5, ) logging.info("Downloaded %s to %s.", self.filename, self.root) def __getitem__(self, index: int) -> Any: """Get the samples of the dataset. Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path = self.root / self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) return sample def __len__(self) -> int: """Get the length of the dataset.""" return len(self.samples)