Source code for torch_uncertainty.datasets.fractals
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import (
check_integrity,
download_file_from_google_drive,
extract_archive,
)
[docs]class Fractals(ImageFolder):
"""Dataset used for PixMix augmentations.
Args:
root (str): Root directory of dataset.
Note:
There is no information on the license of the dataset. It may not
be suitable for commercial use.
"""
file_id = "1qC2gIUx9ARU7zhgI4IwGD3YcFhm8J4cA"
filename = "fractals_and_fvis.tar"
tgz_md5 = "3619fb7e2c76130749d97913fdd3ab27"
def __init__(
self,
root: str | Path,
transform: Callable[..., Any] | None = None,
target_transform: Callable[..., Any] | None = None,
download: bool = False,
) -> None:
self.root = Path(root)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. You can use download=True to "
"download it."
)
super().__init__(
self.root, transform=transform, target_transform=target_transform
)
def _check_integrity(self) -> bool:
fpath = self.root / self.filename
return check_integrity(
fpath,
self.tgz_md5,
)
def download(self) -> None:
if self._check_integrity():
logging.info("Files already downloaded and verified")
return
download_file_from_google_drive(
file_id=self.file_id,
root=self.root,
filename=self.filename,
md5=self.tgz_md5,
)
extract_archive(self.root / self.filename, self.root)
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""Get the samples and targets of the dataset.
Args:
index (int): The index of the sample to get.
"""
return super().__getitem__(index)[0]