LLM训练需要多少显存

1 大模型训练RAM构成

首先,精度会对参数所需内存有影响:

  • fp32,float point,一个参数需要32bits,即4bytes

  • fp16,,一个参数需要 16 bits, 2 bytes

  • int8 精度,一个参数需要 8 bits, 1 byte

大模型训练的显存占用主要由下面几部分构成:

  • 模型参数

    • P = 参数量 * 每个参数所需内存

  • 梯度

    • 与模型参数类似,G = 参数量 * 每个梯度参数所需内存

  • 优化器状态

    • 不同的优化器所储存的参数量不同。

      • 例如,AdamW 需维护一阶动量(m)和二阶动量(v),因此需要存储两倍的模型参数

  • 激活值(前向传播产生的中间张量)

    • A = B * S * E * C

      • B: Batch Size (批处理大小)

      • S: Sequence Length (序列长度,或称上下文长度,即输入和输出的总token数)

      • E: Embedding Dimension (嵌入维度,或模型的隐藏层大小)

      • C: 常数因子,取决于具体的模型架构和实现细节。对于Transformer模型,这个因子通常大于1,因为它需要存储多个层的激活,并且某些操作(如MLP层)会创建维度为 4×E 的中间张量。

  • CUDA内核开销

    • CUDA kernel 也会占据一些 RAM,大概 1.3GB 左右

2 显存计算示例

以 LLaMA-6B 为例,所需的显存占用如下:

  • 模型状态

    • 模型参数 6GB

    • 梯度 6GB

    • 优化器(AdamW)12GB

  • 运行时内存

    • CUDA Kernel 1.3 GB

  • 激活值

    • 基于参数架构

      • hidden_size = 4096, intermediate_size =11008, num_hidden_layers = 32, context_length = 2048

    • 每个实例:

      (4096 +11008) * 2048 *32 * 1byte = 990MB

    • batch size=50时占用 48.3GB

参考

Last updated