ResNetBaseline#

class torch_uncertainty.baselines.classification.ResNetBaseline(num_classes, in_channels, loss, version, arch, style='imagenet', normalization_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, num_estimators=1, dropout_rate=0.0, optim_recipe=None, mixup_params=None, last_layer_dropout=False, width_multiplier=1.0, groups=1, conv_bias=False, scale=None, alpha=None, gamma=1, rho=1.0, batch_repeat=1, ood_criterion='msp', log_plots=False, save_in_csv=False, eval_ood=False, eval_shift=False, eval_grouping_loss=False, num_bins_cal_err=15, pretrained=False)[source]#

ResNet backbone baseline for classification providing support for various versions and architectures.

Parameters:
  • num_classes (int) – Number of classes to predict.

  • in_channels (int) – Number of input channels.

  • loss (nn.Module) – Training loss.

  • optim_recipe (Any) – optimization recipe, corresponds to what expect the LightningModule.configure_optimizers() method.

  • version (str) –

    Determines which ResNet version to use:

    • "std": original ResNet

    • "packed": Packed-Ensembles ResNet

    • "batched": BatchEnsemble ResNet

    • "masked": Masksemble ResNet

    • "mimo": MIMO ResNet

    • "mc-dropout": Monte-Carlo Dropout ResNet

  • arch (int) –

    Determines which ResNet architecture to use, one of:

    • 18: ResNet-18

    • 32: ResNet-32

    • 50: ResNet-50

    • 101: ResNet-101

    • 152: ResNet-152

  • style (str, optional) – Which ResNet style to use. Defaults to imagenet.

  • normalization_layer (type[nn.Module], optional) – Normalization layer to use. Defaults to nn.BatchNorm2d.

  • num_estimators (int, optional) – Number of estimators in the ensemble. Only used if version is either "packed", "batched", "masked" or "mc-dropout" Defaults to None.

  • dropout_rate (float, optional) – Dropout rate. Defaults to 0.0.

  • mixup_params (dict, optional) – Mixup parameters. Can include mixtype, mixmode, dist_sim, kernel_tau_max, kernel_tau_std, mixup_alpha, and cutmix_alpha. If None, no augmentations. Defaults to None.

  • width_multiplier (float, optional) – Expansion factor affecting the width of the estimators. Defaults to 1.0

  • groups (int, optional) – Number of groups in convolutions. Defaults to 1.

  • scale (float, optional) – Expansion factor affecting the width of the estimators. Only used if version is "masked". Defaults to None.

  • last_layer_dropout (bool) – whether to apply dropout to the last layer only.

  • groups – Number of groups in convolutions. Defaults to 1.

  • conv_bias (bool, optional) – Whether to include bias in the convolutional layers. Defaults to False.

  • scale – Expansion factor affecting the width of the estimators. Only used if version is "masked". Defaults to None.

  • alpha (float, optional) – Expansion factor affecting the width of the estimators. Only used if version is "packed". Defaults to None.

  • gamma (int, optional) – Number of groups within each estimator. Only used if version is "packed" and scales with groups. Defaults to 1.

  • rho (float, optional) – Probability that all estimators share the same input. Only used if version is "mimo". Defaults to 1.

  • batch_repeat (int, optional) – Number of times to repeat the batch. Only used if version is "mimo". Defaults to 1.

  • ood_criterion (TUOODCriterion, optional) – Criterion for the binary OOD detection task. Defaults to None which amounts to the maximum softmax probability score (MSP).

  • log_plots (bool, optional) – Indicates whether to log the plots or not. Defaults to False.

  • save_in_csv (bool, optional) – Indicates whether to save the results in a csv file or not. Defaults to False.

  • eval_ood (bool, optional) – Indicates whether to evaluate the OOD detection or not. Defaults to False.

  • eval_shift (bool) – Whether to evaluate on shifted data. Defaults to False.

  • eval_grouping_loss (bool, optional) – Indicates whether to evaluate the grouping loss or not. Defaults to False.

  • num_bins_cal_err (int, optional) – Number of calibration bins. Defaults to 15.

  • pretrained (bool, optional) – Indicates whether to use the pretrained weights or not. Only used if version is "packed". Defaults to False.

Raises:

ValueError – If version is not either "std", "packed", "batched", "masked" or "mc-dropout".

Returns:

ResNet baseline ready for training and evaluation.

Return type:

LightningModule