结论来自:分析transformer模型的参数量、计算量、中间激活、KV cache

从托马斯那里的一些takeaway


当 h >> L 时 前向计算量约为 24BLH^2 * layers,参数量约为 12H^2 * layers

前向计算量/参数量 = 2

后向计算量=2前向计算量

所以总计算量 = 6 * 参数量,一次训练迭代中,对于每个token,每个模型参数,需要进行 2∗3=6 次浮点数运算。

所以我们知道了参数量,知道了总的token的 token 数,就可以算出训练完一个 epoch 需要的 flops,这样除以 gpu 的可达算力,就可以得到训练时间

一般来讲,GPU利用率一般在 0.3∼0.55 之间。token数最好为 模型参数的10倍到100倍。

对于每个token,每个模型参数,进行2次浮点数计算。使用激活重计算技术来减少中间激活显存(下文会详细介绍)需要进行一次额外的前向传递,因此前向传递 + 后向传递 + 激活重计算的系数=1+2+1=4。使用激活重计算的一次训练迭代中,对于每个token,每个模型参数,需要进行 24=82*4=8

在给定训练tokens数、硬件环境配置的情况下,训练transformer模型的计算时间为

 训练时间 8× tokens数 × 模型参数量 GPU 数 ×GPU 峰值 flops×GPU 利用率 \text { 训练时间 } \approx \frac{8 \times \text { tokens数 } \times \text { 模型参数量 }}{G P U \text { 数 } \times G P U \text { 峰值 } f l o p s \times G P U \text { 利用率 }}