ClassificationRoutine#
- class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss=None, *, is_ensemble=False, num_tta=1, format_batch_fn=None, optim_recipe=None, mixup_params=None, eval_ood=False, eval_shift=False, eval_grouping_loss=False, ood_criterion='msp', post_processing=None, num_bins_calibration_error=15, log_plots=False, save_to_csv=False, csv_filename='results.csv')[source]#
Routine for training & testing on classification tasks.
- Parameters:
model (
Module) – Model to train.num_classes (
int) – Number of classes.loss (
Module|None) – Loss function to optimize themodel. Defaults toNone.is_ensemble (
bool) – Indicates whether the model is an ensemble at test time or not. Defaults toFalse.num_tta (
int) – Number of test-time augmentations (TTA). If1: no TTA. Defaults to1.format_batch_fn (
Module|None) – Function to format the batch. Defaults toNone.optim_recipe (
Union[Callable[[Module],Union[Optimizer,Sequence[Optimizer],tuple[Sequence[Optimizer],Sequence[Union[LRScheduler,ReduceLROnPlateau,LRSchedulerConfig]]],OptimizerConfig,OptimizerLRSchedulerConfig,Sequence[OptimizerConfig],Sequence[OptimizerLRSchedulerConfig],None]],Optimizer,Sequence[Optimizer],tuple[Sequence[Optimizer],Sequence[Union[LRScheduler,ReduceLROnPlateau,LRSchedulerConfig]]],OptimizerConfig,OptimizerLRSchedulerConfig,Sequence[OptimizerConfig],Sequence[OptimizerLRSchedulerConfig],None]) – The optimizer and optionally the scheduler to use, or a callable that returns them. Defaults toNone.mixup_params (
dict|None) – Mixup parameters dictionary. Can include mixup type, mixup mode, distance similarity, kernel tau max, kernel tau std, mixup alpha, and cutmix alpha. If None, no mixup augmentations. Defaults toNone.eval_ood (
bool) – Indicates whether to evaluate the OOD detection performance. Defaults toFalse.eval_shift (
bool) – Indicates whether to evaluate the Distribution shift performance. Defaults toFalse.eval_grouping_loss (
bool) – Indicates whether to evaluate the grouping loss or not. Defaults toFalse.ood_criterion (
TUOODCriterion|str) – Criterion for the binary OOD detection task. Defaults tomsp, the Maximum Softmax Probability score.post_processing (
PostProcessing|None) – Post-processing method to train on the calibration set. No post-processing if None. Defaults toNone.num_bins_calibration_error (
int) – Number of bins to compute calibration error metrics. Defaults to15.log_plots (
bool) – Indicates whether to log plots from metrics. Defaults toFalse.save_to_csv (
bool) – Save the results in csv. Defaults toFalse.csv_filename (
str) – Name of the csv file. Defaults to"results.csv". Note that this is only used ifsave_to_csvisTrue.
Warning
You must define
optim_recipeif you do not use the Lightning CLI.Warning
When using an ensemble model, you must: 1. Set
is_ensembletoTrue. 2. Setformat_batch_fntotorch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators). 3. Ensure that the model’s forward pass outputs a tensor of shape \((M \times B, C)\), where \(M\) is the number of estimators, \(B\) is the batch size, \(C\) is the number of classes.For automated batch handling, consider using the available model wrappers in torch_uncertainty.methods.
Note
If
eval_oodisTrue, we perform a binary classification and update the OOD-related metrics twice: - once during the test on ID values where the given binary label is 0 (for ID) - once during the test on OOD values where the given binary label is 1 (for OOD)Note
optim_recipecan be anything that can be returned byLightningModule.configure_optimizers(). Find more details here.- forward(inputs, save_feats=False)[source]#
Forward pass of the inner model.
- Parameters:
inputs (
Tensor) – input tensor.save_feats (
bool) – whether to store the features or not. Defaults toFalse.
- Return type:
Tensor
Note
The features are stored in the
self.featuresattribute.
- on_test_epoch_end()[source]#
Compute, log, and plot the values of the collected metrics in test_step.
- Return type:
None
- on_test_start()[source]#
Prepare the test step.
Setup the post-processing dataset and fit the post-processing method if needed, prepares the storage lists for logit plotting and update the batchnorms if needed.
- Return type:
None
- on_validation_epoch_end()[source]#
Compute and log the values of the collected metrics in validation_step.
- Return type:
None
- on_validation_start()[source]#
Prepare the validation step.
Update the model’s wrapper and the batchnorms if needed.
- Return type:
None
- test_step(batch, batch_idx, dataloader_idx=0)[source]#
Perform a single test step based on the input tensors.
Compute the prediction of the model and the value of the metrics on the test batch. Also handle OOD and distribution-shifted images.
- Parameters:
batch (
tuple[Tensor,Tensor]) – the test data and their corresponding targets.batch_idx (
int) – the number of the current batch (unused).dataloader_idx (
int) – 0 if in-distribution, 1 if out-of-distribution and 2 if distribution-shifted.
- Return type:
None