Skip to content

SelfAttention_Family.py 总览

Abstract

这个文件不是一个模型,而是一组 attention 组件库。

阅读顺序建议:先看 AttentionLayer 外壳,再看 FullAttention 标准注意力,之后看 ProbAttention 的稀疏近似;DSAttentionReformerLayerTwoStageAttentionLayer 是特殊变体。

0. 类地图

![[zdocs/pytorch-basics/assets/self_attention_family_class_map.svg]]

1. 源码类清单

角色核心输入核心输出典型模型
AttentionLayer外壳:Q/K/V 投影 + 多头拆分 + 输出投影(B,L,d_model)(B,L,d_model)Informer / PatchTST
FullAttention标准 scaled dot-product attention(B,L,H,E)(B,L,H,D)PatchTST
DSAttentionDe-stationary Attention,在 score 上加 tau/delta 修正(B,L,H,E)(B,L,H,D)Non-stationary Transformer 类
ProbAttentionProbSparse Attention,抽样估计重要 query(B,L,H,D)(B,H,L,D)Informer
ReformerLayerReformer LSH attention 包装层(B,N,d_model)(B,N,d_model)Reformer 变体
TwoStageAttentionLayer时间维 attention + 变量维 router attention(B,D,L,d_model)(B,D,L,d_model)Crossformer 类

2. 统一 toy 参数

为了避免维度混淆,本组文档使用一套统一 toy 参数:

符号数值含义
B2batch size
L5query length
S / L_K6key/value length
H2head 数
E / d_keys3query/key 每头维度
D / d_values4value 每头维度
d_model8模型隐藏维
factor2ProbAttention 采样系数
bucket_size4Reformer bucket size

注意:AttentionLayerd_model=8n_heads=2,如果不显式传 d_keys/d_values,默认会得到 4。为了展示 Q/K 和 V 维度可以不同,部分文档会显式设 d_keys=3, d_values=4

3. 数学总览

3.1 标准注意力

Scoresb,h,l,s=eQb,l,h,eKb,s,h,eAb,h,l,:=softmax(Scoresb,h,l,:E)Vb,l,h,d=sAb,h,l,sVb,s,h,d

3.2 ProbSparse 的核心稀疏度

ProbAttention 先抽样一部分 key,估计每个 query 的“尖锐程度”:

M(qi,K)=maxj(qikj)1LKjqikj

M 越大,说明这个 query 对少数 key 特别敏感,越值得完整计算 attention。

4. 文件索引

文件覆盖类重点
[[01-AttentionLayer-多头投影外壳]]AttentionLayerLinear、view、多头合并
[[02-FullAttention-标准缩放点积注意力]]FullAttentioneinsum、mask、softmax、加权求和
[[03-DSAttention-去平稳注意力]]DSAttentiontaudelta 如何修正 score
[[04-ProbAttention-ProbSparse稀疏注意力]]ProbAttention_prob_QKtopk、context 初始化和更新
[[05-ReformerLayer-LSH注意力包装]]ReformerLayerfit_length、bucket 对齐、裁剪回原长
[[06-TwoStageAttentionLayer-两阶段注意力]]TwoStageAttentionLayer时间维 attention、变量维 router attention

5. 读法

如果目标是理解 PatchTST:读 AttentionLayer -> FullAttention

如果目标是理解 Informer:读 AttentionLayer -> ProbAttention

如果目标是把所有 attention 变体串起来:按本文索引从 01 读到 06。

*记录并在线阅读我的笔记*