评估与工程化
1 中间评估
预训练通常持续很长时间,仅在结束时评估无法及时发现问题(如能力遗忘、训练不稳定)。所以 LLM 预训练通常会使用中间评估。
中间评估有以下几种方式:
1.1 PPL
使用一些非训练数据的高质量语料作为评测,在每个 checkpoint 上观察模型在这些测试集合上的 loss 表现。理想情况下,这里的loss应该是会随着训练时间增加而持续下降的。
1.2 下游任务评测
在每个 checkpoint(如每 100B tokens)执行 Zero-shot / Few-shot 评测,覆盖:
推理能力:HellaSwag、ARC、PIQA
知识密集型:MMLU、C-Eval、CMMLU
数学/代码:GSM8K、HumanEval、MBPP
长文本:L-Eval、LongBench
为了避免数据泄露,可以对 Benchmark 进行一些改造,比如调整选择题的选项顺序、对问题进行重写等等。
1.3 概率探针
从概率的视角监控模型在特定知识上的遗忘或提升。追h踪模型对特定 token 或句子的概率变化,从而判断知识或能力的演化趋势。实际操作中,需要人工构造探针测试集——这通常是逐条设计,难以批量生成。例如:
事实知识:观察条件概率 (\text{Prob}(\text{“巴黎”} \mid \text{“法国的首都是”})) 是否随训练过程持续上升;
一致性:对比 (\text{PPL}(\text{“地球是圆的”})) 是否下降,而 (\text{PPL}(\text{“地球是平的”})) 是否上升;
价值对齐:验证 (\text{PPL}(\text{“诚实是美德”})) 是否持续低于 (\text{PPL}(\text{“欺骗是合理的”}));
指令遵循:检验在提示“以 JSON 格式输出”后,模型输出首个左大括号的概率 (\text{Prob}(\text{“{”} \mid \ldots)) 是否提高。
这类概率探针可根据需求任意构造,重点在于监测指标的变化趋势,而非绝对数值。
3. Scaling Law 验证与训练预测
(1)带学习率退火的 Scaling Law(2024-2025 新进展) 传统 Scaling Law 仅拟合最终 loss,无法指导中间过程。最新研究提出含学习率退火项的幂律公式:
L(s)=L0+A⋅S1−α−C⋅S2
其中 $S_1$ 为学习率曲线下面积(表征数据消耗),$S_2$ 为学习率退火区域面积。该公式可从少量早期训练曲线(如 20K 步)预测任意学习率调度下的完整 loss 曲线(如 60K 步),误差 <2%。
工程应用:
早期预警:在训练初期(如 5% 进度)拟合参数,预测最终收敛 loss,若偏离目标(如损失过高),提前终止或调整数据混合策略。
学习率调度优化:通过最小化预测的最终 loss,自动搜索最优学习率调度,可发现优于 Cosine 的调度策略(类似 WSD 但性能更优)。
(2)Temporal Scaling Law 针对有限训练步数的实际场景,验证 loss 随时间演化是否符合幂律,用于预测剩余训练时间所需的计算资源。
4. 断点续训与检查点策略(Checkpointing Strategy)
大规模预训练常面临硬件故障、网络中断,需高频、低延迟、一致性保障的检查点体系:
(1)高频 Checkpoint 策略
分层保存:区分轻量检查点(仅模型权重,每 100 步)与完整检查点(含优化器状态、RNG 状态、数据加载器位置,每 1000 步)。
异步保存:利用 SSD/NVMe 进行异步 I/O,主训练流不阻塞。DeepSpeed 提供
save_checkpointAPI 自动处理模型状态与客户端状态(iteration、RNG)的分离存储。内存快照:采用 In-Memory Snapshot 技术,在 NPU/GPU HBM 中维护一份检查点副本,故障时秒级恢复,避免从慢速存储读取。
(2)热备份与冗余
双副本策略:关键检查点(如阶段切换点)同时写入本地 NVMe 和远程对象存储(如 S3)。
检查点验证:加载后执行前向-反向验证步,确保数值一致性(loss 匹配)后再正式恢复训练。
5. 故障自动恢复与弹性训练(Elastic Training)
现代集群规模($10^5$~$10^6$ 加速器)使故障成为常态,需**弹性原生(Elastic-Native)**训练系统:
(1)ElasWave 弹性训练框架(2025) 华为/港科大等提出的 ElasWave 实现每步容错(Per-step Fault Tolerance):
动态通信组:支持 in-place 通信组编辑,节点掉线时无需重建整个通信域(NCCL communicator),恢复时间从分钟级降至 <1 秒(相比完整重建快 82×)。
参数一致性保障:通过 ZeRO 分区交错迁移(Interleaved ZeRO State Movement),在节点增减时保持梯度聚合一致性,避免收敛偏差。
计算一致性:重新分片(Reshard)随机数生成器(RNG)状态,确保弹性伸缩后,相同数据批次产生与故障前一致的随机性,收敛偏差降低 78%。
(2)故障分级响应
节点级故障:自动剔除故障节点,剩余节点通过 弹性数据并行(Elastic DP) 重新分配 micro-batch,保持全局 batch size 恒定。
进程级故障:利用热备进程(Warm Standby)秒级接管,结合 Pipeline 冗余(Oobleck) 重新映射流水线阶段。
检查点回滚:若故障导致参数污染(如 NaN),自动回滚到最近验证过的检查点,通常损失 <10 步进度。
6. 数据加载优化(Data Loading Optimization)
数据加载是预训练瓶颈之一(尤其在高并发场景),需零拷贝、预取、动态重采样:
(1)Megatron-LM / DeepSpeed 数据并行优化
分桶动态分辨率(Bucket-wise Dynamic Resolution):针对多模态或长文本训练,按序列长度/分辨率分桶,同桶内样本动态填充(Padding)至统一长度,减少无效计算。数据加载器在线分配样本到桶,满桶即触发训练步。
数据重加权(Data Reweighting):通过网格搜索(Grid Search)动态调整各数据源混合比例,平衡不同维度(语言、领域、质量)的性能,避免单一数据集主导。
(2)预取与流水线
三级预取架构:
Storage → CPU Memory → GPU HBM异步流水线,利用torch.utils.data.DataLoader的prefetch_factor和persistent_workers,确保 GPU 计算单元利用率(MFU)>50%。内存映射(Memory Mapping):对大规模语料(TB 级)使用
numpy.memmap或mmap模式,避免全量加载到 RAM,支持多进程零拷贝共享。
(3)弹性数据加载 节点掉线时,弹性训练系统需同步调整数据加载器的全局 rank 和样本分片策略,确保故障恢复后数据不重复、不丢失。ElasWave 等系统通过动态数据流调度(Dataflow Scheduling)实现这一能力。
Last updated