ImageNetDataModule#
- class torch_uncertainty.datamodules.ImageNetDataModule(root, batch_size, eval_batch_size=None, eval_ood=False, eval_shift=False, num_tta=1, shift_severity=1, val_split=None, postprocess_set='val', train_transform=None, test_transform=None, ood_ds='openimage-o', test_alt=None, procedure=None, train_size=224, interpolation='bilinear', basic_augment=True, rand_augment_opt=None, num_workers=1, pin_memory=True, persistent_workers=True)[source]#
DataModule for the ImageNet dataset.
This datamodule uses ImageNet as In-distribution dataset, OpenImage-O, INaturalist, ImageNet-0, SVHN or DTD as Out-of-distribution dataset and ImageNet-C as shifted dataset.
- Parameters:
root (str) – Root directory of the datasets.
batch_size (int) – Number of samples per batch during training.
eval_batch_size (int | None) – Number of samples per batch during evaluation (val and test). Set to batch_size if
None
. Defaults toNone
.eval_ood (bool) – Whether to evaluate out-of-distribution performance. Defaults to
False
.eval_shift (bool) – Whether to evaluate on shifted data. Defaults to
False
.num_tta (int) – Number of test-time augmentations (TTA). Defaults to
1
(no TTA).shift_severity (int) – Severity of the shift. Defaults to
1
.val_split (float or Path) – Share of samples to use for validation or path to a yaml file containing a list of validation images ids. Defaults to
0.0
.postprocess_set (str, optional) – The post-hoc calibration dataset to use for the post-processing method. Defaults to
val
.train_transform (nn.Module | None) – Custom training transform. Defaults to
None
. If not provided, a default transform is used.test_transform (nn.Module | None) – Custom test transform. Defaults to
None
. If not provided, a default transform is used.ood_ds (str) – Which out-of-distribution dataset to use. Defaults to
"openimage-o"
.test_alt (str) – Which test set to use. Defaults to
None
.procedure (str) – Which procedure to use. Defaults to
None
. Only used iftrain_transform
is not provided.train_size (int) – Size of training images. Defaults to
224
.interpolation (str) – Interpolation method for the Resize Crops. Defaults to
"bilinear"
. Only used iftrain_transform
is not provided.basic_augment (bool) – Whether to apply base augmentations. Defaults to
True
. Only used iftrain_transform
is not provided.rand_augment_opt (str) – Which RandAugment to use. Defaults to
None
. Only used iftrain_transform
is not provided.num_workers (int) – Number of workers to use for data loading. Defaults to
1
.pin_memory (bool) – Whether to pin memory. Defaults to
True
.persistent_workers (bool) – Whether to use persistent workers. Defaults to
True
.