训练稳定性
1 混合精度训练
1.1 如何选择精度 FP16/BF16
BF16 因其与 FP32 相同的8位指数范围,在Ampere架构(A100、H100)及更新的GPU上成为首选。(精度问题可以看之前的文章: LLM精度问题)
相比之下,FP16的动态范围较窄,需要 GradScaler 通过 Loss scaling 机制来维护数值稳定性。
工作原理如下:
在反向传播前将损失值乘以缩放因子(初始值通常为216或65536),放大梯度以避免FP16表示范围内的下溢
完成梯度计算后再将权重更新除以相同因子恢复原始尺度。
当检测到Inf或NaN时,缩放因子会按回退系数(通常0.5)衰减并跳过当前优化步骤,反之则在连续若干步骤(增长间隔通常2000步)成功后按增长系数(通常2.0)提升。
这种动态调整确保了训练的稳定推进。
值得注意的是,BF16虽简化了实现,但研究表明即使标准的BF16混合精度训练在特定随机种子下仍可能导致约10%的运行出现发散。
2 激活重计算
激活重计算(Gradient Checkpointing)通过时间换取空间来缓解显存压力。
完全重计算(Full Recomputation)在每个Transformer层边界保存检查点,将激活内存从线性复杂度O(N)降至平方根复杂度O(√N),实现50-70%的内存节省,但会引入30-40%的计算开销。
选择性重计算(Selective Recomputation)识别出内存密集但计算廉价的操作(如注意力机制中的Softmax和Dropout层)进行针对性重计算,仅增加约7%的开销却能达到接近完全重计算的内存优化效果。在长序列场景下,结合序列并行(Sequence Parallelism)可将两者优势叠加,将内存需求降至基线的20%以下,同时将计算开销控制在4%以内。
3 梯度裁剪
梯度裁剪是一种防止梯度爆炸的正则化技术,通过限制梯度范数大小来避免参数更新过大导致模型发散(Loss=NaN)
大模型(7B+参数)因参数量大、序列长(8K+ tokens)、使用低精度训练(FP16/BF16)及Transformer多层链式法则连乘特性,梯度爆炸风险显著高于传统模型
主要实现方式
按范数裁剪:使用
torch.nn.utils.clip_grad_norm_限制梯度向量的整体L2范数,保持梯度方向不变仅缩放大小,是大模型训练的首选方案按值裁剪:使用
torch.clamp将每个梯度元素限制在固定范围内,实现简单但可能改变梯度方向引入偏差动态自适应裁剪:包括基于Z-Score统计动态计算阈值、监控裁剪频率自动调整、随训练进度渐进衰减阈值等方法,适用于LLM预训练及 Loss Spike 场景
4 Loss Spike 如何处理
Loss Spike 指在模型预训练或微调过程中,损失值在单个或连续若干步骤内发生数量级跃升的现象。

可能是如下原因:
数据层异常:batch 内包含极端异常值、损坏样本、错误标注或编码异常文本,导致梯度更新方向偏离;训练数据流中特定领域数据(如超长文本、低质量网页内容)突然集中出现,破坏优化轨迹稳定性
学习率过高:Warmup 阶段过早结束导致优化器状态未充分初始化;Cosine Annealing 反弹阶段学习率恢复过快
Loss计算错误:前向传播中触发除零、log(0)、exp 上溢或极小数开方(如 LayerNorm 中 variance 接近零时);GradScaler 的 Loss Scaling 因子设置不当导致梯度上溢或下溢;动态缩放窗口(Growth Interval)过长无法及时响应数值漂移
如何处理
数据层面:
数据清洗:移除或修正损坏/错误标注的样本。
异常值处理:对特征进行截断或鲁棒的归一化。
打乱数据:确保训练数据的随机性,避免连续困难样本。
学习率调整:
降低学习率
Warmup:在训练初期使用较小的学习率,然后逐渐增加到设定值。
检查/调整 LR Scheduler:确保调度器逻辑正确,峰值学习率和衰减策略合理。
梯度裁剪:
设置一个梯度的范数上限或值上限。当计算出的梯度超过此上限时,将其缩放或截断。
参考
Last updated