Source code for torch_uncertainty.datamodules.tabular_regression

from pathlib import Path

from torch_uncertainty.datasets.regression import (
    BostonHousing,
    Concrete,
    EnergyEfficiency,
    EnergyPrediction,
    Kin8NM,
    NavalPropulsionPlant,
    PowerPlant,
    Protein,
    TabularRegressionDataset,
    WineQuality,
    Yacht,
)
from torch_uncertainty.datasets.utils import create_train_val_split

from .abstract import TUDataModule


[docs] class TabularRegressionDataModule(TUDataModule): """Base datamodule for UCI regression datasets. Subclasses must set :attr:`dataset_class` to the corresponding :class:`TabularRegressionDataset` subclass. """ training_task = "regression" dataset_class: type[TabularRegressionDataset] | None = None def __init__( self, root: str | Path, batch_size: int, eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, ) -> None: """UCI regression datamodule. Args: root: Root directory of the datasets. batch_size: Number of samples per batch during training. eval_batch_size: Number of samples per batch during evaluation. Defaults to :attr:`batch_size`. val_split: Share of training samples used for validation. Defaults to ``0``. test_split: Share of the full dataset held out as test set. Defaults to ``0.2``. num_workers: Number of data-loading subprocesses. Defaults to ``1``. pin_memory: Whether to pin memory. Defaults to ``True``. persistent_workers: Whether to keep workers alive between epochs. Defaults to ``True``. """ super().__init__( root=root, batch_size=batch_size, eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) if self.__class__.dataset_class is None: raise TypeError( f"{self.__class__.__name__} must set `dataset_class` as a class attribute." ) self.test_split = test_split
[docs] def prepare_data(self) -> None: """Download the dataset if not already present.""" self.dataset_class(root=self.root, download=True)
[docs] def setup(self, stage: str | None = None) -> None: """Create train, val, and test splits. Args: stage: ``"fit"``, ``"test"``, or ``None`` (both). Defaults to ``None``. """ if stage == "fit" or stage is None: full = self.dataset_class( self.root, train=True, download=False, test_split=self.test_split, ) if self.val_split: self.train, self.val = create_train_val_split(full, self.val_split) else: self.train = full self.val = self.dataset_class( self.root, train=False, download=False, test_split=self.test_split, ) if stage == "test" or stage is None: self.test = self.dataset_class( self.root, train=False, download=False, test_split=self.test_split, ) if stage not in ("fit", "test", None): raise ValueError(f"Stage {stage!r} is not supported.")
def _extra_repr(self) -> str: return ""
class BostonHousingDataModule(TabularRegressionDataModule): """Datamodule for the UCI Boston Housing dataset.""" dataset_class = BostonHousing class ConcreteDataModule(TabularRegressionDataModule): """Datamodule for the UCI Concrete Compressive Strength dataset.""" dataset_class = Concrete class EnergyEfficiencyDataModule(TabularRegressionDataModule): """Datamodule for the UCI Energy Efficiency dataset.""" dataset_class = EnergyEfficiency class EnergyPredictionDataModule(TabularRegressionDataModule): """Datamodule for the UCI Appliances Energy Prediction dataset.""" dataset_class = EnergyPrediction class Kin8NMDataModule(TabularRegressionDataModule): """Datamodule for the Kin8NM robot arm kinematics dataset.""" dataset_class = Kin8NM class NavalPropulsionPlantDataModule(TabularRegressionDataModule): """Datamodule for the UCI Naval Propulsion Plants dataset.""" dataset_class = NavalPropulsionPlant class PowerPlantDataModule(TabularRegressionDataModule): """Datamodule for the UCI Combined Cycle Power Plant dataset.""" dataset_class = PowerPlant class ProteinDataModule(TabularRegressionDataModule): """Datamodule for the UCI Protein Tertiary Structure dataset.""" dataset_class = Protein class WineQualityRegressionDataModule(TabularRegressionDataModule): """Datamodule for the UCI Wine Quality dataset (regression).""" dataset_class = WineQuality def __init__( self, root: str | Path, batch_size: int, eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, variant: str = "red", ) -> None: """Wine Quality regression datamodule. Args: root: Root directory of the datasets. batch_size: Number of samples per training batch. eval_batch_size: Samples per evaluation batch. Defaults to :attr:`batch_size`. val_split: Share of training samples used for validation. Defaults to ``0``. test_split: Share of the full dataset held out as test set. Defaults to ``0.2``. num_workers: Data-loading subprocesses. Defaults to ``1``. pin_memory: Whether to pin memory. Defaults to ``True``. persistent_workers: Whether to keep workers alive between epochs. Defaults to ``True``. variant: ``"red"`` or ``"white"``. Defaults to ``"red"``. """ super().__init__( root=root, batch_size=batch_size, eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) self.variant = variant def _extra_repr(self) -> str: return f"variant={self.variant!r}" def prepare_data(self) -> None: WineQuality(root=self.root, download=True, variant=self.variant) def setup(self, stage: str | None = None) -> None: kwargs = { "download": False, "test_split": self.test_split, "variant": self.variant, } if stage == "fit" or stage is None: full = WineQuality(self.root, train=True, **kwargs) if self.val_split: self.train, self.val = create_train_val_split(full, self.val_split) else: self.train = full self.val = WineQuality(self.root, train=False, **kwargs) if stage == "test" or stage is None: self.test = WineQuality(self.root, train=False, **kwargs) if stage not in ("fit", "test", None): raise ValueError(f"Stage {stage!r} is not supported.") class YachtDataModule(TabularRegressionDataModule): """Datamodule for the UCI Yacht Hydrodynamics dataset.""" dataset_class = Yacht