写PyTorch训练代码最烦的是什么?手动管理训练循环、梯度累积、多卡同步、checkpoint保存、日志记录……这些样板代码占了80%的工作量。PyTorch Lightning把这些全自动化了。
核心概念
LightningModule定义模型和训练逻辑。只写三个方法:__init__定义模型结构,training_step定义单步训练逻辑,configure_optimizers定义优化器。其他全部由框架处理。
import lightning as L
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = ... # 你的模型
def training_step(self, batch, batch_idx):
x, y = batch
pred = self.model(x)
loss = F.cross_entropy(pred, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=2e-5)
自动化能力
自动处理GPU/CPU切换、分布式训练(DDP、FSDP)、混合精度(FP16/BF16)、梯度累积、梯度裁剪、checkpoint保存和恢复、早停、日志记录。
一行代码切换训练策略:trainer = L.Trainer(strategy=“ddp”, devices=4)。
Callback系统
通过Callback扩展训练逻辑而不修改核心代码。ModelCheckpoint自动保存最优模型,EarlyStopping在验证集不再提升时停止训练,LearningRateMonitor记录学习率变化。
和HuggingFace配合
Lightning可以和Transformers库配合使用。用Lightning管理训练循环,用Transformers加载预训练模型。两者互补不冲突。