CIFAR10DataModule#
- class torch_uncertainty.datamodules.CIFAR10DataModule(root, batch_size, eval_batch_size=None, eval_ood=False, eval_shift=False, num_tta=1, shift_severity=1, val_split=None, postprocess_set='val', num_workers=1, train_transform=None, test_transform=None, basic_augment=True, cutout=None, randaugment=False, auto_augment=None, test_alt=None, num_dataloaders=1, pin_memory=True, persistent_workers=True)[source]#
DataModule for CIFAR10.
- Parameters:
root (
str|Path) – Root directory of the datasets.batch_size (
int) – Number of samples per batch during training.eval_batch_size (
int|None) – Number of samples per batch during evaluation (val and test). Set to batch_size ifNone. Defaults toNone.eval_ood (
bool) – Whether to evaluate on out-of-distribution data. Defaults toFalse.eval_shift (
bool) – Whether to evaluate on shifted data. Defaults toFalse.val_split (
float|None) – Share of samples to use for validation. Defaults to0.0.postprocess_set (
Literal['val','test']) – The post-hoc calibration dataset to use for the post-processing method. Defaults toval.num_workers (
int) – Number of workers to use for data loading. Defaults to1.train_transform (
Module|None) – Custom training transform. Defaults toNone. If not provided, a default transform is used.test_transform (
Module|None) – Custom test transform. Defaults toNone. If not provided, a default transform is used.basic_augment (
bool) – Whether to apply base augmentations. Defaults toTrue. Only used iftrain_transformis not provided.cutout (
int|None) – Size of cutout to apply to images. Defaults toNone. Only used iftrain_transformis not provided.randaugment (
bool) – Whether to apply RandAugment. Defaults toFalse. Only used iftrain_transformis not provided.auto_augment (
str|None) – Which auto-augment to apply. Defaults toNone. Only used iftrain_transformis not provided.test_alt (
Optional[Literal['h']]) – Which test set to use. Defaults toNone.num_tta (
int) – Number of test-time augmentations (TTA). Defaults to1(no TTA).shift_severity (
int) – Severity of corruption to apply for CIFAR10-C. Defaults to1.num_dataloaders (
int) – Number of dataloaders to use. Defaults to1.pin_memory (
bool) – Whether to pin memory. Defaults toTrue.persistent_workers (
bool) – Whether to use persistent workers. Defaults toTrue.