Source code for torch_uncertainty.datasets.classification.not_mnist
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import (
check_integrity,
download_and_extract_archive,
)
[docs]class NotMNIST(ImageFolder):
"""The notMNIST dataset.
Args:
root (str): Root directory of the datasets.
subset (str): The subset to use, one of ``small`` or ``large``.
transform (callable, optional): A function/transform that takes in
a PIL image and returns a transformed version. E.g,
``transforms.RandomCrop``. Defaults to None.
target_transform (callable, optional): A function/transform that
takes in the target and transforms it. Defaults to None.
download (bool, optional): If True, downloads the dataset from the
internet and puts it in root directory. If dataset is already
downloaded, it is not downloaded again. Defaults to False.
Note:
There is no information on the license of the dataset. It may not
be suitable for commercial use.
"""
url_base = "https://zenodo.org/record/8274268/files/"
filenames = ["notMNIST_small.zip", "notMNIST_large.zip"]
tgz_md5s = [
"3de91fb69221d9c2d5c57387101ebc6c",
"c3f9e0862df000a897766593044e366a",
]
subsets = ["small", "large"]
def __init__(
self,
root: str | Path,
subset: Literal["small", "large"] = "small",
transform: Callable[..., Any] | None = None,
target_transform: Callable[..., Any] | None = None,
download: bool = False,
) -> None:
self.root = Path(root)
if subset not in self.subsets:
raise ValueError(
f"The subset '{subset}' does not exist for notMNIST."
)
ind = self.subsets.index(subset)
self.url = self.url_base + "/" + self.filenames[ind]
self.filename = self.filenames[ind]
self.tgz_md5 = self.tgz_md5s[ind]
if download:
self.download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. You can use download=True to "
"download it."
)
super().__init__(
self.root / f"notMNIST_{subset}",
transform=transform,
target_transform=target_transform,
)
def _check_integrity(self) -> bool:
fpath = self.root / self.filename
return check_integrity(
fpath,
self.tgz_md5,
)
def download(self) -> None:
if self._check_integrity():
logging.info("Files already downloaded and verified")
return
download_and_extract_archive(
self.url,
download_root=self.root,
filename=self.filename,
md5=self.tgz_md5,
)
logging.info("Downloaded %s to %s.", self.filename, self.root)
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""Get the samples and targets of the dataset.
Args:
index (int): The index of the sample to get.
"""
return super().__getitem__(index)