ClassificationRoutine¶
- class torch_uncertainty.routines.ClassificationRoutine(model, num_classes, loss, num_estimators=1, format_batch_fn=None, optim_recipe=None, mixtype='erm', mixmode='elem', dist_sim='emb', kernel_tau_max=1.0, kernel_tau_std=0.5, mixup_alpha=0, cutmix_alpha=0, eval_ood=False, eval_grouping_loss=False, ood_criterion='msp', log_plots=False, save_in_csv=False, calibration_set=None)[source]¶
Routine for efficient training and testing on classification tasks using LightningModule.
- Parameters:
model (torch.nn.Module) – Model to train.
num_classes (int) – Number of classes.
loss (torch.nn.Module) – Loss function to optimize the
model
.num_estimators (int, optional) – Number of estimators for the ensemble. Defaults to
1
(single model).format_batch_fn (torch.nn.Module, optional) – Function to format the batch. Defaults to
torch.nn.Identity()
.optim_recipe (dict or torch.optim.Optimizer, optional) – The optimizer and optionally the scheduler to use. Defaults to
None
.mixtype (str, optional) – Mixup type. Defaults to
"erm"
.mixmode (str, optional) – Mixup mode. Defaults to
"elem"
.dist_sim (str, optional) – Distance similarity. Defaults to
"emb"
.kernel_tau_max (float, optional) – Maximum value for the kernel tau. Defaults to
1.0
.kernel_tau_std (float, optional) – Standard deviation for the kernel tau. Defaults to
0.5
.mixup_alpha (float, optional) – Alpha parameter for Mixup. Defaults to
0
.cutmix_alpha (float, optional) – Alpha parameter for Cutmix. Defaults to
0
.eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection performance or not. Defaults to
False
.eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to
False
.ood_criterion (str, optional) –
OOD criterion. Available options are
"msp"
(default): Maximum softmax probability."logit"
: Maximum logit."energy"
: Logsumexp of the mean logits."entropy"
: Entropy of the mean prediction."mi"
: Mutual information of the ensemble."vr"
: Variation ratio of the ensemble.
log_plots (bool, optional) – Indicates whether to log plots from metrics. Defaults to
False
.save_in_csv (bool, optional) – Save the results in csv. Defaults to
False
.calibration_set (str, optional) – The calibration dataset to use for scaling. If not
None
, it uses either the validation set when set to"val"
or the test set when set to"test"
. Defaults toNone
.
Warning
You must define
optim_recipe
if you do not use the CLI.Note
optim_recipe
can be anything that can be returned byLightningModule.configure_optimizers()
. Find more details here.