Shortcuts

Source code for torch_uncertainty.datasets.classification.cifar.cifar_h

from collections.abc import Callable
from pathlib import Path
from typing import Any

import numpy as np
import torch
from torchvision.datasets import CIFAR10
from torchvision.datasets.utils import check_integrity, download_url


[docs]class CIFAR10H(CIFAR10): """`CIFAR-10H <https://github.com/jcpeterson/cifar-10h>`_ Dataset. Args: root (string): Root directory of dataset where file ``cifar-10h-probs.npy`` exists or will be saved to if download is set to True. train (bool, optional): For API consistency, not used. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. Defaults to None. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Defaults to None. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. """ h_test_list = ["cifar-10h-probs.npy", "7b41f73eee90fdefc73bfc820ab29ba8"] h_url = "https://github.com/jcpeterson/cifar-10h/raw/master/data/" "cifar10h-probs.npy" def __init__( self, root: str | Path, train: bool | None = None, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False, ) -> None: if train: raise ValueError("CIFAR10H does not support training data.") print("WARNING: CIFAR10H cannot be used with Classification routines " "for now.") super().__init__( Path(root), train=False, transform=transform, target_transform=target_transform, download=download, ) if download: self.download_h() if not self._check_specific_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to " "download it." ) self.targets = list(torch.as_tensor(np.load(self.root / self.h_test_list[0]))) def _check_specific_integrity(self) -> bool: filename, md5 = self.h_test_list fpath = self.root / filename return check_integrity(fpath, md5) def download_h(self) -> None: download_url( self.h_url, self.root, filename=self.h_test_list[0], md5=self.h_test_list[1], )