Appearance
SelfAttention_Family.py 总览
Abstract
这个文件不是一个模型,而是一组 attention 组件库。
阅读顺序建议:先看
AttentionLayer外壳,再看FullAttention标准注意力,之后看ProbAttention的稀疏近似;DSAttention、ReformerLayer、TwoStageAttentionLayer是特殊变体。
0. 类地图
![[zdocs/pytorch-basics/assets/self_attention_family_class_map.svg]]
1. 源码类清单
| 类 | 角色 | 核心输入 | 核心输出 | 典型模型 |
|---|---|---|---|---|
AttentionLayer | 外壳:Q/K/V 投影 + 多头拆分 + 输出投影 | (B,L,d_model) | (B,L,d_model) | Informer / PatchTST |
FullAttention | 标准 scaled dot-product attention | (B,L,H,E) | (B,L,H,D) | PatchTST |
DSAttention | De-stationary Attention,在 score 上加 tau/delta 修正 | (B,L,H,E) | (B,L,H,D) | Non-stationary Transformer 类 |
ProbAttention | ProbSparse Attention,抽样估计重要 query | (B,L,H,D) | (B,H,L,D) | Informer |
ReformerLayer | Reformer LSH attention 包装层 | (B,N,d_model) | (B,N,d_model) | Reformer 变体 |
TwoStageAttentionLayer | 时间维 attention + 变量维 router attention | (B,D,L,d_model) | (B,D,L,d_model) | Crossformer 类 |
2. 统一 toy 参数
为了避免维度混淆,本组文档使用一套统一 toy 参数:
| 符号 | 数值 | 含义 |
|---|---|---|
B | 2 | batch size |
L | 5 | query length |
S / L_K | 6 | key/value length |
H | 2 | head 数 |
E / d_keys | 3 | query/key 每头维度 |
D / d_values | 4 | value 每头维度 |
d_model | 8 | 模型隐藏维 |
factor | 2 | ProbAttention 采样系数 |
bucket_size | 4 | Reformer bucket size |
注意:AttentionLayer 里 d_model=8,n_heads=2,如果不显式传 d_keys/d_values,默认会得到 4。为了展示 Q/K 和 V 维度可以不同,部分文档会显式设 d_keys=3, d_values=4。
3. 数学总览
3.1 标准注意力
3.2 ProbSparse 的核心稀疏度
ProbAttention 先抽样一部分 key,估计每个 query 的“尖锐程度”:
M 越大,说明这个 query 对少数 key 特别敏感,越值得完整计算 attention。
4. 文件索引
| 文件 | 覆盖类 | 重点 |
|---|---|---|
| [[01-AttentionLayer-多头投影外壳]] | AttentionLayer | Linear、view、多头合并 |
| [[02-FullAttention-标准缩放点积注意力]] | FullAttention | einsum、mask、softmax、加权求和 |
| [[03-DSAttention-去平稳注意力]] | DSAttention | tau、delta 如何修正 score |
| [[04-ProbAttention-ProbSparse稀疏注意力]] | ProbAttention | _prob_QK、topk、context 初始化和更新 |
| [[05-ReformerLayer-LSH注意力包装]] | ReformerLayer | fit_length、bucket 对齐、裁剪回原长 |
| [[06-TwoStageAttentionLayer-两阶段注意力]] | TwoStageAttentionLayer | 时间维 attention、变量维 router attention |
5. 读法
如果目标是理解 PatchTST:读 AttentionLayer -> FullAttention。
如果目标是理解 Informer:读 AttentionLayer -> ProbAttention。
如果目标是把所有 attention 变体串起来:按本文索引从 01 读到 06。