评估与工程化

1 中间评估

预训练通常持续很长时间,仅在结束时评估无法及时发现问题(如能力遗忘、训练不稳定)。所以 LLM 预训练通常会使用中间评估。

中间评估有以下几种方式:

1.1 PPL

使用一些非训练数据的高质量语料作为评测,在每个 checkpoint 上观察模型在这些测试集合上的 loss 表现。理想情况下,这里的loss应该是会随着训练时间增加而持续下降的。

1.2 下游任务评测

在每个 checkpoint(如每 100B tokens)执行 Zero-shot / Few-shot 评测,覆盖:

  • 推理能力:HellaSwag、ARC、PIQA

  • 知识密集型:MMLU、C-Eval、CMMLU

  • 数学/代码:GSM8K、HumanEval、MBPP

  • 长文本:L-Eval、LongBench

为了避免数据泄露,可以对 Benchmark 进行一些改造,比如调整选择题的选项顺序、对问题进行重写等等。

1.3 概率探针

从概率的视角监控模型在特定知识上的遗忘或提升。追h踪模型对特定 token 或句子的概率变化,从而判断知识或能力的演化趋势。实际操作中,需要人工构造探针测试集——这通常是逐条设计,难以批量生成。例如:

  • 事实知识:观察条件概率 (\text{Prob}(\text{“巴黎”} \mid \text{“法国的首都是”})) 是否随训练过程持续上升;

  • 一致性:对比 (\text{PPL}(\text{“地球是圆的”})) 是否下降,而 (\text{PPL}(\text{“地球是平的”})) 是否上升;

  • 价值对齐:验证 (\text{PPL}(\text{“诚实是美德”})) 是否持续低于 (\text{PPL}(\text{“欺骗是合理的”}));

  • 指令遵循:检验在提示“以 JSON 格式输出”后,模型输出首个左大括号的概率 (\text{Prob}(\text{“{”} \mid \ldots)) 是否提高。

这类概率探针可根据需求任意构造,重点在于监测指标的变化趋势,而非绝对数值。

3. Scaling Law 验证与训练预测

(1)带学习率退火的 Scaling Law(2024-2025 新进展) 传统 Scaling Law 仅拟合最终 loss,无法指导中间过程。最新研究提出含学习率退火项的幂律公式

L(s)=L0+AS1αCS2L(s) = L_0 + A \cdot S_1^{-\alpha} - C \cdot S_2

其中 $S_1$ 为学习率曲线下面积(表征数据消耗),$S_2$ 为学习率退火区域面积。该公式可从少量早期训练曲线(如 20K 步)预测任意学习率调度下的完整 loss 曲线(如 60K 步),误差 <2%。

工程应用

  • 早期预警:在训练初期(如 5% 进度)拟合参数,预测最终收敛 loss,若偏离目标(如损失过高),提前终止或调整数据混合策略。

  • 学习率调度优化:通过最小化预测的最终 loss,自动搜索最优学习率调度,可发现优于 Cosine 的调度策略(类似 WSD 但性能更优)。

(2)Temporal Scaling Law 针对有限训练步数的实际场景,验证 loss 随时间演化是否符合幂律,用于预测剩余训练时间所需的计算资源。


4. 断点续训与检查点策略(Checkpointing Strategy)

大规模预训练常面临硬件故障、网络中断,需高频、低延迟、一致性保障的检查点体系:

(1)高频 Checkpoint 策略

  • 分层保存:区分轻量检查点(仅模型权重,每 100 步)与完整检查点(含优化器状态、RNG 状态、数据加载器位置,每 1000 步)。

  • 异步保存:利用 SSD/NVMe 进行异步 I/O,主训练流不阻塞。DeepSpeed 提供 save_checkpoint API 自动处理模型状态与客户端状态(iteration、RNG)的分离存储。

  • 内存快照:采用 In-Memory Snapshot 技术,在 NPU/GPU HBM 中维护一份检查点副本,故障时秒级恢复,避免从慢速存储读取。

(2)热备份与冗余

  • 双副本策略:关键检查点(如阶段切换点)同时写入本地 NVMe 和远程对象存储(如 S3)。

  • 检查点验证:加载后执行前向-反向验证步,确保数值一致性(loss 匹配)后再正式恢复训练。


5. 故障自动恢复与弹性训练(Elastic Training)

现代集群规模($10^5$~$10^6$ 加速器)使故障成为常态,需**弹性原生(Elastic-Native)**训练系统:

(1)ElasWave 弹性训练框架(2025) 华为/港科大等提出的 ElasWave 实现每步容错(Per-step Fault Tolerance)

  • 动态通信组:支持 in-place 通信组编辑,节点掉线时无需重建整个通信域(NCCL communicator),恢复时间从分钟级降至 <1 秒(相比完整重建快 82×)。

  • 参数一致性保障:通过 ZeRO 分区交错迁移(Interleaved ZeRO State Movement),在节点增减时保持梯度聚合一致性,避免收敛偏差。

  • 计算一致性:重新分片(Reshard)随机数生成器(RNG)状态,确保弹性伸缩后,相同数据批次产生与故障前一致的随机性,收敛偏差降低 78%

(2)故障分级响应

  • 节点级故障:自动剔除故障节点,剩余节点通过 弹性数据并行(Elastic DP) 重新分配 micro-batch,保持全局 batch size 恒定。

  • 进程级故障:利用热备进程(Warm Standby)秒级接管,结合 Pipeline 冗余(Oobleck) 重新映射流水线阶段。

  • 检查点回滚:若故障导致参数污染(如 NaN),自动回滚到最近验证过的检查点,通常损失 <10 步进度。


6. 数据加载优化(Data Loading Optimization)

数据加载是预训练瓶颈之一(尤其在高并发场景),需零拷贝、预取、动态重采样

(1)Megatron-LM / DeepSpeed 数据并行优化

  • 分桶动态分辨率(Bucket-wise Dynamic Resolution):针对多模态或长文本训练,按序列长度/分辨率分桶,同桶内样本动态填充(Padding)至统一长度,减少无效计算。数据加载器在线分配样本到桶,满桶即触发训练步。

  • 数据重加权(Data Reweighting):通过网格搜索(Grid Search)动态调整各数据源混合比例,平衡不同维度(语言、领域、质量)的性能,避免单一数据集主导。

(2)预取与流水线

  • 三级预取架构Storage → CPU Memory → GPU HBM 异步流水线,利用 torch.utils.data.DataLoaderprefetch_factorpersistent_workers,确保 GPU 计算单元利用率(MFU)>50%。

  • 内存映射(Memory Mapping):对大规模语料(TB 级)使用 numpy.memmapmmap 模式,避免全量加载到 RAM,支持多进程零拷贝共享。

(3)弹性数据加载 节点掉线时,弹性训练系统需同步调整数据加载器的全局 rank 和样本分片策略,确保故障恢复后数据不重复、不丢失。ElasWave 等系统通过动态数据流调度(Dataflow Scheduling)实现这一能力。

Last updated