Appearance
Layer4 AttentionLayer
覆盖
EncoderLayer.forward()中的self.attention(x, x, x, ...)调用。AttentionLayer 完成三件事:Linear 投影 Q/K/V →.view()拆分多头 →FullAttention计算注意力 → 还原形状 → out_projection。inner_attention的细节见 [[04A-Layer5-FullAttention]]。
§1 在父层中的位置
EncoderLayer.forward() 第一步:
python
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)- Q = K = V =
x (3, 9, 8)— Self-Attention,输入都是同一张量 - 输出:
new_x (3, 9, 8),attn = None
§2 I/O 接口定义
| 参数 | shape(toy) | 含义 |
|---|---|---|
queries (=x) | (3, 9, 8) | B=3,token_count=9,d_model=8 |
keys (=x) | (3, 9, 8) | Self-Attention:与 queries 相同张量 |
values (=x) | (3, 9, 8) | Self-Attention:与 queries 相同张量 |
attn_mask | None | iTransformer 不使用掩码 |
输出 out | (3, 9, 8) | 注意力加权聚合后的 token 表示 |
输出 attn | None | output_attention=False |
§3 顺序图
§4 语义分组图
§5 逐步精读
§5.0 完整原始代码
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn§5.1 宏观逻辑
用小例子(B=1,L=S=3,d_model=4,n_heads=2,d_keys=2)串起来:
① Q/K/V 各做一次 Linear(4,4),形状不变
queries/keys/values: (1, 3, 4) → (1, 3, 4)
② .view(1, 3, 2, 2):把 d_model=4 拆成 n_heads=2 × d_keys=2
相当于"把 4 维分给 2 个头,每头 2 维"
Head 0: queries[:, :, 0, :] shape (1, 3, 2)
Head 1: queries[:, :, 1, :] shape (1, 3, 2)
③ FullAttention 对 4 个张量 (B,L,H,d_keys) 做多头注意力
每头独立算一个 3×3 矩阵,输出 (1, 3, 2, 2)
④ .view(1, 3, 4):把 n_heads × d_keys = 2×2=4 合并
再做 Linear(4,4):out_projection,最终 (1, 3, 4)为什么先 Linear 再 .view,不直接让 Linear 输出 (B,L,H,d_keys) 形状?
nn.Linear 只作用于最后一维,必须先产出 (B,L,d_model) 再用 .view() 重排。.view() 不做任何计算,只是重新解释内存布局(等效于 reshape),零代价。
§5.2 步骤一:解包形状
python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_headsB=3, L=9(query token 数),S=9(key token 数),H=4- Self-Attention 时 L=S(queries、keys 来自同一 tensor)
_表示d_model=8,用不到(投影层内部定义了输入维度)
§5.3 步骤二:Q/K/V 投影
python
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)Linear 的语义:
query_projection = nn.Linear(d_model=8, n_heads × d_keys = 4×2 = 8)
nn.Linear 作用于最后一维,前面所有维度视为 batch:(3,9,8) 视为 27 条长度 8 的行向量,各自乘 (3,9,8)。
.view(B, L, H, -1) 的语义:
.view(3, 9, 4, -1) 把最后一维 8 拆成 (n_heads=4, d_keys=2),即
- 拆分前:
queries_proj (3, 9, 8),连续内存,每 token 有 8 个数 - 拆分后:
queries (3, 9, 4, 2),每 token 的 8 个数被"分配"给 4 个头,每头 2 个数
toy 数值(batch=0,token=0):
queries[0, 0, :] 原始 = [q0, q1, q2, q3, q4, q5, q6, q7] 8 维
经过 Linear(8,8) → [q'0, ..., q'7](线性混合)
.view(4, 2) 后:
Head 0: [q'0, q'1] ← queries[0, 0, 0, :]
Head 1: [q'2, q'3] ← queries[0, 0, 1, :]
Head 2: [q'4, q'5] ← queries[0, 0, 2, :]
Head 3: [q'6, q'7] ← queries[0, 0, 3, :]keys, values 同理,各自有独立的投影矩阵
§5.4 步骤三:FullAttention
python
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)self.inner_attention=FullAttention实例- 输入:
queries (3,9,4,2),keys (3,9,4,2),values (3,9,4,2) - 输出:
out (3,9,4,2),attn = None(iTransformer 的 FullAttention 用output_attention=False)
FullAttention 对每个 head 独立算 9×9 注意力矩阵,再用 V 加权求和。详见 [[04A-Layer5-FullAttention]]。
§5.5 步骤四:合并多头 + out_projection
python
out = out.view(B, L, -1)
return self.out_projection(out), attn.view(B, L, -1) 的语义:
.view(3, 9, -1):把 (n_heads=4, d_keys=2) 合并回 d_model=8,即
- 拆分前:
out (3, 9, 4, 2) - 合并后:
out (3, 9, 8)
这是 .view(B, L, H, -1) 的逆操作,内存布局不变。
out_projection = nn.Linear(d_model=8, d_model=8):
线性混合各头输出,得到最终 (3, 9, 8),作为 new_x 返回给 EncoderLayer。
toy 数值(batch=0,token=0):
out[0, 0, :, :] shape (4, 2) → .view(8) → [o0, ..., o7]
out_projection: [o0,...,o7] × W_O^T + b_O → [p0,...,p7]
new_x[0, 0, :] = [p0,...,p7] ← var_0 token 的注意力更新结果§6 初始化参数(toy)
AttentionLayer.__init__() 关键参数:
| 参数 | 值 | 含义 |
|---|---|---|
d_model | 8 | 输入/输出维度 |
n_heads | 4 | 注意力头数 |
d_keys | 2 | = d_model / n_heads = 8/4 |
d_values | 2 | = d_keys(通常相同) |
query_projection | Linear(8, 8) | 8 = n_heads × d_keys |
key_projection | Linear(8, 8) | 同上 |
value_projection | Linear(8, 8) | 同上 |
out_projection | Linear(8, 8) | 合并后输出投影 |
inner_attention | FullAttention(...) | 见下层 |
§7 下钻子组件
| 组件 | 输入 shape | 输出 shape | 下层文档 |
|---|---|---|---|
FullAttention | Q/K/V: (3,9,4,2) | (3,9,4,2) | [[04A-Layer5-FullAttention]] |