Source code for torch_uncertainty.datasets.fractals
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import (
    check_integrity,
    download_file_from_google_drive,
    extract_archive,
)
[docs]
class Fractals(ImageFolder):
    file_id = "1qC2gIUx9ARU7zhgI4IwGD3YcFhm8J4cA"
    filename = "fractals_and_fvis.tar"
    tgz_md5 = "3619fb7e2c76130749d97913fdd3ab27"
    def __init__(
        self,
        root: str | Path,
        transform: Callable[..., Any] | None = None,
        target_transform: Callable[..., Any] | None = None,
        download: bool = False,
    ) -> None:
        """Dataset used for PixMix augmentations.
        Args:
            root (str | Path): Root directory of dataset.
            transform (Callable[..., Any] | None): Transform to apply to the input samples.
                Defaults to ``None``.
            target_transform (Callable[..., Any] | None): Transform to apply to the target labels.
                Defaults to ``None``.
            download (bool): If ``True``, downloads the dataset if not present. Defaults to ``False``.
        Note:
            There is no information on the license of the dataset. It may not
            be suitable for commercial use.
        """
        self.root = Path(root)
        if download:
            self.download()
        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it."
            )
        super().__init__(self.root, transform=transform, target_transform=target_transform)
    def _check_integrity(self) -> bool:
        fpath = self.root / self.filename
        return check_integrity(
            fpath,
            self.tgz_md5,
        )
    def download(self) -> None:
        if self._check_integrity():
            logging.info("Files already downloaded and verified")
            return
        download_file_from_google_drive(
            file_id=self.file_id,
            root=self.root,
            filename=self.filename,
            md5=self.tgz_md5,
        )
        extract_archive(self.root / self.filename, self.root)
    def __getitem__(self, index: int) -> tuple[Any, Any]:
        """Get the samples and targets of the dataset.
        Args:
            index (int): The index of the sample to get.
        """
        return super().__getitem__(index)[0]