Shortcuts

Source code for torch_uncertainty.datasets.classification.cifar.cifar_c

import logging
from collections.abc import Callable
from pathlib import Path

import numpy as np
from torch import Tensor
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
)


[docs]class CIFAR10C(VisionDataset): """The corrupted CIFAR-10-C Dataset. Args: root (str): Root directory of the datasets. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. Defaults to None. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``cifarc_subsets``. shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. References: Benchmarking neural network robustness to common corruptions and perturbations. Dan Hendrycks and Thomas Dietterich. In ICLR, 2019. """ base_folder = "CIFAR-10-C" tgz_md5 = "56bf5dcef84df0e2308c6dcbcbbd8499" cifarc_subsets = [ "brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise", "snow", "spatter", "speckle_noise", "zoom_blur", ] ctest_list = [ ["fog.npy", "7b397314b5670f825465fbcd1f6e9ccd"], ["jpeg_compression.npy", "2b9cc4c864e0193bb64db8d7728f8187"], ["zoom_blur.npy", "6ea8e63f1c5cdee1517533840641641b"], ["speckle_noise.npy", "ef00b87611792b00df09c0b0237a1e30"], ["glass_blur.npy", "7361fb4019269e02dbf6925f083e8629"], ["spatter.npy", "8a5a3903a7f8f65b59501a6093b4311e"], ["shot_noise.npy", "3a7239bb118894f013d9bf1984be7f11"], ["defocus_blur.npy", "7d1322666342a0702b1957e92f6254bc"], ["elastic_transform.npy", "9421657c6cd452429cf6ce96cc412b5f"], ["gaussian_blur.npy", "c33370155bc9b055fb4a89113d3c559d"], ["frost.npy", "31f6ab3bce1d9934abfb0cc13656f141"], ["saturate.npy", "1cfae0964219c5102abbb883e538cc56"], ["brightness.npy", "0a81ef75e0b523c3383219c330a85d48"], ["snow.npy", "bb238de8555123da9c282dea23bd6e55"], ["gaussian_noise.npy", "ecaf8b9a2399ffeda7680934c33405fd"], ["motion_blur.npy", "fffa5f852ff7ad299cfe8a7643f090f4"], ["contrast.npy", "3c8262171c51307f916c30a3308235a8"], ["impulse_noise.npy", "2090e01c83519ec51427e65116af6b1a"], ["labels.npy", "c439b113295ed5254878798ffe28fd54"], ["pixelate.npy", "0f14f7e2db14288304e1de10df16832f"], ] url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar" filename = "CIFAR-10-C.tar" def __init__( self, root: Path | str, transform: Callable | None = None, target_transform: Callable | None = None, subset: str = "all", shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) # Download the new targets if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found. You can use download=True to download it." ) super().__init__( root=self.root / self.base_folder, transform=transform, target_transform=target_transform, ) if subset not in ["all", *self.cifarc_subsets]: raise ValueError( f"The subset '{subset}' does not exist in CIFAR-C." ) self.subset = subset self.shift_severity = shift_severity if shift_severity not in list(range(1, 6)): raise ValueError( "Corruptions shift_severity should be chosen between 1 and 5 " "included." ) samples, labels = self.make_dataset( self.root, self.subset, self.shift_severity ) self.samples = samples self.labels = labels.astype(np.int64)
[docs] def make_dataset( self, root: Path, subset: str, shift_severity: int ) -> tuple[np.ndarray, np.ndarray]: r"""Make the CIFAR-C dataset. Build the corrupted dataset according to the chosen subset and shift_severity. If the subset is 'all', gather all corruption types in the dataset. Args: root (Path):The path to the dataset. subset (str): The name of the corruption subset to be used. Choose `all` for the dataset to contain all subsets. shift_severity (int): The shift_severity of the corruption applied to the images. Returns: Tuple[np.ndarray, np.ndarray]: The samples and labels of the chosen. """ if subset == "all": labels: np.ndarray = np.load(root / "labels.npy")[ (shift_severity - 1) * 10000 : shift_severity * 10000 ] sample_arrays = [ np.load(root / (cifar_subset + ".npy"))[ (shift_severity - 1) * 10000 : shift_severity * 10000 ] for cifar_subset in self.cifarc_subsets ] samples = np.concatenate(sample_arrays, axis=0) labels = np.tile(labels, len(self.cifarc_subsets)) else: samples: np.ndarray = np.load(root / (subset + ".npy"))[ (shift_severity - 1) * 10000 : shift_severity * 10000 ] labels: np.ndarray = np.load(root / "labels.npy")[ (shift_severity - 1) * 10000 : shift_severity * 10000 ] return samples, labels
def __len__(self) -> int: """The number of samples in the dataset.""" return self.labels.shape[0] def __getitem__(self, index: int) -> tuple[np.ndarray | Tensor, int]: """Get the samples and targets of the dataset. Args: index (int): The index of the sample to get. """ sample, target = ( self.samples[index], self.labels[index], ) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def _check_integrity(self) -> bool: """Check the integrity of the dataset.""" for filename, md5 in self.ctest_list: fpath = self.root / self.base_folder / filename if not check_integrity(fpath, md5): return False return True
[docs] def download(self) -> None: """Download the dataset.""" if self._check_integrity(): logging.info("Files already downloaded and verified") return download_and_extract_archive( self.url, self.root, filename=self.filename, md5=self.tgz_md5 )
[docs]class CIFAR100C(CIFAR10C): base_folder = "CIFAR-100-C" tgz_md5 = "11f0ed0f1191edbf9fa23466ae6021d3" ctest_list = [ ["fog.npy", "4efc7ebd5e82b028bdbe13048e3ea564"], ["jpeg_compression.npy", "c851b7f1324e1d2ffddeb76920576d11"], ["zoom_blur.npy", "0204613400c034a81c4830d5df81cb82"], ["speckle_noise.npy", "e3f215b1a0f9fd9fd6f0d1cf94a7ce99"], ["glass_blur.npy", "0bf384f38e5ccbf8dd479d9059b913e1"], ["spatter.npy", "12ccf41d62564d36e1f6a6ada5022728"], ["shot_noise.npy", "b0a1fa6e1e465a747c1b204b1914048a"], ["defocus_blur.npy", "d923e3d9c585a27f0956e2f2ad832564"], ["elastic_transform.npy", "a0792bd6581f6810878be71acedfc65a"], ["gaussian_blur.npy", "5204ba0d557839772ef5a4196a052c3e"], ["frost.npy", "3a39c6823bdfaa0bf8b12fe7004b8117"], ["saturate.npy", "c0697e9fdd646916a61e9c312c77bf6b"], ["brightness.npy", "f22d7195aecd6abb541e27fca230c171"], ["snow.npy", "0237be164583af146b7b144e73b43465"], ["gaussian_noise.npy", "ecc4d366eac432bdf25c024086f5e97d"], ["motion_blur.npy", "732a7e2e54152ff97c742d4c388c5516"], ["contrast.npy", "322bb385f1d05154ee197ca16535f71e"], ["impulse_noise.npy", "3b3c210ddfa0b5cb918ff4537a429fef"], ["labels.npy", "bb4026e9ce52996b95f439544568cdb2"], ["pixelate.npy", "96c00c60f144539e14cffb02ddbd0640"], ] url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar" filename = "CIFAR-100-C.tar"