Source code for torch_uncertainty.datamodules.classification.tabular

from pathlib import Path

from torch_uncertainty.datamodules.abstract import TUDataModule
from torch_uncertainty.datasets.classification.tabular import (
    HTRU2,
    AdultCensusIncome,
    AmazonAccess,
    APSFailure,
    BankMarketing,
    CreditApproval,
    DOTA2Games,
    GermanCredit,
    HiggsBoson,
    KDDChurn,
    OnlineShoppers,
    PimaDiabetes,
    SpamBase,
    TabularClassificationDataset,
    TelcoChurn,
    WineQuality,
)
from torch_uncertainty.datasets.utils import create_train_val_split


[docs] class TabularClassificationDataModule(TUDataModule): """Base datamodule for tabular binary classification datasets. Subclasses must set :attr:`dataset_class` to the corresponding :class:`TabularClassificationDataset` subclass. No ``__init__`` override is needed in the subclass. Example:: class HTRU2DataModule(TabularClassificationDataModule): dataset_class = HTRU2 """ training_task = "classification" dataset_class: type[TabularClassificationDataset] | None = None def __init__( self, root: str | Path, batch_size: int, eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, binary: bool = True, ) -> None: """Tabular binary classification datamodule. Args: root: Root directory of the datasets. batch_size: Number of samples per batch during training. eval_batch_size: Number of samples per batch during evaluation. Defaults to :attr:`batch_size`. val_split: Share of the training samples to use as validation set. Defaults to ``0``. test_split: Share of the full dataset to hold out as test set (used when the dataset has no predefined split). Defaults to ``0.2``. num_workers: Number of data-loading subprocesses. Defaults to ``1``. pin_memory: Whether to pin memory. Defaults to ``True``. persistent_workers: Whether to keep workers alive between epochs. Defaults to ``True``. binary: If ``True``, returns scalar targets. Defaults to ``True``. """ super().__init__( root=root, batch_size=batch_size, eval_batch_size=eval_batch_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, ) if self.__class__.dataset_class is None: raise TypeError( f"{self.__class__.__name__} must set `dataset_class` as a class attribute." ) self.test_split = test_split self.binary = binary
[docs] def prepare_data(self) -> None: """Download the dataset if not already present.""" self.dataset_class(root=self.root, download=True, download_only=True)
[docs] def setup(self, stage: str | None = None) -> None: """Create train, val, and test splits. Args: stage: ``"fit"``, ``"test"``, or ``None``. Defaults to ``None``. """ if stage == "fit" or stage is None: full = self.dataset_class( self.root, train=True, download=False, binary=self.binary, test_split=self.test_split, ) if self.val_split: self.train, self.val = create_train_val_split(full, self.val_split) else: self.train = full self.val = self.dataset_class( self.root, train=False, download=False, binary=self.binary, test_split=self.test_split, ) if stage == "test" or stage is None: self.test = self.dataset_class( self.root, train=False, download=False, binary=self.binary, test_split=self.test_split, ) if stage not in ("fit", "test", None): raise ValueError(f"Stage {stage!r} is not supported.")
[docs] class AdultCensusIncomeDataModule(TabularClassificationDataModule): """Datamodule for the UCI Adult Census Income dataset.""" dataset_class = AdultCensusIncome
[docs] class AmazonAccessDataModule(TabularClassificationDataModule): """Datamodule for the Amazon Employee Access dataset.""" dataset_class = AmazonAccess
[docs] class APSFailureDataModule(TabularClassificationDataModule): """Datamodule for the UCI APS Failure at Scania Trucks dataset.""" dataset_class = APSFailure
[docs] class BankMarketingDataModule(TabularClassificationDataModule): """Datamodule for the UCI Bank Marketing dataset.""" dataset_class = BankMarketing
[docs] class CreditApprovalDataModule(TabularClassificationDataModule): """Datamodule for the UCI Credit Approval dataset.""" dataset_class = CreditApproval
[docs] class DOTA2GamesDataModule(TabularClassificationDataModule): """Datamodule for the UCI DOTA 2 Games Results dataset.""" dataset_class = DOTA2Games
[docs] class GermanCreditDataModule(TabularClassificationDataModule): """Datamodule for the UCI Statlog German Credit dataset.""" dataset_class = GermanCredit
[docs] class HiggsBosonDataModule(TabularClassificationDataModule): """Datamodule for the Higgs Boson dataset (OpenML 23512).""" dataset_class = HiggsBoson
[docs] class HTRU2DataModule(TabularClassificationDataModule): """Datamodule for the UCI HTRU2 pulsar dataset.""" dataset_class = HTRU2
[docs] class KDDChurnDataModule(TabularClassificationDataModule): """Datamodule for the KDD Cup 2009 Customer Churn dataset (OpenML 1112).""" dataset_class = KDDChurn
[docs] class OnlineShoppersDataModule(TabularClassificationDataModule): """Datamodule for the UCI Online Shoppers Purchasing Intention dataset.""" dataset_class = OnlineShoppers
[docs] class PimaDiabetesDataModule(TabularClassificationDataModule): """Datamodule for the UCI Pima Indians Diabetes dataset.""" dataset_class = PimaDiabetes
[docs] class SpamBaseDataModule(TabularClassificationDataModule): """Datamodule for the UCI SpamBase e-mail spam dataset.""" dataset_class = SpamBase
[docs] class TelcoChurnDataModule(TabularClassificationDataModule): """Datamodule for the Telecom Customer Churn dataset (OpenML 40701).""" dataset_class = TelcoChurn
[docs] class WineQualityDataModule(TabularClassificationDataModule): """Datamodule for the UCI Wine Quality dataset.""" dataset_class = WineQuality def __init__( self, root: str | Path, batch_size: int, eval_batch_size: int | None = None, val_split: float = 0.0, test_split: float = 0.2, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, binary: bool = True, variant: str = "red", threshold: int = 6, ) -> None: """Wine Quality datamodule. Args: root: Root directory of the datasets. batch_size: Number of samples per training batch. eval_batch_size: Samples per evaluation batch. Defaults to :attr:`batch_size`. val_split: Share of training samples used for validation. Defaults to ``0``. test_split: Share of the full dataset held out as test set. Defaults to ``0.2``. num_workers: Data-loading subprocesses. Defaults to ``1``. pin_memory: Whether to pin memory. Defaults to ``True``. persistent_workers: Whether to keep workers alive between epochs. Defaults to ``True``. binary: If ``True``, binarises quality scores. Defaults to ``True``. variant: ``"red"`` or ``"white"``. Defaults to ``"red"``. threshold: Quality threshold for binary mode. Defaults to ``6``. """ super().__init__( root=root, batch_size=batch_size, eval_batch_size=eval_batch_size, val_split=val_split, test_split=test_split, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, binary=binary, ) self.variant = variant self.threshold = threshold def prepare_data(self) -> None: self.dataset_class(root=self.root, download=True, download_only=True, variant=self.variant) def setup(self, stage: str | None = None) -> None: kwargs = { "download": False, "binary": self.binary, "test_split": self.test_split, "variant": self.variant, "threshold": self.threshold, } if stage == "fit" or stage is None: full = self.dataset_class(self.root, train=True, **kwargs) if self.val_split: self.train, self.val = create_train_val_split(full, self.val_split) else: self.train = full self.val = self.dataset_class(self.root, train=False, **kwargs) if stage == "test" or stage is None: self.test = self.dataset_class(self.root, train=False, **kwargs) if stage not in ("fit", "test", None): raise ValueError(f"Stage {stage!r} is not supported.")