Source code for torch_uncertainty.datasets.classification.tabular.base

import logging
from abc import ABC
from collections.abc import Callable
from pathlib import Path

import torch
from torch import Generator, Tensor
from torch.utils.data import Dataset
from torchvision.datasets.utils import (
    download_and_extract_archive,
    download_url,
)

from torch_uncertainty.datasets.utils import load_arff  # noqa: F401


[docs] class TabularClassificationDataset(Dataset, ABC): """Base class for tabular binary classification datasets. Subclasses must define the class attributes :attr:`url`, :attr:`filename`, :attr:`dataset_name` and implement :meth:`_make_dataset`. If the source is a plain file rather than a zip archive, set :attr:`is_archive` to ``False``; :meth:`download` will then call :func:`~torchvision.datasets.utils.download_url` instead of :func:`~torchvision.datasets.utils.download_and_extract_archive`. """ md5_zip: str | None = None url: str = "" filename: str = "" dataset_name: str = "" num_features: int = 0 need_split: bool = True pre_split: bool = False apply_standardization: bool = True is_archive: bool = True def __init__( self, root: Path | str, transform: Callable | None = None, target_transform: Callable | None = None, binary: bool = True, download: bool = False, train: bool = True, test_split: float = 0.2, split_seed: int = 21893027, download_only: bool = False, ) -> None: """Tabular binary classification dataset. Args: root: Root directory of the datasets. transform: A function/transform that takes in a tensor and returns a transformed version. Defaults to ``None``. target_transform: A function/transform that takes in the target and transforms it. Defaults to ``None``. binary: If ``True``, returns scalar targets; otherwise one-hot encodes them into two classes. Defaults to ``True``. download: If ``True``, downloads the dataset from the internet. If already present, it is not downloaded again. Defaults to ``False``. train: If ``True``, use the training split. Defaults to ``True``. test_split: Fraction of the dataset held out as test set when :attr:`need_split` is ``True``. Defaults to ``0.2``. split_seed: Random seed for the train/test split. Defaults to ``21893027``. download_only: If ``True``, only download the dataset and skip the feature processing. Useful when only triggering the download. Defaults to ``False``. Note: The licenses of the datasets may differ from TorchUncertainty's license. Check before use. """ super().__init__() self.root = Path(root) self.train = train self.transform = transform self.target_transform = target_transform if download: self.download() if download_only: return if self.pre_split: train_data, train_targets, test_data, test_targets = self._make_pre_split_dataset() if self.apply_standardization: self._compute_statistics(train_data) train_data = self._standardize_tensor(train_data) test_data = self._standardize_tensor(test_data) if self.train: self.data, self.targets = train_data, train_targets else: self.data, self.targets = test_data, test_targets else: self._make_dataset() if self.need_split: gen = Generator().manual_seed(split_seed) num_train = int((1 - test_split) * len(self)) train_idx = torch.ones(len(self)).multinomial( num_samples=num_train, replacement=False, generator=gen, ) mask = torch.zeros(len(self), dtype=torch.bool) mask[train_idx] = True test_idx = torch.nonzero(~mask, as_tuple=False).squeeze(1) if self.apply_standardization: # Compute statistics from the training split only to avoid # leaking test-set information into the normalization. self._compute_statistics(self.data[train_idx]) self._standardize() self.split_idx = train_idx if self.train else test_idx self.data = self.data[self.split_idx] self.targets = self.targets[self.split_idx] elif self.apply_standardization: self._compute_statistics() self._standardize() self._postprocess_targets(binary) def __len__(self) -> int: """Get the number of rows of the tabular data.""" return self.data.shape[0] def _check_integrity(self) -> bool: return (self.root / self.dataset_name / self.filename).is_file() def _standardize_tensor(self, data: Tensor) -> Tensor: out = (data - self.data_mean) / self.data_std return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0) def _standardize(self) -> None: self.data = self._standardize_tensor(self.data) def _compute_statistics(self, data: Tensor | None = None) -> None: d = self.data if data is None else data # Ignore NaN values when computing statistics so a single residual NaN # doesn't poison the entire column. self.data_mean = torch.nanmean(d, dim=0) # nanstd: var via nanmean of squared deviations centered = d - self.data_mean n_valid = (~torch.isnan(d)).sum(dim=0).clamp(min=1) var = torch.nansum(centered**2, dim=0) / (n_valid - (n_valid > 1).long()).clamp(min=1) self.data_std = var.sqrt() self.data_std[~torch.isfinite(self.data_std)] = 1 self.data_std[self.data_std == 0] = 1
[docs] def download(self) -> None: """Download and, if needed, extract the dataset.""" if self._check_integrity(): logging.info("Files already downloaded and verified") return download_root = self.root / self.dataset_name if self.is_archive: download_and_extract_archive( self.url, download_root=download_root, filename=self.dataset_name + ".zip", md5=self.md5_zip, ) else: download_url( self.url, root=str(download_root), filename=self.filename, md5=self.md5_zip, )
def _postprocess_targets(self, binary: bool) -> None: """Post-process targets after splitting. The default behaviour one-hot encodes targets into two classes when ``binary`` is ``False``. Override this in subclasses that require different target handling (e.g. multi-class datasets). """ if not binary: self.targets = torch.nn.functional.one_hot(self.targets, num_classes=2) def _make_dataset(self) -> None: """Populate ``self.data`` (float32 tensor) and ``self.targets`` (long tensor). Required for datasets without pre-existing train/test files. Datasets with separate train/test files should instead set :attr:`pre_split` and implement :meth:`_make_pre_split_dataset`. """ raise NotImplementedError def _make_pre_split_dataset(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Load both train and test files and return ``(train_data, train_targets, test_data, test_targets)``. Implementations must align feature columns between train and test (e.g. by concatenating before one-hot encoding) and use training-set statistics for any imputation. """ raise NotImplementedError def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: """Get the row of id index of the tabular data.""" data = self.data[index, :] if self.transform is not None: data = self.transform(data) target = self.targets[index] if self.target_transform is not None: target = self.target_transform(target) return data, target