MNISTDataModule#
- class torch_uncertainty.datamodules.MNISTDataModule(root, batch_size, eval_batch_size=None, eval_ood=False, eval_shift=False, ood_ds='fashion', num_tta=1, val_split=None, postprocess_set='val', num_workers=1, train_transform=None, test_transform=None, ood_transform=None, basic_augment=True, cutout=None, pin_memory=True, persistent_workers=True)[source]#
DataModule for MNIST.
- Parameters:
root (
str|Path) – Root directory of the datasets.eval_ood (
bool) – Whether to evaluate on out-of-distribution data. Defaults toFalse.eval_shift (
bool) – Whether to evaluate on shifted data. Defaults toFalse.batch_size (
int) – Number of samples per batch during training.eval_batch_size (
int|None) – Number of samples per batch during evaluation (val and test). Set tobatch_sizeifNone. Defaults toNone.ood_ds (
Literal['fashion','notMNIST']) – Which out-of-distribution dataset to use. Defaults to"fashion"; fashion stands for FashionMNIST and notMNIST for notMNIST.val_split (
float|None) – Share of samples to use for validation. Defaults to0.0.num_tta (
int) – Number of test-time augmentations (TTA). Defaults to1(no TTA).postprocess_set (
Literal['val','test']) – The post-hoc calibration dataset to use for the post-processing method. Defaults toval.num_workers (
int) – Number of workers to use for data loading. Defaults to1.train_transform (
Module|None) – Custom training transform. Defaults toNone. If not provided, a default transform is used.test_transform (
Module|None) – Custom test transform. Defaults toNone. If not provided, a default transform is used.ood_transform (
Module|None) – Custom transform for out-of-distribution datasets. Defaults toNone. If not provided, a default transform is used.basic_augment (
bool) – Whether to apply base augmentations. Defaults toTrue.cutout (
int|None) – Size of cutout to apply to images. Defaults toNone.pin_memory (
bool) – Whether to pin memory. Defaults toTrue.persistent_workers (
bool) – Whether to use persistent workers. Defaults toTrue.