Appearance
Layer 4 — AttentionLayer 精读
由 EncoderLayer(03B1-Layer3-EncoderLayer)和 DecoderLayer(03C1-Layer3-DecoderLayer)调用,共计 4 种场景。
本文档只覆盖AttentionLayer.forward这一层(d_model ↔ 多头格式桥梁)。
子层 ProbAttention 见 04A-Layer5-ProbAttention。
1. 在父层中的位置
AttentionLayer 被 4 个地方调用,统一接口,差异只在 mask_flag 和 Q/K/V 来源:
| 调用位置 | mask_flag | Q 来源 | K/V 来源 | L_Q | L_K |
|---|---|---|---|---|---|
| EncoderLayer 0 自注意力 | False | enc_out (3,10,8) | enc_out (3,10,8) | 10 | 10 |
| EncoderLayer 1 自注意力 | False | enc_out (3,6,8) | enc_out (3,6,8) | 6 | 6 |
| DecoderLayer masked 自注意力 | True | dec_out (3,12,8) | dec_out (3,12,8) | 12 | 12 |
| DecoderLayer cross 注意力 | False | dec_out (3,12,8) | enc_out (3,6,8) | 12 | 6 |
EncoderLayer.forward
└─ self.attention(x, x, x, ...) ← AttentionLayer (mask_flag=False)
└─ self.inner_attention(Q, K, V, ...) → 详见 Layer5 ProbAttention
DecoderLayer.forward
├─ self.self_attention(x, x, x, ...) ← AttentionLayer (mask_flag=True)
└─ self.cross_attention(x, cross, cross, ...) ← AttentionLayer (mask_flag=False)2. I/O 接口定义
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):以 EncoderLayer 0 自注意力(toy 基准)为例:
| shape(toy) | 含义 | |
|---|---|---|
输入 queries | (3, 10, 8) = (B, L_Q, d_model) | encoder token(自注意力时 Q=K=V) |
输入 keys | (3, 10, 8) | 同上 |
输入 values | (3, 10, 8) | 同上 |
输出 out | (3, 10, 8) | 注意力聚合后,out_projection 投影回 d_model |
输出 attn | None | output_attention=False 时为 None |
3. 顺序图(具体层)
⚠️ ProbAttention 返回的 context 格式是
(B, H, L_Q, D)=(3, 4, 10, 2),
与 PatchTST 的 FullAttention 返回(B, L, H, D)不同。.view(B, L, -1)前需要先理解 context 的轴顺序。
4. 语义分组图(索引层)
本层的核心是**"进来拆开,算完合并"**的对称结构,与 PatchTST 的 AttentionLayer 完全相同,只是内部注意力换成了 ProbAttention。
5. 逐步解析
5.0 完整原始代码
python
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn5.1 格式转换 IN:d_model → 多头
本节的作用
query_projection+.view(B, L, H, -1)把(B, L, d_model=8)拆成(B, L, H=4, d_keys=2)。本节重点解释:变量名 B/L/S/H 为什么对应张量的某个具体维度,以及这一对应关系的第一性原理。
步骤一:提取形状变量
python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_headsB=3,L=10,d_model=8(丢弃),S=10,H=4。
B、L、d_model 在张量里的具体含义
queries 是一个三维张量,三个轴分别对应:
| 轴 | 变量 | toy 值 | 语义 |
|---|---|---|---|
| dim0 | B | 3 | 批次:同一批 3 条时序,互相独立 |
| dim1 | L | 10 | 序列:每条时序有 10 个 token(时间步) |
| dim2 | _ | 8 | d_model:每个 token 的 8 维特征向量 |
这是 PyTorch 的==标准轴序 (B, L, d)==,由两个约定确立:nn.Linear 只作用于最后一轴(所以特征在 dim2),GPU 并行沿 batch 轴切割(所以 B 在 dim0)。
d_model 用 _ 丢弃,因为后续 .view(B, L, H, -1) 用 -1 自动推断,不需要单独取出。H = self.n_heads 是超参数,直接读配置,不从 shape 里取。
为什么 L 和 S 要分两行取
自注意力时 Q、K 来自同一张量,L=S;但 cross-attention 时 Q 来自 decoder(L=12),K 来自 encoder(S=6),两者不等。若共用同一个变量,keys.view(B, L, H, -1) 会把 S=6 误当 L=12,元素总数不匹配,.view() 直接报错。
步骤二:Q/K/V projection + .view() 拆多头
python
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)query_projection = nn.Linear(d_model=8, d_model=8),公式 (3, 10, 8)。
.view(B=3, L=10, H=4, -1) 把最后一维 8 重新解读为
==Linear 输出维度必须恰好等于 .view() 乘积不匹配会报错。
图解 — .view() 把 8 个连续 float 重新分组给 4 个 head:
.view() 是零拷贝操作
.view() 是零拷贝操作
.view()不分配新内存,只改变 ==stride==("如何步进")——同一块个 float,重新按 (3,10,4,2)解读。.split()/.chunk()会实际复制数据,在注意力这种计算密集的路径上代价更高。前提:内存必须连续。
nn.Linear的输出通常是连续的;若不连续(如前面有 transpose 且未.contiguous()),.view()会报错。
toy 数值追踪(batch=0, token t=0):
设 queries[0, 0, :] = [1.2, -0.3, 0.5, 0.8, -1.1, 0.2, 0.7, -0.4],= [0.1, -0.2, 0.3, 0.1, -0.1, 0.2, 0.0, 0.1]:
设 projection_out[0, 0, :] = [0.52, -0.18, 0.31, 0.67, -0.44, 0.25, 0.13, -0.29],.view(3, 10, 4, 2) 后:
queries[0, 0, 0, :] = [ 0.52, -0.18] ← Head 0
queries[0, 0, 1, :] = [ 0.31, 0.67] ← Head 1
queries[0, 0, 2, :] = [-0.44, 0.25] ← Head 2
queries[0, 0, 3, :] = [ 0.13, -0.29] ← Head 3
shape: (3, 10, 4, 2) ✓keys、values 同理:自注意力时 S=10 → (3, 10, 4, 2);cross-attention 时 S=6 → (3, 6, 4, 2)。
5.2 注意力计算(委托 ProbAttention,步骤三)
本节的作用
把格式转换后的 Q/K/V 交给 ProbAttention 做稀疏注意力计算;本层是纯粹的"接口桥梁",不包含任何注意力逻辑。
python
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)传入 queries=(3,10,4,2),keys=(3,10,4,2),values=(3,10,4,2),attn_mask=None。
ProbAttention 内部先 transpose(2,1) 变成 (B,H,L,D) 格式处理,
返回 out (context): (3, 4, 10, 2) = (B, H, L_Q, D),attn=None。
→ ProbAttention 的稀疏筛选逻辑见 04A-Layer5-ProbAttention。
5.3 格式转换 OUT:多头 → d_model(步骤四)
本节的作用
ProbAttention 返回的
(B, H, L_Q, D)经.view()合并多头回(B, L, d_model=8),再经 out_projection 做最终线性混合。
python
out = out.view(B, L, -1)
return self.out_projection(out), attnProbAttention 返回的 out 是 (3, 4, 10, 2) = (B, H, L_Q, D) 格式。
.view(3, 10, -1) 要求内存连续,且将 (H=4, D=2) 合并回 8。
⚠️ 注意轴顺序:
(B, H, L, D)的.view(B, L, H*D)不等价于(B, L, H, D)的.view(B, L, H*D)——两者内存布局不同。
ProbAttention 在context.contiguous()后返回,确保内存连续,.view()才能正确执行。
out: (3, 4, 10, 2).contiguous() → .view(3, 10, -1) → (3, 10, 8)
out_projection: Linear(8, 8) → (3, 10, 8) = new_x6. 下钻子组件
| 子组件 | 职责 | 下层文档 |
|---|---|---|
ProbAttention(self.inner_attention) | ProbSparse 稀疏注意力:M 分数筛选 top-u query + mean(V) 初始化 + 精确更新 | 04A-Layer5-ProbAttention |