CLI Guide#

Introduction to the Lightning CLI#

The Lightning CLI tool eases the implementation of a CLI to instanciate models to train and evaluate them on some data. The CLI tool is a wrapper around the Trainer class and provides a set of subcommands to train and test a LightningModule on a LightningDataModule. To better match our needs, we created an inherited class from the LightningCLI class, namely TULightningCLI.

Note

TULightningCLI adds a new argument to the LightningCLI class: eval_after_fit to know whether an evaluation on the test set should be performed after the training phase. It is better to restrict the usage of this parameter for single node training to avoid small test performance inconsistencies.

Let’s see how to implement the CLI, by checking out the experiments/classification/cifar10/main.py.

import torch
from torch_uncertainty import TULightningCLI
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.routines import ClassificationRoutine


def cli_main() -> TULightningCLI:
    return TULightningCLI(ClassificationRoutine, CIFAR10DataModule)


if __name__ == "__main__":
    torch.set_float32_matmul_precision("medium")
    cli = cli_main()
    if (
        (not cli.trainer.fast_dev_run)
        and cli.subcommand == "fit"
        and cli._get(cli.config, "eval_after_fit")
    ):
        cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best")

This file enables both training and testing on the CIFAR-10 dataset. The model, optimizer and learning rate schedulers will be set in the configuration file as shown below.

def cli_main() -> TULightningCLI:
    return TULightningCLI(ClassificationRoutine, CIFAR10DataModule)

Depending on the CLI subcommand calling cli_main() will either train or test the model on the using the CIFAR-10 dataset. But what are these subcommands?

python main.py --help

This command will display the available subcommands of the CLI tool.

subcommands:
For more details of each subcommand, add it as an argument followed by --help.

Available subcommands:
    fit                 Runs the full optimization routine.
    validate            Perform one evaluation epoch over the validation set.
    test                Perform one evaluation epoch over the test set.
    predict             Run evaluation on your data.

You can execute whichever subcommand you like and set up all your hyperparameters directly using the command line.

Due to the large number of hyperparameters, we advise against it and suggest using configuration files. Let’s see how to do that.

Configuration files#

By default the LightningCLI support configuration files in the YAML format (learn more about this format here). Taking the previous example, we can create a configuration file named config.yaml with the following content:

  # config.yaml
  eval_after_fit: true
  trainer:
    max_epochs: 75
    accelerators: gpu
    devices: 1
  routine:
    model:
      class_path: torch_uncertainty.models.classification.resnet
      init_args:
      in_channels: 3
      num_classes: 10
      arch: 18
      style: cifar
    num_classes: 10
    loss: CrossEntropyLoss
  data:
    root: ./data
    batch_size: 128
optimizer:
  class_path: torch.optim.SGD
  init_args:
    lr: 0.05
    momentum: 0.9
    weight_decay: 5e-4
lr_scheduler:
  class_path: torch.optim.lr_scheduler.MultiStepLR
  init_args:
    milestones:
      - 25
      - 50
    gamma: 0.1

Then, we can run the following command to train the model:

python main.py fit --config config.yaml

By default, executing the command above will store the experiment results in a directory named lightning_logs, and the last state of the model will be saved in a directory named lightning_logs/version_{int}/checkpoints. In addition, all arguments passed to instanciate the Trainer, CIFAR10DataModule, torch.optim.SGD, and torch.optim.lr_scheduler.MultiStepLR classes will be saved in a file named lightning_logs/version_{int}/config.yaml. When testing the model, we advise to use this configuration file to ensure that the same hyperparameters are used for training and testing.

python main.py test --config lightning_logs/version_{int}/config.yaml --ckpt_path lightning_logs/version_{int}/checkpoints/{filename}.ckpt

Experiment folder usage#

Now that we have seen how to implement the CLI tool and how to use configuration files, let explore the configurations available in the experiments directory. The experiments directory is mainly organized as follows:

experiments
├── classification
│   ├── cifar10
│      ├── configs
│         ├── resnet18
│            ├── standard.yaml
│            ├── packed.yaml
│            ├── ...
│         ├── resnet50
│         ├── wideresnet28x10
│      ├── main.py
│   └── cifar100
│       ├── configs
│          ├── ...
│       ├── main.py
├── regression
│   └── uci_datasets
│       ├── configs
│       └── main.py
└── segmentation
    ├── cityscapes
       ├── configs
       └── main.py
    └── muad
        ├── configs
        └── main.py

For each task (classification, regression, and segmentation), we have a directory containing the datasets (e.g., CIFAR10, CIFAR100, UCI datasets, Cityscapes, and MUAD) and for each dataset, we have a directory containing the configuration files and the CLI files for different backbones.

You can directly use the CLI files with the command line or use the predefined configuration files to train and test the models. The configuration files are stored in the configs. For example, the configuration file for the classic ResNet-18 model on the CIFAR-10 dataset is stored in the experiments/classification/cifar10/configs/resnet18/standard.yaml file. For the Packed ResNet-18 model on the CIFAR-10 dataset, the configuration file is stored in the experiments/classification/cifar10/configs/resnet18/packed.yaml file.

If you are interested in using a ResNet model but want to choose some of the hyperparameters using the command line, you can use the configuration file and override the hyperparameters using the command line. For example, to train a ResNet-18 model on the CIFAR-10 dataset with a batch size of \(256\), you can use the following command:

python main.py fit --config configs/resnet18/standard.yaml --data.batch_size 256

To use the weights argument of the torch.nn.CrossEntropyLoss class, you can use the following command:

python main.py fit --config configs/resnet18/standard.yaml --routine.loss CrossEntropyLoss --routine.loss.weight Tensor --routine.loss.weight.dict_kwargs.data [1,2,3,4,5,6,7,8,9,10]

In addition, we provide a default configuration file for some backbones in the configs directory. For example, experiments/classification/cifar10/configs/resnet18/standard.yaml contains the default hyperparameters to train a ResNet-18 model on the CIFAR-10 dataset. Yet, some hyperparameters are purposely missing to be set by the user using the command line.

For instance, to train a Packed ResNet-34 model on the CIFAR-10 dataset with \(4\) estimators and a \(\alpha\) value of \(2\), you can use the following command:

python main.py fit --config configs/resnet18/standard.yaml --trainer.max_epochs 75 --routine.model.version packed --routine.model.arch 34 --routine.model.num_estimators 4 --routine.model.alpha 2 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50]

Tip

Explore the Lightning CLI docs to learn more about the CLI tool, the available arguments, and how to use them with configuration files.