写PyTorch训练代码最烦的是什么?手动管理训练循环、梯度累积、多卡同步、checkpoint保存、日志记录……这些样板代码占了80%的工作量。PyTorch Lightning把这些全自动化了。

核心概念

LightningModule定义模型和训练逻辑。只写三个方法:__init__定义模型结构,training_step定义单步训练逻辑,configure_optimizers定义优化器。其他全部由框架处理。

import lightning as L

class MyModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = ...  # 你的模型
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = F.cross_entropy(pred, y)
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-5)

自动化能力

自动处理GPU/CPU切换、分布式训练(DDP、FSDP)、混合精度(FP16/BF16)、梯度累积、梯度裁剪、checkpoint保存和恢复、早停、日志记录。

一行代码切换训练策略:trainer = L.Trainer(strategy=“ddp”, devices=4)。

Callback系统

通过Callback扩展训练逻辑而不修改核心代码。ModelCheckpoint自动保存最优模型,EarlyStopping在验证集不再提升时停止训练,LearningRateMonitor记录学习率变化。

和HuggingFace配合

Lightning可以和Transformers库配合使用。用Lightning管理训练循环,用Transformers加载预训练模型。两者互补不冲突。