CIFAR100DataModule#

class torch_uncertainty.datamodules.CIFAR100DataModule(root, batch_size, eval_batch_size=None, eval_ood=False, eval_shift=False, num_tta=1, shift_severity=1, val_split=None, postprocess_set='val', train_transform=None, test_transform=None, basic_augment=True, cutout=None, randaugment=False, auto_augment=None, num_dataloaders=1, num_workers=1, pin_memory=True, persistent_workers=True)[source]#

DataModule for CIFAR100.

Parameters:
  • root (str) – Root directory of the datasets.

  • eval_ood (bool) – Whether to evaluate out-of-distribution performance.

  • eval_shift (bool) – Whether to evaluate on shifted data. Defaults to False.

  • batch_size (int) – Number of samples per batch during training.

  • eval_batch_size (int | None) – Number of samples per batch during evaluation (val and test). Set to batch_size if None. Defaults to None.

  • val_split (float) – Share of samples to use for validation. Defaults to 0.0.

  • postprocess_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to val.

  • train_transform (nn.Module | None) – Custom training transform. Defaults to None. If not provided, a default transform is used.

  • test_transform (nn.Module | None) – Custom test transform. Defaults to None. If not provided, a default transform is used.

  • basic_augment (bool) – Whether to apply base augmentations. Defaults to True. Only used if train_transform is not provided.

  • cutout (int) – Size of cutout to apply to images. Defaults to None. Only used if train_transform is not provided.

  • randaugment (bool) – Whether to apply RandAugment. Defaults to False. Only used if train_transform is not provided.

  • auto_augment (str) – Which auto-augment to apply. Defaults to None.

  • num_tta (int) – Number of test-time augmentations (TTA). Defaults to 1 (no TTA).

  • shift_severity (int) – Severity of corruption to apply to CIFAR100-C. Defaults to 1.

  • num_dataloaders (int) – Number of dataloaders to use. Defaults to 1.

  • num_workers (int) – Number of workers to use for data loading. Defaults to 1.

  • pin_memory (bool) – Whether to pin memory. Defaults to True.

  • persistent_workers (bool) – Whether to use persistent workers. Defaults to True.

test_dataloader()[source]#

Get test dataloaders.

Returns:

test set for in distribution data, SVHN data, and/or CIFAR-100C data.

Return type:

list[DataLoader]

train_dataloader()[source]#

Get the training dataloader for CIFAR100.

Returns:

CIFAR100 training dataloader.

Return type:

DataLoader