注意力计算的瓶颈不在算力,在内存带宽。标准注意力需要把完整的Q、K、V矩阵和中间的注意力分数矩阵都保存在GPU显存中。Flash Attention通过重新设计计算顺序,彻底解决了这个问题。
核心思想
标准注意力:计算完整的N×N注意力矩阵,写入显存,再读出来做softmax和加权求和。IO复杂度O(N²)。
Flash Attention:把Q、K、V分成小块,在SRAM(GPU片上高速缓存)中完成小块的注意力计算,用在线softmax算法避免存储完整的注意力矩阵。IO复杂度O(N²/M),M是SRAM大小。
实际效果
训练速度提升2-4倍。显存占用从O(N²)降到O(N)。支持更长的序列——原来2048 token的限制可以扩展到16K甚至更长。
使用方式
Flash Attention 2已经集成到PyTorch和Transformers中。大多数情况下只需要:
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)
不需要修改任何代码,自动生效。
Flash Attention 3
最新版本针对Hopper架构(H100)做了深度优化。异步计算和数据传输、FP8量化注意力、更好的warp调度。在H100上比Flash Attention 2再快1.5-2倍。
现代大模型训练和推理几乎都默认开启Flash Attention。如果你的框架支持但你没开,等于白送的性能不要。