Shortcuts

Source code for torch_uncertainty.datamodules.classification.tiny_imagenet

from pathlib import Path
from typing import Literal

import numpy as np
import torchvision.transforms as T
from numpy.typing import ArrayLike
from timm.data.auto_augment import rand_augment_transform
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
from torchvision.datasets import DTD, SVHN

from torch_uncertainty.datamodules import TUDataModule
from torch_uncertainty.datasets.classification import (
    ImageNetO,
    TinyImageNet,
    TinyImageNetC,
)
from torch_uncertainty.utils import (
    create_train_val_split,
    interpolation_modes_from_str,
)


[docs]class TinyImageNetDataModule(TUDataModule): num_classes = 200 num_channels = 3 training_task = "classification" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, ood_ds: str = "svhn", interpolation: str = "bilinear", basic_augment: bool = True, rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, ) -> None: super().__init__( root=root, batch_size=batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) self.eval_ood = eval_ood self.eval_shift = eval_shift self.shift_severity = shift_severity self.ood_ds = ood_ds self.interpolation = interpolation_modes_from_str(interpolation) self.dataset = TinyImageNet if ood_ds == "imagenet-o": self.ood_dataset = ImageNetO elif ood_ds == "svhn": self.ood_dataset = SVHN elif ood_ds == "textures": self.ood_dataset = DTD else: raise ValueError( f"OOD dataset {ood_ds} not supported for TinyImageNet." ) self.shift_dataset = TinyImageNetC if basic_augment: basic_transform = T.Compose( [ T.RandomCrop(64, padding=4), T.RandomHorizontalFlip(), ] ) else: basic_transform = nn.Identity() if rand_augment_opt is not None: main_transform = rand_augment_transform(rand_augment_opt, {}) else: main_transform = nn.Identity() self.train_transform = T.Compose( [ T.ToTensor(), basic_transform, main_transform, T.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = T.Compose( [ T.ToTensor(), T.Resize(64, interpolation=self.interpolation), T.Normalize(mean=self.mean, std=self.std), ] ) def _verify_splits(self, split: str) -> None: # coverage: ignore if split not in list(self.root.iterdir()): raise FileNotFoundError( f"a {split} TinyImagenet split was not found in {self.root}," f" make sure the folder contains a subfolder named {split}" ) def prepare_data(self) -> None: # coverage: ignore if self.eval_ood: if self.ood_ds != "textures": self.ood_dataset( self.root, split="test", download=True, transform=self.test_transform, ) else: ConcatDataset( [ self.ood_dataset( self.root, split="train", download=True, transform=self.test_transform, ), self.ood_dataset( self.root, split="val", download=True, transform=self.test_transform, ), self.ood_dataset( self.root, split="test", download=True, transform=self.test_transform, ), ] ) if self.eval_shift: self.shift_dataset( self.root, download=True, transform=self.test_transform, shift_severity=self.shift_severity, ) def setup(self, stage: Literal["fit", "test"] | None = None) -> None: if stage == "fit" or stage is None: full = self.dataset( self.root, split="train", transform=self.train_transform, ) if self.val_split: self.train, self.val = create_train_val_split( full, self.val_split, self.test_transform, ) else: self.train = full self.val = self.dataset( self.root, split="val", transform=self.test_transform, ) if stage == "test" or stage is None: self.test = self.dataset( self.root, split="val", transform=self.test_transform, ) if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") if self.eval_ood: if self.ood_ds == "textures": self.ood = ConcatDataset( [ self.ood_dataset( self.root, split="train", download=True, transform=self.test_transform, ), self.ood_dataset( self.root, split="val", download=True, transform=self.test_transform, ), self.ood_dataset( self.root, split="test", download=True, transform=self.test_transform, ), ] ) else: self.ood = self.ood_dataset( self.root, split="test", transform=self.test_transform, ) if self.eval_shift: self.shift = self.shift_dataset( self.root, download=False, shift_severity=self.shift_severity, transform=self.test_transform, )
[docs] def train_dataloader(self) -> DataLoader: r"""Get the training dataloader for TinyImageNet. Return: DataLoader: TinyImageNet training dataloader. """ return self._data_loader(self.train, shuffle=True)
[docs] def val_dataloader(self) -> DataLoader: r"""Get the validation dataloader for TinyImageNet. Return: DataLoader: TinyImageNet validation dataloader. """ return self._data_loader(self.val)
[docs] def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. Return: list[DataLoader]: test set for in distribution data and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) if self.eval_shift: dataloader.append(self._data_loader(self.shift)) return dataloader
def _get_train_data(self) -> ArrayLike: if self.val_split: return self.train.dataset.samples[self.train.indices] return self.train.samples def _get_train_targets(self) -> ArrayLike: if self.val_split: return np.array(self.train.dataset.label_data)[self.train.indices] return np.array(self.train.label_data)