Shortcuts

CLI Guide

Introduction to the Lightning CLI

The Lightning CLI tool eases the implementation of a CLI to instanciate models to train and evaluate them on some data. The CLI tool is a wrapper around the Trainer class and provides a set of subcommands to train and test a LightningModule on a LightningDataModule. To better match our needs, we created an inherited class from the LightningCLI class, namely TULightningCLI.

Note

TULightningCLI adds a new argument to the LightningCLI class: eval_after_fit to know whether an evaluation on the test set should be performed after the training phase.

Let’s see how to implement the CLI, by checking out the experiments/classification/cifar10/resnet.py.

import torch
from lightning.pytorch.cli import LightningArgumentParser

from torch_uncertainty.baselines.classification import ResNetBaseline
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty import TULightningCLI


class ResNetCLI(TULightningCLI):
    def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
        parser.add_optimizer_args(torch.optim.SGD)
        parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR)


def cli_main() -> ResNetCLI:
    return ResNetCLI(ResNetBaseline, CIFAR10DataModule)


if __name__ == "__main__":
    cli = cli_main()
    if (
        (not cli.trainer.fast_dev_run)
        and cli.subcommand == "fit"
        and cli._get(cli.config, "eval_after_fit")
    ):
        cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best")

This file enables both training and testing ResNet architectures on the CIFAR-10 dataset. The ResNetCLI class inherits from the TULightningCLI class and implements the add_arguments_to_parser method to add the optimizer and learning rate scheduler arguments into the parser. In this case, we use the torch.optim.SGD optimizer and the torch.optim.lr_scheduler.MultiStepLR learning rate scheduler.

class ResNetCLI(TULightningCLI):
    def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
        parser.add_optimizer_args(torch.optim.SGD)
        parser.add_lr_scheduler_args(torch.optim.lr_scheduler.MultiStepLR)

The LightningCLI takes a LightningModule and a LightningDataModule as arguments. Here the cli_main function creates an instance of the ResNetCLI class by taking the ResNetBaseline model and the CIFAR10DataModule as arguments.

def cli_main() -> ResNetCLI:
    return ResNetCLI(ResNetBaseline, CIFAR10DataModule)

Note

The ResNetBaseline is a subclass of the ClassificationRoutine seemlessly instanciating a ResNet model based on a version and an arch to be passed to the routine.

Depending on the CLI subcommand calling cli_main() will either train or test the model on the using the CIFAR-10 dataset. But what are these subcommands?

python resnet.py --help

This command will display the available subcommands of the CLI tool.

subcommands:
For more details of each subcommand, add it as an argument followed by --help.

Available subcommands:
    fit                 Runs the full optimization routine.
    validate            Perform one evaluation epoch over the validation set.
    test                Perform one evaluation epoch over the test set.
    predict             Run evaluation on your data.

You can execute whichever subcommand you like and set up all your hyperparameters directly using the command line

python resnet.py fit --trainer.max_epochs 75 --trainer.accelerators gpu --trainer.devices 1 --model.version std --model.arch 18 --model.in_channels 3 --model.num_classes 10 --model.loss CrossEntropyLoss --model.style cifar --data.root ./data --data.batch_size 128 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50]

All arguments in the __init__() methods of the Trainer, LightningModule (here ResNetBaseline), LightningDataModule (here CIFAR10DataModule), torch.optim.SGD, and torch.optim.lr_scheduler.MultiStepLR classes are configurable using the CLI tool using the --trainer, --model, --data, --optimizer, and --lr_scheduler prefixes, respectively.

However for a large number of hyperparameters, it is not practical to pass them all in the command line. It is more convenient to use configuration files to store these hyperparameters and ease the burden of repeating them each time you want to train or test a model. Let’s see how to do that.

Note

Note that Pytorch classes are supported by the CLI tool, so you can use them directly: --model.loss CrossEntropyLoss and they would be automatically instanciated by the CLI tool with their default arguments (i.e., CrossEntropyLoss()).

Tip

Add the following after calling cli=cli_main() to eventually evaluate the model on the test set after training, if the eval_after_fit argument is set to True and trainer.fast_dev_run is set to False.

if (
    (not cli.trainer.fast_dev_run)
    and cli.subcommand == "fit"
    and cli._get(cli.config, "eval_after_fit")
):
    cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best")

Configuration files

By default the LightningCLI support configuration files in the YAML format (learn more about this format here). Taking the previous example, we can create a configuration file named config.yaml with the following content:

# config.yaml
eval_after_fit: true
trainer:
  max_epochs: 75
  accelerators: gpu
  devices: 1
model:
  version: std
  arch: 18
  in_channels: 3
  num_classes: 10
  loss: CrossEntropyLoss
  style: cifar
data:
  root: ./data
  batch_size: 128
optimizer:
  lr: 0.05
lr_scheduler:
  milestones:
    - 25
    - 50

Then, we can run the following command to train the model:

python resnet.py fit --config config.yaml

By default, executing the command above will store the experiment results in a directory named lightning_logs, and the last state of the model will be saved in a directory named lightning_logs/version_{int}/checkpoints. In addition, all arguments passed to instanciate the Trainer, ResNetBaseline, CIFAR10DataModule, torch.optim.SGD, and torch.optim.lr_scheduler.MultiStepLR classes will be saved in a file named lightning_logs/version_{int}/config.yaml. When testing the model, we advise to use this configuration file to ensure that the same hyperparameters are used for training and testing.

python resnet.py test --config lightning_logs/version_{int}/config.yaml --ckpt_path lightning_logs/version_{int}/checkpoints/{filename}.ckpt

Experiment folder usage

Now that we have seen how to implement the CLI tool and how to use configuration files, let explore the configurations available in the experiments directory. The experiments directory is mainly organized as follows:

experiments
├── classification
│   ├── cifar10
│      ├── configs
│      ├── resnet.py
│      ├── vgg.py
│      └── wideresnet.py
│   └── cifar100
│       ├── configs
│       ├── resnet.py
│       ├── vgg.py
│       └── wideresnet.py
├── regression
│   └── uci_datasets
│       ├── configs
│       └── mlp.py
└── segmentation
    ├── cityscapes
       ├── configs
       └── segformer.py
    └── muad
        ├── configs
        └── segformer.py

For each task (classification, regression, and segmentation), we have a directory containing the datasets (e.g., CIFAR10, CIFAR100, UCI datasets, Cityscapes, and Muad) and for each dataset, we have a directory containing the configuration files and the CLI files for different backbones.

You can directly use the CLI files with the command line or use the predefined configuration files to train and test the models. The configuration files are stored in the configs. For example, the configuration file for the classic ResNet-18 model on the CIFAR-10 dataset is stored in the experiments/classification/cifar10/configs/resnet18/standard.yaml file. For the Packed ResNet-18 model on the CIFAR-10 dataset, the configuration file is stored in the experiments/classification/cifar10/configs/resnet18/packed.yaml file.

If you are interested in using a ResNet model but want to choose some of the hyperparameters using the command line, you can use the configuration file and override the hyperparameters using the command line. For example, to train a ResNet-18 model on the CIFAR-10 dataset with a batch size of \(256\), you can use the following command:

python resnet.py fit --config configs/resnet18/standard.yaml --data.batch_size 256

To use the weights argument of the torch.nn.CrossEntropyLoss class, you can use the following command:

python resnet.py fit --config configs/resnet18/standard.yaml --model.loss CrossEntropyLoss --model.loss.weight Tensor --model.loss.weight.dict_kwargs.data [1,2,3,4,5,6,7,8,9,10]

In addition, we provide a default configuration file for some backbones in the configs directory. For example, experiments/classification/cifar10/configs/resnet.yaml contains the default hyperparameters to train a ResNet model on the CIFAR-10 dataset. Yet, some hyperparameters are purposely missing to be set by the user using the command line.

For instance, to train a Packed ResNet-34 model on the CIFAR-10 dataset with \(4\) estimators and a \(\alpha\) value of \(2\), you can use the following command:

python resnet.py fit --config configs/resnet.yaml --trainer.max_epochs 75 --model.version packed --model.arch 34 --model.num_estimators 4 --model.alpha 2 --optimizer.lr 0.05 --lr_scheduler.milestones [25,50]

Tip

Explore the Lightning CLI docs to learn more about the CLI tool, the available arguments, and how to use them with configuration files.