
Source code for torch_uncertainty.datamodules.uci_regression

from functools import partial
from pathlib import Path

from torch import Generator
from import random_split

from torch_uncertainty.datasets.regression import UCIRegression

from .abstract import TUDataModule

[docs]class UCIRegressionDataModule(TUDataModule): training_task = "regression" def __init__( self, root: str | Path, batch_size: int, dataset_name: str, val_split: float = 0.0, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, input_shape: tuple[int, ...] | None = None, split_seed: int = 42, ) -> None: """The UCI regression datasets. Args: root (string): Root directory of the datasets. batch_size (int): The batch size for training and testing. dataset_name (string, optional): The name of the dataset. One of "boston-housing", "concrete", "energy", "kin8nm", "naval-propulsion-plant", "power-plant", "protein", "wine-quality-red", and "yacht". val_split (float, optional): Share of validation samples. Defaults to ``0``. num_workers (int, optional): How many subprocesses to use for data loading. Defaults to ``1``. pin_memory (bool, optional): Whether to pin memory in the GPU. Defaults to ``True``. persistent_workers (bool, optional): Whether to use persistent workers. Defaults to ``True``. input_shape (tuple, optional): The shape of the input data. Defaults to ``None``. split_seed (int, optional): The seed to use for splitting the dataset. Defaults to ``42``. """ super().__init__( root=root, batch_size=batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) self.dataset = partial(UCIRegression, dataset_name=dataset_name, seed=split_seed) self.input_shape = input_shape self.gen = Generator().manual_seed(split_seed)
[docs] def prepare_data(self) -> None: """Download the dataset.""" self.dataset(root=self.root, download=True)
# ruff: noqa: ARG002
[docs] def setup(self, stage: str | None = None) -> None: """Split the datasets into train, val, and test.""" full = self.dataset( self.root, download=False, ) self.train, self.test, self.val = random_split( full, [ 0.8 - self.val_split, 0.2, self.val_split, ], generator=self.gen, ) if self.val_split == 0: self.val = self.test