WideResNetBaseline#
- class torch_uncertainty.baselines.classification.WideResNetBaseline(num_classes, in_channels, loss, version, style='imagenet', num_estimators=1, dropout_rate=0.0, optim_recipe=None, mixup_params=None, groups=1, last_layer_dropout=False, scale=None, alpha=None, gamma=1, rho=1.0, batch_repeat=1, ood_criterion='msp', log_plots=False, save_in_csv=False, eval_ood=False, eval_shift=False, eval_grouping_loss=False)[source]#
Wide-ResNet28x10 backbone baseline for classification providing support for various versions.
- Parameters:
num_classes (int) – Number of classes to predict.
in_channels (int) – Number of input channels.
loss (nn.Module) – Training loss.
optim_recipe (Any) – optimization recipe, corresponds to what expect the LightningModule.configure_optimizers() method.
version (str) –
Determines which Wide-ResNet version to use:
"std"
: original Wide-ResNet"mc-dropout"
: Monte Carlo Dropout Wide-ResNet"packed"
: Packed-Ensembles Wide-ResNet"batched"
: BatchEnsemble Wide-ResNet"masked"
: Masksemble Wide-ResNet"mimo"
: MIMO Wide-ResNet
style (bool, optional) – (str, optional): Which ResNet style to use. Defaults to
imagenet
.num_estimators (int, optional) – Number of estimators in the ensemble. Only used if
version
is either"packed"
,"batched"
or"masked"
Defaults toNone
.dropout_rate (float, optional) – Dropout rate. Defaults to
0.0
.mixup_params (dict, optional) – Mixup parameters. Can include mixtype, mixmode, dist_sim, kernel_tau_max, kernel_tau_std, mixup_alpha, and cutmix_alpha. If None, no augmentations. Defaults to
None
.last_layer_dropout (bool) – whether to apply dropout to the last layer only.
groups (int, optional) – Number of groups in convolutions. Defaults to
1
.scale (float, optional) – Expansion factor affecting the width of the estimators. Only used if
version
is"masked"
. Defaults toNone
.alpha (float, optional) – Expansion factor affecting the width of the estimators. Only used if
version
is"packed"
. Defaults toNone
.gamma (int, optional) – Number of groups within each estimator. Only used if
version
is"packed"
and scales withgroups
. Defaults to1
.rho (float, optional) – Probability that all estimators share the same input. Only used if
version
is"mimo"
. Defaults to1
.batch_repeat (int, optional) – Number of times to repeat the batch. Only used if
version
is"mimo"
. Defaults to1
.ood_criterion (TUOODCriterion, optional) – Criterion for the binary OOD detection task. Defaults to None which amounts to the maximum softmax probability score (MSP).
log_plots (bool, optional) – Indicates whether to log the plots or not. Defaults to
False
.save_in_csv (bool, optional) – Indicates whether to save the results in a csv file or not. Defaults to
False
.eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection or not. Defaults to
False
.eval_shift (bool) – Whether to evaluate on shifted data. Defaults to
False
.eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to
False
.
- Raises:
ValueError – If
version
is not either"std"
,"packed"
,"batched"
or"masked"
.- Returns:
Wide-ResNet baseline ready for training and evaluation.
- Return type:
LightningModule