API Reference#

Routines#

Routines are the main building blocks of the library. They define the framework in which models are trained and evaluated, and make it easy to compute the metrics crucial for uncertainty estimation across the supported tasks: classification, regression, segmentation, and pixelwise regression. See the Evaluating Models page for a full breakdown of the metrics computed by each routine.

Classification#

ClassificationRoutine

Routine for training & testing on classification tasks.

Segmentation#

SegmentationRoutine

Routine for training & testing on segmentation tasks.

Regression#

RegressionRoutine

Routine for training & testing on regression tasks.

Pixelwise Regression#

PixelRegressionRoutine

Routine for training & testing on pixel regression tasks.

Layers#

Ensemble layers#

PackedLinear

Packed-Ensembles-style Linear layer.

PackedConv2d

Packed-Ensembles-style Conv2d layer.

PackedMultiheadAttention

Packed-Ensembles-style MultiheadAttention layer.

PackedLayerNorm

Packed-Ensembles-style LayerNorm layer.

PackedTransformerEncoderLayer

Packed-Ensembles-style TransformerEncoderLayer (made up of self-attention followed by a feedforward network).

PackedTransformerDecoderLayer

Packed-Ensembles-style TransformerDecoderLayer (made up of self-attention, multi-head attention, and feedforward network).

BatchLinear

BatchEnsemble-style Linear layer.

BatchConv2d

BatchEnsemble-style Conv2d layer.

MaskedLinear

Masksembles-style Linear layer.

MaskedConv2d

Masksembles-style Conv2d layer.

Bayesian layers#

BayesLinear

Bayesian Linear Layer with Mixture of Normals prior and Normal posterior.

BayesConv1d

Bayesian Conv1d Layer with Mixture of Normals prior and Normal posterior.

BayesConv2d

Bayesian Conv2d Layer with Gaussian Mixture prior and Normal posterior.

BayesConv3d

Bayesian Conv3d Layer with Gaussian mixture prior and Normal posterior.

LPBNNLinear

LPBNN-style linear layer.

LPBNNConv2d

LPBNN-style 2D convolutional layer.

Density layers#

Linear Layers#

NormalLinear

Base Linear layer for any distribution with loc and scale parameters.

LaplaceLinear

Base Linear layer for any distribution with loc and scale parameters.

CauchyLinear

Base Linear layer for any distribution with loc and scale parameters.

StudentTLinear

Student's T-Distribution Linear Density Layer.

NormalInverseGammaLinear

Normal-Inverse-Gamma Distribution Linear Density Layer.

Convolution Layers#

NormalConvNd

Base Convolutional layer for any distribution with loc and scale parameters.

LaplaceConvNd

Base Convolutional layer for any distribution with loc and scale parameters.

CauchyConvNd

Base Convolutional layer for any distribution with loc and scale parameters.

StudentTConvNd

Student's T-Distribution Convolutional Density Layer.

NormalInverseGammaConvNd

Normal-Inverse-Gamma Distribution Convolutional Density Layer.

Model Backbones#

ResNet#

batched_resnet

BatchEnsemble of ResNet.

lpbnn_resnet

LPBNN version of ResNet.

masked_resnet

Masksembles of ResNet.

mimo_resnet

MIMO ResNet.

packed_resnet

Packed-Ensembles of ResNet.

resnet

ResNet model.

WideResNet#

batched_wideresnet28x10

BatchEnsemble of Wide-ResNet-28x10.

masked_wideresnet28x10

Masksembles of Wide-ResNet-28x10.

mimo_wideresnet28x10

MIMO of Wide-ResNet-28x10.

packed_wideresnet28x10

Packed-Ensembles of Wide-ResNet-28x10.

wideresnet28x10

Wide-ResNet-28x10 from Wide Residual Networks.

InceptionTime#

batched_inception_time

BatchEnsemble of InceptionTime.

bayesian_inception_time

Bayesian InceptionTime.

inception_time

InceptionTime from InceptionTime: Finding AlexNet for Time Series Classification.

mimo_inception_time

MIMO InceptionTime.

packed_inception_time

Packed-Ensembles of InceptionTime.

UQ Methods#

UQ Methods encapsulate your models to enable better uncertainty estimation.

Functions#

Some methods can be directly created through functions such as the following:

batch_ensemble

BatchEnsemble wrapper for a model.

deep_ensembles

Build a Deep Ensembles out of the original models.

mc_dropout

MC Dropout wrapper for a model.

Classes#

Some methods need to be instantiated as classes such as the following:

BatchEnsemble

Wrap a BatchEnsemble model to ensure correct batch replication.

CheckpointCollector

Ensemble of models at different points in the training trajectory.

EMA

Exponential Moving Average (EMA).

StochasticModel

SWA

Stochastic Weight Averaging.

SWAG

Stochastic Weight Averaging Gaussian (SWAG).

Zero

Zero for test-time adaptation.

Metrics#

Classification#

Proper Scores#

BrierScore

Compute the Brier score.

CategoricalNLL

Computes the Negative Log-Likelihood (NLL) metric for classification tasks.

Out-of-Distribution Detection#

FPRx

Compute the False Positive Rate at x% Recall.

FPR95

Compute the False Positive Rate at 95% Recall.

Selective Classification#

AUGRC

Calculate The Area Under the Generalized Risk-Coverage curve (AUGRC).

AURC

Calculate Area Under the Risk-Coverage curve.

CovAtxRisk

Provide coverage at x Risk.

CovAt5Risk

Provide coverage at 5% Risk.

RiskAtxCov

Compute the risk at a specified coverage threshold.

RiskAt80Cov

Compute the risk at 80% coverage.

Selective Classification with OOD#

SCODAUGRC

Calculate The Area Under the Generalized Risk-Coverage curve (AUGRC).

SCODAURC

Calculate Area Under the Risk-Coverage curve.

SCODCovAtxRisk

Provide coverage at x Risk.

SCODCovAt5Risk

Coverage at 5% SCOD risk.

SCODRiskAtxCov

Compute the risk at a specified coverage threshold.

SCODRiskAt80Cov

SCOD risk at 80% coverage.

Calibration#

AdaptiveCalibrationError

Computes the Adaptive Top-label Calibration Error (ACE) for classification tasks.

CalibrationError

Computes the Calibration Error for classification tasks.

SmoothCalibrationError

Smooth Expected Calibration Error (SmECE).

ClasswiseCalibrationError

Compute the Classwise Expected Calibration Error (ECE).

Conformal Predictions#

CoverageRate

Empirical coverage rate metric.

SetSize

Average prediction-set size — the standard efficiency metric for conformal prediction methods.

Diversity#

Disagreement

Calculate the Disagreement Metric.

Entropy

The Shannon Entropy Metric to estimate the confidence of a single model or the mean confidence across estimators.

MutualInformation

Compute the Mutual Information Metric.

VariationRatio

Compute the Variation Ratio.

Others#

GroupingLoss

Metric to estimate the Top-label Grouping Loss.

Regression#

DistributionNLL

Computes the Negative Log-Likelihood (NLL) metric for classification tasks.

Log10

Compute the LOG10 metric.

MeanAbsoluteErrorInverse

Mean Absolute Error of the inverse predictions (iMAE).

MeanGTRelativeAbsoluteError

Compute the Mean Absolute Error relative to the Ground Truth (MAErel or ARErel).

MeanGTRelativeSquaredError

Compute mean squared error relative to the Ground Truth (MSErel or SRE).

MeanSquaredErrorInverse

Mean Squared Error of the inverse predictions (iMSE).

MeanSquaredLogError

Compute the Mean Squared Logarithmic Error (MSLE).

SILog

Compute The Scale-Invariant Logarithmic Loss metric.

ThresholdAccuracy

Compute the Threshold Accuracy metric, also referred to as \(\delta_1\), \(\delta_2\), or \(\delta_3\).

Calibration#

QuantileCalibrationError

Quantile Calibration Error for regression tasks.

Segmentation#

MeanIntersectionOverUnion

Computes the Mean Intersection over Union (mIoU) score.

SegmentationBinaryAUROC

Image-averaged binary AUROC for dense binary segmentation tasks.

SegmentationBinaryAveragePrecision

Image-averaged binary Average Precision for dense segmentation tasks.

SegmentationFPR95

Image-averaged FPR@95 TPR for dense binary segmentation tasks.

Others#

AUSE

The Area Under the Sparsification Error curve (AUSE) metric to evaluate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors.

Losses#

BCEWithLogitsLSLoss

Binary Cross Entropy with Logits Loss with label smoothing.

BetaNLL

The \(\beta\)-Negative Log-Likelihood loss (Seitzer et al., 2022).

ConflictualLoss

The Conflictual Loss.

ConfidencePenaltyLoss

The Confidence Penalty loss.

DECLoss

The Deep Evidential Classification (DEC) loss.

DERLoss

The Deep Evidential Regression (DER) loss.

DistributionNLLLoss

Negative Log-Likelihood loss for probabilistic regression.

ELBOLoss

The (negative) Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks.

FocalLoss

Focal Loss for classification tasks.

KLDiv

KL divergence loss for Bayesian Neural Networks.

Post-Processing Methods#

LaplaceApprox

Laplace approximation for post-hoc Bayesian uncertainty estimation.

MCBatchNorm

Monte Carlo Batch Normalization (MCBN) wrapper for 2D inputs (Teye, Azizpour & Smith, ICML 2018).

Scaling Methods#

TemperatureScaler

Temperature scaling post-processing for calibrated probabilities.

VectorScaler

Vector scaling post-processing for calibrated probabilities.

MatrixScaler

Matrix scaling post-processing for calibrated probabilities.

DirichletScaler

Dirichlet scaling post-processing for calibrated probabilities (Kull et al., 2019).

IsotonicRegressionScaler

Isotonic Regression post-processing for calibrated probabilities (Zadrozny & Elkan, 2002).

Conformal Methods#

Conformal

Abstract base class for split-conformal classification predictors.

ConformalClsAPS

Conformal classification with Adaptive Prediction Sets (APS; Romano, Sesia & Candès, NeurIPS 2020).

ConformalClsRAPS

Conformal classification with Regularised Adaptive Prediction Sets (RAPS; Angelopoulos, Bates, Jordan & Malik, 2021).

ConformalClsTHR

Threshold-based conformal classifier (THR; Sadinle et al., 2019).

OOD Scores#

TUOODCriterion

Abstract base class for Out-of-Distribution (OOD) criteria.

MaxLogitCriterion

OOD criterion based on the Max-Logit score (Hendrycks et al.).

EnergyCriterion

OOD criterion based on the free-energy score (Liu et al., NeurIPS 2020).

MaxSoftmaxCriterion

OOD criterion based on the Maximum Softmax Probability (MSP) baseline of Hendrycks & Gimpel (ICLR 2017).

EntropyCriterion

OOD criterion based on entropy.

MutualInformationCriterion

OOD criterion based on mutual information (BALD).

PostProcessingCriterion

OOD criterion based on the Maximum Softmax Probability (MSP) baseline of Hendrycks & Gimpel (ICLR 2017).

VariationRatioCriterion

OOD criterion based on variation ratio.

Datamodules#

Classification#

CIFAR10DataModule

DataModule for CIFAR10.

CIFAR100DataModule

DataModule for CIFAR100.

ImageNetDataModule

DataModule for the ImageNet dataset.

MNISTDataModule

DataModule for MNIST.

TinyImageNetDataModule

DataModule for the Tiny-ImageNet dataset.

Tabular Classification#

TabularClassificationDataModule

Tabular binary classification datamodule.

AdultCensusIncomeDataModule

Tabular binary classification datamodule.

AmazonAccessDataModule

Tabular binary classification datamodule.

APSFailureDataModule

Tabular binary classification datamodule.

BankMarketingDataModule

Tabular binary classification datamodule.

CreditApprovalDataModule

Tabular binary classification datamodule.

DOTA2GamesDataModule

Tabular binary classification datamodule.

GermanCreditDataModule

Tabular binary classification datamodule.

HiggsBosonDataModule

Tabular binary classification datamodule.

HTRU2DataModule

Tabular binary classification datamodule.

KDDChurnDataModule

Tabular binary classification datamodule.

OnlineShoppersDataModule

Tabular binary classification datamodule.

PimaDiabetesDataModule

Tabular binary classification datamodule.

SpamBaseDataModule

Tabular binary classification datamodule.

TelcoChurnDataModule

Tabular binary classification datamodule.

WineQualityDataModule

Wine Quality datamodule.

Regression#

TabularRegressionDataModule

UCI regression datamodule.

Segmentation#

CamVidDataModule

DataModule for the CamVid dataset.

CityscapesDataModule

DataModule for the Cityscapes dataset.

MUADDataModule

Segmentation DataModule for the MUAD dataset.

Datasets#

Classification#

MNISTC

The corrupted MNIST-C Dataset.

NotMNIST

The notMNIST dataset.

CIFAR10C

The corrupted CIFAR-10-C Dataset.

CIFAR100C

The corrupted CIFAR-100-C Dataset.

CIFAR10H

CIFAR-10H Dataset.

CIFAR10N

CIFAR-10N Dataset.

CIFAR100N

CIFAR-100N Dataset.

ImageNetA

Initializes the ImageNetA dataset class.

ImageNetC

Initializes the ImageNetC dataset class.

ImageNetO

Initializes the ImageNetO dataset class.

ImageNetR

Initializes the ImageNetR dataset class.

TinyImageNet

Inspired by https://gist.github.com/z-a-f/b862013c0dc2b540cf96a123a6766e54.

TinyImageNetC

The corrupted TinyImageNet-C Dataset.

OpenImageO

OpenImage-O dataset.

Tabular Classification#

TabularClassificationDataset

Tabular binary classification dataset.

AdultCensusIncome

Tabular binary classification dataset.

AmazonAccess

Tabular binary classification dataset.

APSFailure

Tabular binary classification dataset.

BankMarketing

Tabular binary classification dataset.

CreditApproval

Tabular binary classification dataset.

DOTA2Games

Tabular binary classification dataset.

GermanCredit

Tabular binary classification dataset.

HiggsBoson

Tabular binary classification dataset.

HTRU2

Tabular binary classification dataset.

KDDChurn

Tabular binary classification dataset.

OnlineShoppers

Tabular binary classification dataset.

PimaDiabetes

Tabular binary classification dataset.

SpamBase

Tabular binary classification dataset.

TelcoChurn

Tabular binary classification dataset.

WineQuality

Wine Quality classification dataset.

Regression#

TabularRegressionDataset

UCI regression dataset.

Segmentation#

CamVid

CamVid Dataset.

Cityscapes

Cityscapes dataset wrapper with train-ID color mapping.

Others & Cross-Categories#

Fractals

Dataset used for PixMix augmentations.

FrostImages

Frost corruption image dataset.

KITTIDepth

KITTI Depth Estimation dataset.

MUAD

The MUAD Dataset.

NYUv2

NYUv2 depth dataset.

Callbacks#

Custom Lightning callbacks for advanced checkpointing and model saving functionalities.

CompoundCheckpoint

Save the checkpoints maximizing or minimizing a given linear form on the metric values.

TUClsCheckpoint

Keep multiple checkpoints corresponding to the best model in terms of: Accuracy, Brier-Score and Negative Log-Likelihood.

TURegCheckpoint

Keep multiple checkpoints corresponding to the best model in terms of: Mean Squared Error, and eventually the Negative Log-Likelihood and Quantile Calibration Error.

TUSegCheckpoint

Keep multiple checkpoints corresponding to the best model in terms of: Mean Intersection over Union, Brier-Score and Negative Log-Likelihood.