softmax
- softmax中包含指数计算,为了数值稳定性,会将所有输入减去最大值
- safe softmax包含三次循环:1. 找最大值 2. 算指数,求和来得到分母 3. 除以分母
- online softmax将前两次合并到一起。每次求和时使用的是当时的局部最大值,如果发现了新的最大值会乘一个系数进行调整。最后一次循环不变
求导推导
先算
当i!=j时:
最后
看起来稍微有点乱,大概就是反向传回来的梯度对softmax的值做一个加权求和
附加一个对-log_softmax求导的推导,这个会用来算交叉熵 llm.c
求导推导
先算
当i!=j时:
Attention
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
这篇文章里有标准attention的前向计算和反向计算怎么算的,还有它提出的优化。反向计算和一些细节在附录里。
标准self-attention
前向传播公式:
标准attention的反向传播
问自己几个问题:需要保存多少中间值?需要保存注意力矩阵吗?重计算更好吗?mask需要存吗?todo
FlashAttention v1
前向
反向
FlashAttention v2
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
ring attention
ring attention + flash attention:超长上下文之路
大模型训练之序列并行双雄:DeepSpeed Ulysses & Ring-Attention
大模型推理加速技术的学习路线是什么? - 猛猿的回答 - 知乎
https://www.zhihu.com/question/591646269/answer/3309904882
vLLM皇冠上的明珠:深入浅出理解PagedAttention CUDA实现 - 方佳瑞的文章 - 知乎
https://zhuanlan.zhihu.com/p/673284781
【手撕LLM-Flash Attention】从softmax说起,保姆级超长文!! - 小冬瓜AIGC的文章 - 知乎
https://zhuanlan.zhihu.com/p/663932651