训练稳定性

1 混合精度训练

1.1 如何选择精度 FP16/BF16

BF16 因其与 FP32 相同的8位指数范围,在Ampere架构(A100、H100)及更新的GPU上成为首选。(精度问题可以看之前的文章: LLM精度问题)

相比之下,FP16的动态范围较窄,需要 GradScaler 通过 Loss scaling 机制来维护数值稳定性。

工作原理如下:

  1. 在反向传播前将损失值乘以缩放因子(初始值通常为2162^{16}或65536),放大梯度以避免FP16表示范围内的下溢

  2. 完成梯度计算后再将权重更新除以相同因子恢复原始尺度。

  3. 当检测到Inf或NaN时,缩放因子会按回退系数(通常0.5)衰减并跳过当前优化步骤,反之则在连续若干步骤(增长间隔通常2000步)成功后按增长系数(通常2.0)提升。

这种动态调整确保了训练的稳定推进。

值得注意的是,BF16虽简化了实现,但研究表明即使标准的BF16混合精度训练在特定随机种子下仍可能导致约10%的运行出现发散。

2 激活重计算

激活重计算(Gradient Checkpointing)通过时间换取空间来缓解显存压力。

  • 完全重计算(Full Recomputation)在每个Transformer层边界保存检查点,将激活内存从线性复杂度O(N)降至平方根复杂度O(√N),实现50-70%的内存节省,但会引入30-40%的计算开销。

  • 选择性重计算(Selective Recomputation)识别出内存密集但计算廉价的操作(如注意力机制中的Softmax和Dropout层)进行针对性重计算,仅增加约7%的开销却能达到接近完全重计算的内存优化效果。在长序列场景下,结合序列并行(Sequence Parallelism)可将两者优势叠加,将内存需求降至基线的20%以下,同时将计算开销控制在4%以内。

3 梯度裁剪

  • 梯度裁剪是一种防止梯度爆炸的正则化技术,通过限制梯度范数大小来避免参数更新过大导致模型发散(Loss=NaN)

  • 大模型(7B+参数)因参数量大、序列长(8K+ tokens)、使用低精度训练(FP16/BF16)及Transformer多层链式法则连乘特性,梯度爆炸风险显著高于传统模型

主要实现方式

  • 按范数裁剪:使用 torch.nn.utils.clip_grad_norm_ 限制梯度向量的整体L2范数,保持梯度方向不变仅缩放大小,是大模型训练的首选方案

  • 按值裁剪:使用 torch.clamp 将每个梯度元素限制在固定范围内,实现简单但可能改变梯度方向引入偏差

  • 动态自适应裁剪:包括基于Z-Score统计动态计算阈值、监控裁剪频率自动调整、随训练进度渐进衰减阈值等方法,适用于LLM预训练及 Loss Spike 场景

4 Loss Spike 如何处理

Loss Spike 指在模型预训练或微调过程中,损失值在单个或连续若干步骤内发生数量级跃升的现象。

可能是如下原因:

  • 数据层异常:batch 内包含极端异常值、损坏样本、错误标注或编码异常文本,导致梯度更新方向偏离;训练数据流中特定领域数据(如超长文本、低质量网页内容)突然集中出现,破坏优化轨迹稳定性

  • 学习率过高:Warmup 阶段过早结束导致优化器状态未充分初始化;Cosine Annealing 反弹阶段学习率恢复过快

  • Loss计算错误:前向传播中触发除零、log(0)、exp 上溢或极小数开方(如 LayerNorm 中 variance 接近零时);GradScaler 的 Loss Scaling 因子设置不当导致梯度上溢或下溢;动态缩放窗口(Growth Interval)过长无法及时响应数值漂移

如何处理

  • 数据层面:

    • 数据清洗:移除或修正损坏/错误标注的样本。

    • 异常值处理:对特征进行截断或鲁棒的归一化。

    • 打乱数据:确保训练数据的随机性,避免连续困难样本。

  • 学习率调整:

    • 降低学习率

    • Warmup:在训练初期使用较小的学习率,然后逐渐增加到设定值。

    • 检查/调整 LR Scheduler:确保调度器逻辑正确,峰值学习率和衰减策略合理。

  • 梯度裁剪:

    • 设置一个梯度的范数上限或值上限。当计算出的梯度超过此上限时,将其缩放或截断。

参考

Last updated