API Reference#

Routines#

The routine are the main building blocks of the library. They define the framework in which the models are trained and evaluated. They allow for easy computation of different metrics crucial for uncertainty estimation in different contexts, namely classification, regression and segmentation.

Classification#

ClassificationRoutine

Routine for training & testing on classification tasks.

Segmentation#

SegmentationRoutine

Routine for training & testing on segmentation tasks.

Regression#

RegressionRoutine

Routine for training & testing on regression tasks.

Pixelwise Regression#

PixelRegressionRoutine

Routine for training & testing on pixel regression tasks.

Baselines#

Warning

The baselines will soon be removed from the library to avoid confusion with the routines.

TorchUncertainty provide lightning-based models that can be easily trained and evaluated. These models inherit from the routines and are specifically designed to benchmark different methods in similar settings, here with constant architectures.

Classification#

ResNetBaseline

ResNet backbone baseline for classification providing support for various versions and architectures.

VGGBaseline

VGG backbone baseline for classification providing support for various versions and architectures.

WideResNetBaseline

Wide-ResNet28x10 backbone baseline for classification providing support for various versions.

Regression#

MLPBaseline

MLP baseline for regression providing support for various versions.

Segmentation#

DeepLabBaseline

SegFormerBaseline

SegFormer backbone baseline for segmentation providing support for various versions and architectures.

Monocular Depth Estimation#

Layers#

Ensemble layers#

PackedLinear

Packed-Ensembles-style Linear layer.

PackedConv2d

Packed-Ensembles-style Conv2d layer.

PackedMultiheadAttention

Packed-Ensembles-style MultiheadAttention layer.

PackedLayerNorm

Packed-Ensembles-style LayerNorm layer.

PackedTransformerEncoderLayer

Packed-Ensembles-style TransformerEncoderLayer (made up of self-attention followed by a feedforward network).

PackedTransformerDecoderLayer

Packed-Ensembles-style TransformerDecoderLayer (made up of self-attention, multi-head attention, and feedforward network).

BatchLinear

BatchEnsemble-style Linear layer.

BatchConv2d

BatchEnsemble-style Conv2d layer.

MaskedLinear

Masksembles-style Linear layer.

MaskedConv2d

Masksembles-style Conv2d layer.

Bayesian layers#

BayesLinear

Bayesian Linear Layer with Mixture of Normals prior and Normal posterior.

BayesConv1d

Bayesian Conv1d Layer with Mixture of Normals prior and Normal posterior.

BayesConv2d

Bayesian Conv2d Layer with Gaussian Mixture prior and Normal posterior.

BayesConv3d

Bayesian Conv3d Layer with Gaussian mixture prior and Normal posterior.

LPBNNLinear

LPBNN-style linear layer.

LPBNNConv2d

LPBNN-style 2D convolutional layer.

Density layers#

Linear Layers#

NormalLinear

Normal Distribution Linear Density Layer.

LaplaceLinear

Laplace Distribution Linear Density Layer.

CauchyLinear

Cauchy Distribution Linear Density Layer.

StudentTLinear

Student's T-Distribution Linear Density Layer.

NormalInverseGammaLinear

Normal-Inverse-Gamma Distribution Linear Density Layer.

Convolution Layers#

NormalConvNd

Normal Distribution Convolutional Density Layer.

LaplaceConvNd

Laplace Distribution Convolutional Density Layer.

CauchyConvNd

Cauchy Distribution Convolutional Density Layer.

StudentTConvNd

Student's T-Distribution Convolutional Density Layer.

NormalInverseGammaConvNd

Normal-Inverse-Gamma Distribution Convolutional Density Layer.

Models#

Wrappers#

Functions#

batch_ensemble

BatchEnsemble wrapper for a model.

deep_ensembles

Build a Deep Ensembles out of the original models.

mc_dropout

MC Dropout wrapper for a model.

Classes#

BatchEnsemble

Wrap a BatchEnsemble model to ensure correct batch replication.

CheckpointCollector

Ensemble of models at different points in the training trajectory.

EMA

Exponential Moving Average (EMA).

MCDropout

MC Dropout wrapper for a model containing nn.Dropout modules.

StochasticModel

SWA

Stochastic Weight Averaging.

SWAG

Stochastic Weight Averaging Gaussian (SWAG).

Metrics#

Classification#

Proper Scores#

BrierScore

Compute the Brier score.

CategoricalNLL

Computes the Negative Log-Likelihood (NLL) metric for classification tasks.

Out-of-Distribution Detection#

FPRx

Compute the False Positive Rate at x% Recall.

FPR95

Compute the False Positive Rate at 95% Recall.

Selective Classification#

AUGRC

Calculate The Area Under the Generalized Risk-Coverage curve (AUGRC).

AURC

Calculate Area Under the Risk-Coverage curve.

CovAtxRisk

Provide coverage at x Risk.

CovAt5Risk

Provide coverage at 5% Risk.

RiskAtxCov

Compute the risk at a specified coverage threshold.

RiskAt80Cov

Compute the risk at 80% coverage.

Calibration#

AdaptiveCalibrationError

Computes the Adaptive Top-label Calibration Error (ACE) for classification tasks.

CalibrationError

Computes the Calibration Error for classification tasks.

Conformal Predictions#

CoverageRate

Empirical coverage rate metric.

SetSize

Set size to compute the efficiency of conformal prediction methods.

Diversity#

Disagreement

Calculate the Disagreement Metric.

Entropy

The Shannon Entropy Metric to estimate the confidence of a single model or the mean confidence across estimators.

MutualInformation

Compute the Mutual Information Metric.

VariationRatio

Compute the Variation Ratio.

Others#

GroupingLoss

Metric to estimate the Top-label Grouping Loss.

Regression#

DistributionNLL

Computes the Negative Log-Likelihood (NLL) metric for classification tasks.

Log10

Computes the LOG10 metric.

MeanAbsoluteErrorInverse

Mean Absolute Error of the inverse predictions (iMAE).

MeanGTRelativeAbsoluteError

Compute Mean Absolute Error relative to the Ground Truth (MAErel or ARErel).

MeanGTRelativeSquaredError

Compute mean squared error relative to the Ground Truth (MSErel or SRE).

MeanSquaredErrorInverse

Mean Squared Error of the inverse predictions (iMSE).

MeanSquaredLogError

Computes the Mean Squared Logarithmic Error (MSLE) regression metric.

SILog

Computes The Scale-Invariant Logarithmic Loss metric.

ThresholdAccuracy

Computes the Threshold Accuracy metric, also referred to as d1, d2, or d3.

Segmentation#

MeanIntersectionOverUnion

Computes Mean Intersection over Union (IoU) score.

SegmentationBinaryAUROC

SegmentationBinaryAUROC computes the Area Under the Receiver Operating Characteristic Curve (AUROC) for binary segmentation tasks.

SegmentationBinaryAveragePrecision

SegmentationBinaryAveragePrecision computes the Average Precision (AP) for binary segmentation tasks.

SegmentationFPR95

FPR95 metric for segmentation tasks.

Others#

AUSE

The Area Under the Sparsification Error curve (AUSE) metric to evaluate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors.

Losses#

BCEWithLogitsLSLoss

Binary Cross Entropy with Logits Loss with label smoothing.

BetaNLL

The Beta Negative Log-likelihood loss.

ConflictualLoss

The Conflictual Loss.

ConfidencePenaltyLoss

The Confidence Penalty Loss.

DECLoss

The deep evidential classification loss.

DERLoss

The Deep Evidential Regression loss.

DistributionNLLLoss

Negative Log-Likelihood loss using given distributions as inputs.

ELBOLoss

The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks.

FocalLoss

Focal-Loss for classification tasks.

KLDiv

KL divergence loss for Bayesian Neural Networks.

Post-Processing Methods#

LaplaceApprox

Laplace approximation for uncertainty estimation.

MCBatchNorm

Monte Carlo Batch Normalization wrapper.

Scaling Methods#

MatrixScaler

Matrix scaling post-processing for calibrated probabilities.

TemperatureScaler

Temperature scaling post-processing for calibrated probabilities.

VectorScaler

Vector scaling post-processing for calibrated probabilities.

OOD Scores#

TUOODCriterion

Abstract base class for Out-of-Distribution (OOD) criteria.

MaxLogitCriterion

OOD criterion based on the maximum logit value.

EnergyCriterion

OOD criterion based on the energy function.

MaxSoftmaxCriterion

OOD criterion based on maximum softmax probability.

EntropyCriterion

OOD criterion based on entropy.

MutualInformationCriterion

OOD criterion based on mutual information.

PostProcessingCriterion

OOD criterion based on maximum softmax probability.

VariationRatioCriterion

OOD criterion based on variation ratio.

Datamodules#

Classification#

CIFAR10DataModule

DataModule for CIFAR10.

CIFAR100DataModule

DataModule for CIFAR100.

ImageNetDataModule

DataModule for the ImageNet dataset.

MNISTDataModule

DataModule for MNIST.

TinyImageNetDataModule

DataModule for the Tiny-ImageNet dataset.

UCI Tabular Classification#

BankMarketingDataModule

The Bank Marketing UCI classification datamodule.

DOTA2GamesDataModule

The Dota2 Games UCI classification datamodule.

HTRU2DataModule

The HTRU2 UCI classification datamodule.

OnlineShoppersDataModule

The online shoppers intention UCI classification datamodule.

SpamBaseDataModule

The Bank Marketing UCI classification datamodule.

Regression#

UCIRegressionDataModule

The UCI regression datasets.

Segmentation#

CamVidDataModule

DataModule for the CamVid dataset.

CityscapesDataModule

DataModule for the Cityscapes dataset.

MUADDataModule

Segmentation DataModule for the MUAD dataset.

Datasets#

Classification#

MNISTC

The corrupted MNIST-C Dataset.

NotMNIST

The notMNIST dataset.

CIFAR10C

The corrupted CIFAR-10-C Dataset.

CIFAR100C

The corrupted CIFAR-100-C Dataset.

CIFAR10H

CIFAR-10H Dataset.

CIFAR10N

CIFAR-10N Dataset.

CIFAR100N

CIFAR-100N Dataset.

ImageNetA

Initializes the ImageNetA dataset class.

ImageNetC

Initializes the ImageNetC dataset class.

ImageNetO

Initializes the ImageNetO dataset class.

ImageNetR

Initializes the ImageNetR dataset class.

TinyImageNet

Inspired by https://gist.github.com/z-a-f/b862013c0dc2b540cf96a123a6766e54.

TinyImageNetC

The corrupted TinyImageNet-C Dataset.

OpenImageO

OpenImage-O dataset.

UCI Tabular Classification#

BankMarketing

The bank Marketing UCI classification dataset.

DOTA2Games

The DOTA 2 Games UCI classification dataset.

HTRU2

The HTRU2 UCI classification dataset.

OnlineShoppers

The Online Shoppers Intention UCI classification dataset.

SpamBase

The SpamBase UCI classification dataset.

Regression#

UCIRegression

The UCI regression datasets.

Segmentation#

Others & Cross-Categories#

Fractals

Dataset used for PixMix augmentations.

FrostImages

KITTIDepth

MUAD

The MUAD Dataset.

NYUv2

NYUv2 depth dataset.