Source code for torch_uncertainty.datasets.regression.tabular.base

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from pathlib import Path

import torch
from torch import Generator, Tensor, tensor
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive, download_url

from torch_uncertainty.datasets.utils import load_arff  # noqa: F401


[docs] class TabularRegressionDataset(Dataset, ABC): """Abstract base class for UCI regression datasets. Mirrors :class:`~torch_uncertainty.datasets.classification.tabular.TabularClassificationDataset`: subclasses set :attr:`url`, :attr:`filename`, :attr:`dataset_name` and implement :meth:`_make_dataset`. Standardization is computed from the training split only and then applied to both splits. """ root_appendix = "uci_regression" url: str = "" filename: str = "" dataset_name: str = "" need_split: bool = True apply_standardization: bool = True is_archive: bool = True md5: str | None = None def __init__( self, root: Path | str, transform: Callable | None = None, target_transform: Callable | None = None, download: bool = False, train: bool = True, test_split: float = 0.2, split_seed: int = 42, ) -> None: """UCI regression dataset. Args: root: Root directory of the datasets. transform: Transform applied to each input sample. Defaults to ``None``. target_transform: Transform applied to each target. Defaults to ``None``. download: If ``True``, downloads the dataset. Defaults to ``False``. train: If ``True``, returns the training split. Defaults to ``True``. test_split: Fraction of the dataset held out as test set. Defaults to ``0.2``. split_seed: Random seed for the train/test split. Defaults to ``42``. Note: The licenses of the datasets may differ from TorchUncertainty's license. Check before use. """ super().__init__() self.root = Path(root) self.train = train self.transform = transform self.target_transform = target_transform if download: self.download() self._make_dataset() if self.need_split: gen = Generator().manual_seed(split_seed) train_idx = torch.ones(len(self)).multinomial( num_samples=int((1 - test_split) * len(self)), replacement=False, generator=gen, ) test_idx = tensor([i for i in range(len(self)) if i not in train_idx]) if self.apply_standardization: self._compute_statistics(self.data[train_idx], self.targets[train_idx]) self._standardize() self.split_idx = train_idx if self.train else test_idx self.data = self.data[self.split_idx] self.targets = self.targets[self.split_idx] elif self.apply_standardization: self._compute_statistics() self._standardize() @property def _data_path(self) -> Path: return self.root / self.root_appendix / self.dataset_name def __len__(self) -> int: """Get the length of the tabular regression dataset.""" return self.data.shape[0] def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: """Get an element of the tabular regression dataset.""" data = self.data[index] if self.transform is not None: data = self.transform(data) target = self.targets[index] if self.target_transform is not None: target = self.target_transform(target) return data, target def _check_integrity(self) -> bool: return (self._data_path / self.filename).is_file() def _compute_statistics( self, data: Tensor | None = None, targets: Tensor | None = None, ) -> None: # Use float64 to avoid precision loss for large-valued features when # computing the mean (e.g. NavalPropulsionPlant columns at ~1e9). d = (self.data if data is None else data).double() t = (self.targets if targets is None else targets).double() self.data_mean = d.mean(dim=0).float() self.data_std = d.std(dim=0).float() self.data_std[self.data_std == 0] = 1 self.target_mean = t.mean(dim=0).float() self.target_std = t.std(dim=0).float() self.target_std[self.target_std == 0] = 1 def _standardize(self) -> None: self.data = (self.data - self.data_mean) / self.data_std self.targets = (self.targets - self.target_mean) / self.target_std
[docs] def download(self) -> None: """Download and, if needed, extract the dataset.""" if self._check_integrity(): logging.info("Files already downloaded and verified") return self._data_path.mkdir(parents=True, exist_ok=True) if self.is_archive: download_and_extract_archive( self.url, download_root=self._data_path, filename=self.dataset_name + ".zip", md5=self.md5, ) else: download_url( self.url, root=str(self._data_path), filename=self.filename, md5=self.md5, )
@abstractmethod def _make_dataset(self) -> None: """Populate ``self.data`` and ``self.targets`` as float32 tensors."""