在机器学习和深度学习项目中,通常会涉及多个模型和多个数据集。当项目变得复杂时,从命令行配置超参数以混合不同的模型和数据集变得非常重要。这不仅提高了项目的灵活性,还减少了不必要的代码更改。本文将介绍如何通过命令行配置超参数,使项目更具可扩展性和可配置性。
为什么要混合模型和数据集
通常,一个深度学习项目在初始阶段可能只包括一个模型和一个数据集。随着项目的发展,您可能需要引入更多的模型和数据集,以提高模型性能或适应不同的任务。在这种情况下,从命令行轻松混合不同的模型和数据集成为一种非常有用的能力。
Lightning项目通常从一个模型和一个数据集开始。然而,当您的项目变得更加复杂,您可能希望能够在不更改代码的情况下从命令行直接混合任何模型和任何数据集。
下面是一个示例,展示了如何从命令行混合不同的模型和数据集:
$ python main.py fit --model=GAN --data=MNIST
$ python main.py fit --model=Transformer --data=MNIST
这里,我们使用LightningCLI使混合模型和数据集变得非常简单。否则,这种配置可能需要大量的样板代码,通常会看起来像这样:
# 选择模型
if args.model == "gan":
model = GAN(args.feat_dim)
elif args.model == "transformer":
model = Transformer(args.feat_dim)
...
# 选择数据模块
if args.data == "MNIST":
datamodule = MNIST()
elif args.data == "imagenet":
datamodule = Imagenet()
...
# 将它们混合在一起!
trainer.fit(model, datamodule)
因此,强烈建议避免编写这种样板代码,而是使用LightningCLI来实现这一目标。
多个Lightning模块
为了支持多个模型,当实例化LightningCLI时,省略model_class
参数:
# main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class Model1(DemoModel):
def configure_optimizers(self):
print("⚡", "使用 Model1", "⚡")
return super().configure_optimizers()
class Model2(DemoModel):
def configure_optimizers(self):
print("⚡", "使用 Model2", "⚡")
return super().configure_optimizers()
cli = LightningCLI(datamodule_class=BoringDataModule)
现在,您可以在命令行中选择任何模型:
# 使用 Model1
python main.py fit --model Model1
# 使用 Model2
python main.py fit --model Model2
提示: 您还可以选择给出一个基类,并将subclass_mode_model=True
。这将使CLI仅接受作为给定基类子类的模型。
多个Lightning数据模块
为了支持多个数据模块,当实例化LightningCLI时,省略datamodule_class
参数:
# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class FakeDataset1(BoringDataModule):
def train_dataloader(self):
print("⚡", "使用 FakeDataset1", "⚡")
return torch.utils.data.DataLoader(self.random_train)
class FakeDataset2(BoringDataModule):
def train_dataloader(self):
print("⚡", "使用 FakeDataset2", "⚡")
return torch.utils.data.DataLoader(self.random_train)
cli = LightningCLI(DemoModel)
现在,您可以在运行时选择任何数据集:
# 使用 FakeDataset1
python main.py fit --data FakeDataset1
# 使用 FakeDataset2
python main.py fit --data FakeDataset2
提示: 您还可以选择给出一个基类,并将subclass_mode_data=True
。这将使CLI仅接受作为给定基类子类的数据模块。
多个优化器
标准的torch.optim中的优化器可以直接使用:
python main.py fit --optimizer AdamW
如果您需要其他参数以适应所需的优化器,可以通过CLI添加它们(无需更改代码):
python main.py fit --optimizer SGD --optimizer.lr=0.01
此外,任何torch.optim.Optimizer的自定义子类都可以用作优化器:
# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class LitAdam(torch.optim.Adam):
def step(self, closure):
print("⚡", "使用 LitAdam", "⚡")
super().step(closure)
class FancyAdam(torch.optim.Adam):
def step(self, closure):
print("⚡", "使用 FancyAdam", "⚡")
super().step(closure)
cli = LightningCLI(DemoModel, BoringDataModule)
现在,您可以在运行时选择任何优化器:
# 使用 LitAdam
python main.py fit --optimizer LitAdam
# 使用 FancyAdam
python main.py fit --optimizer FancyAdam
多个学习率调度器
标准的torch.optim.lr_scheduler中的学习率调度器可以直接使用:
python main.py fit --optimizer=Adam --lr_scheduler CosineAnnealingLR
请注意,为了使--lr_scheduler
生效,必须添加--optimizer
。
如果您需要其他参数以适应所需的调度器,可以通过CLI添加它们(无需更改代码):
python main.py fit --optimizer=Adam --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch
此外,任何torch.optim.lr_scheduler.LRScheduler的自定义子类都可以用作学习率调度器:
# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
def step(self):
print("⚡", "使用 LitLRScheduler", "⚡")
super().step()
cli = LightningCLI(DemoModel, BoringDataModule)
现在,您可以在运行时选择任何学习率调度器:
# 使用 LitLRScheduler
python main.py fit --optimizer=Adam --lr_scheduler LitLRScheduler
来自任何包的类
在前面的部分中,选择要使用的自定义类是在运行LightningCLI类的相同python文件中定义的。为了从任何包中选择类,只需导入相应的包:
from lightning.pytorch.cli import LightningCLI
import my_code.models # noqa: F401
import my_code.data_modules # noqa: F401
import my_code.optimizers # noqa: F401
cli = LightningCLI()
现在,可以使用来自任何包的类:
python main.py fit --model Model1 --data FakeDataset1 --optimizer LitAdam --lr_scheduler LitLRScheduler
# noqa: F401
注释可以避免lint器警告,指出导入未使用。
还可以通过给出完整的导入路径选择尚未导入的子类:
python main.py fit --model my_code.models.Model1
特定类的帮助
当多个模型或数据集被接受时,CLI的主帮助不包括它们的具体参数。为了显示这些特定帮助,可以使用额外的帮助参数,指定类名或其导入路径。例如:
python main.py fit --model.help Model1
python main.py fit --data.help FakeDataset2
python main.py fit --optimizer.help Adagrad
python main.py fit --lr_scheduler.help StepLR