Skip to content

Layer 4 — AttentionLayer 精读

由 EncoderLayer(03B1-Layer3-EncoderLayer)和 DecoderLayer(03C1-Layer3-DecoderLayer)调用,共计 4 种场景。
本文档只覆盖 AttentionLayer.forward 这一层(d_model ↔ 多头格式桥梁)。
子层 ProbAttention 见 04A-Layer5-ProbAttention


1. 在父层中的位置

AttentionLayer 被 4 个地方调用,统一接口,差异只在 mask_flag 和 Q/K/V 来源:

调用位置mask_flagQ 来源K/V 来源L_QL_K
EncoderLayer 0 自注意力Falseenc_out (3,10,8)enc_out (3,10,8)1010
EncoderLayer 1 自注意力Falseenc_out (3,6,8)enc_out (3,6,8)66
DecoderLayer masked 自注意力Truedec_out (3,12,8)dec_out (3,12,8)1212
DecoderLayer cross 注意力Falsedec_out (3,12,8)enc_out (3,6,8)126
EncoderLayer.forward
  └─ self.attention(x, x, x, ...)              ← AttentionLayer (mask_flag=False)
       └─ self.inner_attention(Q, K, V, ...)   → 详见 Layer5 ProbAttention

DecoderLayer.forward
  ├─ self.self_attention(x, x, x, ...)         ← AttentionLayer (mask_flag=True)
  └─ self.cross_attention(x, cross, cross, ...) ← AttentionLayer (mask_flag=False)

2. I/O 接口定义

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):

以 EncoderLayer 0 自注意力(toy 基准)为例:

shape(toy)含义
输入 queries(3, 10, 8) = (B, L_Q, d_model)encoder token(自注意力时 Q=K=V)
输入 keys(3, 10, 8)同上
输入 values(3, 10, 8)同上
输出 out(3, 10, 8)注意力聚合后,out_projection 投影回 d_model
输出 attnNoneoutput_attention=False 时为 None

3. 顺序图(具体层)

⚠️ ProbAttention 返回的 context 格式是 (B, H, L_Q, D) = (3, 4, 10, 2)
与 PatchTST 的 FullAttention 返回 (B, L, H, D) 不同。
.view(B, L, -1) 前需要先理解 context 的轴顺序。


4. 语义分组图(索引层)

本层的核心是**"进来拆开,算完合并"**的对称结构,与 PatchTST 的 AttentionLayer 完全相同,只是内部注意力换成了 ProbAttention。


5. 逐步解析

5.0 完整原始代码

python
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads


	def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
	    B, L, _ = queries.shape
	    _, S, _ = keys.shape
	    H = self.n_heads
	
	    queries = self.query_projection(queries).view(B, L, H, -1)
	    keys = self.key_projection(keys).view(B, S, H, -1)
	    values = self.value_projection(values).view(B, S, H, -1)
	
	    out, attn = self.inner_attention(
	        queries, keys, values, attn_mask, tau=tau, delta=delta
	    )
	    out = out.view(B, L, -1)
	
	    return self.out_projection(out), attn

5.1 格式转换 IN:d_model → 多头

本节的作用

query_projection + .view(B, L, H, -1)(B, L, d_model=8) 拆成 (B, L, H=4, d_keys=2)。本节重点解释:变量名 B/L/S/H 为什么对应张量的某个具体维度,以及这一对应关系的第一性原理。


步骤一:提取形状变量

python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads

B=3L=10d_model=8(丢弃),S=10H=4


B、L、d_model 在张量里的具体含义

queries 是一个三维张量,三个轴分别对应:

变量toy 值语义
dim0B3批次:同一批 3 条时序,互相独立
dim1L10序列:每条时序有 10 个 token(时间步)
dim2_8d_model:每个 token 的 8 维特征向量

这是 PyTorch 的==标准轴序 (B, L, d)==,由两个约定确立:nn.Linear 只作用于最后一轴(所以特征在 dim2),GPU 并行沿 batch 轴切割(所以 B 在 dim0)。

d_model_ 丢弃,因为后续 .view(B, L, H, -1)-1 自动推断,不需要单独取出。H = self.n_heads 是超参数,直接读配置,不从 shape 里取。


为什么 L 和 S 要分两行取

自注意力时 Q、K 来自同一张量,L=S;但 cross-attention 时 Q 来自 decoder(L=12),K 来自 encoder(S=6),两者不等。若共用同一个变量,keys.view(B, L, H, -1) 会把 S=6 误当 L=12,元素总数不匹配,.view() 直接报错。


步骤二:Q/K/V projection + .view() 拆多头

python
queries = self.query_projection(queries).view(B, L, H, -1)
keys    = self.key_projection(keys).view(B, S, H, -1)
values  = self.value_projection(values).view(B, S, H, -1)

query_projection = nn.Linear(d_model=8, d_model=8),公式 y=xWQTWQR8×8,作用在最后一维,输出仍为 (3, 10, 8)

.view(B=3, L=10, H=4, -1) 把最后一维 8 重新解读为 (H=4, dkeys=?),PyTorch 由 1=8÷4=2 自动推断:

dkeys=dmodelH=84=2

==Linear 输出维度必须恰好等于 H×dkeys==,否则 .view() 乘积不匹配会报错。

图解 — .view() 把 8 个连续 float 重新分组给 4 个 head:

.view() 是零拷贝操作

.view() 不分配新内存,只改变 ==stride==("如何步进")——同一块 3×10×8=240 个 float,重新按 (3,10,4,2) 解读。.split() / .chunk() 会实际复制数据,在注意力这种计算密集的路径上代价更高。

前提:内存必须连续。nn.Linear 的输出通常是连续的;若不连续(如前面有 transpose 且未 .contiguous()),.view() 会报错。

toy 数值追踪(batch=0, token t=0):

queries[0, 0, :] = [1.2, -0.3, 0.5, 0.8, -1.1, 0.2, 0.7, -0.4]WQ 第 0 行 = [0.1, -0.2, 0.3, 0.1, -0.1, 0.2, 0.0, 0.1]

p0=1.2×0.1+(0.3)×(0.2)+0.5×0.3+0.8×0.1+(1.1)×(0.1)+0.2×0.2+0.7×0+(0.4)×0.1=0.52

projection_out[0, 0, :] = [0.52, -0.18, 0.31, 0.67, -0.44, 0.25, 0.13, -0.29].view(3, 10, 4, 2) 后:

queries[0, 0, 0, :] = [ 0.52, -0.18]   ← Head 0
queries[0, 0, 1, :] = [ 0.31,  0.67]   ← Head 1
queries[0, 0, 2, :] = [-0.44,  0.25]   ← Head 2
queries[0, 0, 3, :] = [ 0.13, -0.29]   ← Head 3
shape: (3, 10, 4, 2) ✓

keysvalues 同理:自注意力时 S=10(3, 10, 4, 2);cross-attention 时 S=6(3, 6, 4, 2)


5.2 注意力计算(委托 ProbAttention,步骤三)

本节的作用

把格式转换后的 Q/K/V 交给 ProbAttention 做稀疏注意力计算;本层是纯粹的"接口桥梁",不包含任何注意力逻辑。

python
out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)

传入 queries=(3,10,4,2)keys=(3,10,4,2)values=(3,10,4,2)attn_mask=None

ProbAttention 内部先 transpose(2,1) 变成 (B,H,L,D) 格式处理,
返回 out (context): (3, 4, 10, 2) = (B, H, L_Q, D)attn=None

→ ProbAttention 的稀疏筛选逻辑见 04A-Layer5-ProbAttention


5.3 格式转换 OUT:多头 → d_model(步骤四)

本节的作用

ProbAttention 返回的 (B, H, L_Q, D).view() 合并多头回 (B, L, d_model=8),再经 out_projection 做最终线性混合。

python
out = out.view(B, L, -1)

return self.out_projection(out), attn

ProbAttention 返回的 out(3, 4, 10, 2) = (B, H, L_Q, D) 格式。

.view(3, 10, -1) 要求内存连续,且将 (H=4, D=2) 合并回 8

⚠️ 注意轴顺序:(B, H, L, D).view(B, L, H*D) 不等价于 (B, L, H, D).view(B, L, H*D)——两者内存布局不同。
ProbAttention 在 context.contiguous() 后返回,确保内存连续,.view() 才能正确执行。

out: (3, 4, 10, 2).contiguous() → .view(3, 10, -1) → (3, 10, 8)
out_projection: Linear(8, 8) → (3, 10, 8) = new_x

6. 下钻子组件

子组件职责下层文档
ProbAttentionself.inner_attentionProbSparse 稀疏注意力:M 分数筛选 top-u query + mean(V) 初始化 + 精确更新04A-Layer5-ProbAttention

*记录并在线阅读我的笔记*