ResNetBaseline¶
- class torch_uncertainty.baselines.classification.ResNetBaseline(num_classes, in_channels, loss, version, arch, style='imagenet', normalization_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, num_estimators=1, dropout_rate=0.0, optim_recipe=None, mixup_params=None, last_layer_dropout=False, width_multiplier=1.0, groups=1, scale=None, alpha=None, gamma=1, rho=1.0, batch_repeat=1, ood_criterion='msp', log_plots=False, save_in_csv=False, calibration_set='val', eval_ood=False, eval_shift=False, eval_grouping_loss=False, num_calibration_bins=15, pretrained=False)[source]¶
ResNet backbone baseline for classification providing support for various versions and architectures.
- Parameters:
num_classes (int) – Number of classes to predict.
in_channels (int) – Number of input channels.
loss (nn.Module) – Training loss.
optim_recipe (Any) – optimization recipe, corresponds to what expect the LightningModule.configure_optimizers() method.
version (str) –
Determines which ResNet version to use:
"std"
: original ResNet"packed"
: Packed-Ensembles ResNet"batched"
: BatchEnsemble ResNet"masked"
: Masksemble ResNet"mimo"
: MIMO ResNet"mc-dropout"
: Monte-Carlo Dropout ResNet
arch (int) –
Determines which ResNet architecture to use:
18
: ResNet-1832
: ResNet-3250
: ResNet-50101
: ResNet-101152
: ResNet-152
style (str, optional) – Which ResNet style to use. Defaults to
imagenet. –
normalization_layer (type[nn.Module], optional) – Normalization layer to use. Defaults to
nn.BatchNorm2d
.num_estimators (int, optional) – Number of estimators in the ensemble. Only used if
version
is either"packed"
,"batched"
,"masked"
or"mc-dropout"
Defaults toNone
.dropout_rate (float, optional) – Dropout rate. Defaults to
0.0
.mixup_params (dict, optional) – Mixup parameters. Can include mixtype, mixmode, dist_sim, kernel_tau_max, kernel_tau_std, mixup_alpha, and cutmix_alpha. If None, no augmentations. Defaults to
None
.width_multiplier (float, optional) – Expansion factor affecting the width of the estimators. Defaults to
1.0
groups (int, optional) – Number of groups in convolutions. Defaults to
1
.scale (float, optional) – Expansion factor affecting the width of the estimators. Only used if
version
is"masked"
. Defaults toNone
.last_layer_dropout (bool) – whether to apply dropout to the last layer only.
groups – Number of groups in convolutions. Defaults to
1
.scale – Expansion factor affecting the width of the estimators. Only used if
version
is"masked"
. Defaults toNone
.alpha (float, optional) – Expansion factor affecting the width of the estimators. Only used if
version
is"packed"
. Defaults toNone
.gamma (int, optional) – Number of groups within each estimator. Only used if
version
is"packed"
and scales withgroups
. Defaults to1
.rho (float, optional) – Probability that all estimators share the same input. Only used if
version
is"mimo"
. Defaults to1
.batch_repeat (int, optional) – Number of times to repeat the batch. Only used if
version
is"mimo"
. Defaults to1
.ood_criterion (str, optional) – OOD criterion. Defaults to
"msp"
. MSP is the maximum softmax probability, logit is the maximum logit, entropy is the entropy of the mean prediction, mi is the mutual information of the ensemble and vr is the variation ratio of the ensemble.log_plots (bool, optional) – Indicates whether to log the plots or not. Defaults to
False
.save_in_csv (bool, optional) – Indicates whether to save the results in a csv file or not. Defaults to
False
.calibration_set (Callable, optional) – Calibration set. Defaults to
None
.eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection or not. Defaults to
False
.eval_shift (bool) – Whether to evaluate on shifted data. Defaults to
False. –
eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to
False
.num_calibration_bins (int, optional) – Number of calibration bins. Defaults to
15
.pretrained (bool, optional) – Indicates whether to use the pretrained weights or not. Only used if
version
is"packed"
. Defaults toFalse
.
- Raises:
ValueError – If
version
is not either"std"
,"packed"
,"batched"
,"masked"
or"mc-dropout"
.- Returns:
ResNet baseline ready for training and evaluation.
- Return type:
LightningModule