← 返回题库
训练与微调困难

大模型训练 loss spike(损失尖峰)怎么办?

#loss spike#训练稳定性#梯度裁剪#检查点回滚

题目

大模型训练中常出现 loss spike(损失突增)。请说明原因、危害与排查处理流程。

参考答案

loss spike 表现:训练 loss 本来平稳下降,突然飙升(可能数十倍),伴随梯度爆炸、参数混乱,轻则模型退化,重则出现 NaN/Inf 训练中断。

常见原因

  1. 数据异常:混入超长序列、重复 batch、低质脏数据、数值异常 token。
  2. 梯度爆炸:深层网络梯度连乘导致数值溢出,尤其 Post-Norm 架构。
  3. 学习率过大:某步更新跨过最优点,进入不稳定区域。
  4. 数值精度:FP16 下梯度下溢或上溢。
  5. 分布式不同步:all-reduce 异常导致部分卡用错梯度。
  6. 架构问题:注意力未加缩放、未做 QK-norm 等。

急救处理(先止血)

  1. 梯度裁剪(Gradient Clipping):限制梯度范数(如 max_norm=1.0),是最常用的防爆炸手段。
  2. 学习率回退:发现 spike 立即降低学习率,很多训练框架支持自动 LR decay on spike。
  3. 检查点回滚:回滚到 spike 前的健康 checkpoint,跳过引发 spike 的数据 batch 重训。
  4. 跳过异常 batch:检测到 loss 异常高时直接跳过该 step 不更新。

根治手段

  • 架构稳定性:用 Pre-Norm + RMSNorm + QK-norm + 初始化调整(μP 等)。
  • 精度:BF16 替代 FP16,避免溢出;关键累积用 FP32。
  • 数据质量:严格过滤超长/重复/异常样本,训练前做分布检查。
  • 学习率调度:warmup + cosine decay,避免初期过大更新。
  • Embedding 与输出层:这些层参数大、梯度大,常需单独学习率或额外正则。
  • Loss scaling 监控:FP16 下动态 loss scale,监控 scale 是否持续下降(下降意味溢出频发)。

调试方法论

  1. 看 loss 曲线判断是单点 spike 还是持续不收敛。
  2. 查 spike 出现的 step 与数据 batch,定位是否数据问题。
  3. 查梯度/激活统计(min/max/mean/std)逐层排查。
  4. 查 optimizer state 是否 NaN。
  5. 缩小复现范围(小 batch、单卡复现)。

面试加分点

  • 指出 spike 在百亿到千亿规模训练中几乎是常态,不是”是否遇到”而是”如何应对”——GPT-3、PaLM 训练报告都记录了多次 spike 与回滚。
  • 强调”回滚 + 跳 batch”是工程上最实用的应急组合。
  • LR warmup + 梯度裁剪 + BF16 是预防三件套,缺一不可。

出处:CSDN《大模型面试题52:LLM 如果在训练过程中 loss 值出现 spike 应该怎么办》、《大模型最新面试题系列:训练篇之训练稳定性》。

内容来源

整理自 CSDN《大模型面试题52:LLM 如果在训练过程中 loss 值出现 spike 应该怎么办》及《大模型最新面试题系列:训练篇之训练稳定性》

本站内容整理自公开面经与开源仓库,仅供学习交流,严禁杜撰。