Skip to content

4C-1 AttentionLayer

Abstract

这一篇是:

04C-Encoder主链self.attention(x, x, x, ...) 这一步的第一层下钻。

只讲:

AttentionLayer 怎样把 (B, L, d_model) 的输入投影成多头 Q/K/V,再把内层 attention 的输出重新拼回 d_model

1. 上下文

上一层:

下一层:

入口代码:

python
out, attn = self.inner_attention(queries, keys, values, attn_mask, tau=tau, delta=delta)
return self.out_projection(out), attn

2. 当前层第一性

这一层存在的第一性是:

把单路隐藏表示拆成多头 Q/K/V,送进真正的 attention 机制,再把多头结果拼回模型主干需要的 d_model 维。

3. 本层输入输出含义

3.1 输入

  • queries
    • 形状 (B, L, d_model)
  • keys
    • 形状 (B, S, d_model)
  • values
    • 形状 (B, S, d_model)
  • n_heads
    • 头数
  • d_keys
    • 每头 query/key 维度,默认 d_model // n_heads
  • d_values
    • 每头 value 维度,默认 d_model // n_heads

3.2 输出

  • out
    • 形状 (B, L, d_model)
  • attn
    • 注意力权重信息,由内层 attention 决定是否返回

4. 顺序图

5. 抽象树

6. 当前真实例子与 toy 例子

6.1 真实运行例子

当前真实例子里:

  • d_model = 32
  • n_heads = 8
  • 所以默认:
    • d_keys = 4
    • d_values = 4

6.2 固定 toy 例子

为了算得动,这里固定:

  • B = 1
  • L = S = 4
  • d_model = 4
  • n_heads = 2
  • d_keys = d_values = 2
python
queries = [
    [1, 0, 2, 1],
    [0, 1, 1, 2],
    [1, 1, 0, 1],
    [2, 0, 1, 0],
]  # (1, 4, 4)

keys = queries
values = [
    [10, 11, 20, 21],
    [12, 13, 22, 23],
    [14, 15, 24, 25],
    [16, 17, 26, 27],
]  # (1, 4, 4)

7. 代码块 1:AttentionLayer.forward(...)

位置:

完整代码:

python
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries, keys, values, attn_mask, tau=tau, delta=delta
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

8. 子块 A:query_projection / key_projection / value_projection

对应代码:

python
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)

8.1 toy 张量演变图

text
输入:
  queries = (1, 4, 4)
  keys    = (1, 4, 4)
  values  = (1, 4, 4)

当前 toy 配置:
  d_model = 4
  n_heads = 2
  d_keys = d_values = 2

步骤 1: Linear(4 -> 4)
  先把每个时间步的 4 维向量,映射到总共 4 维的 Q/K/V 空间

步骤 2: reshape 成多头
  (1, 4, 4) -> (1, 4, 2, 2)
  含义:
    4 个时间步
    每个时间步拆成 2 个头
    每个头 2 维

8.2 一个可算的 toy 小算例

固定 query_projection 的 toy 权重矩阵:

text
Wq =
[ [1,0,0,0],
 [0,1,0,0],
 [0,0,1,0],
 [0,0,0,1] ]

那第 1 个时间步:

text
q1_in = [1, 0, 2, 1]
q1_out = q1_in * Wq = [1, 0, 2, 1]

再按两头切开:

text
head1_q1 = [1, 0]
head2_q1 = [2, 1]

同理,k1v1 也会各自被切成两头。

8.3 这一段的 input / output 语义

  • 输入 queries/keys/values
    • 单路隐藏表示
  • 输出 queries/keys/values(重赋值后)
    • 已拆成多头的 Q/K/V

9. 子块 B:inner_attention(...)

对应代码:

python
out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)

9.1 toy 张量演变图

text
输入:
  Q = (1, 4, 2, 2)
  K = (1, 4, 2, 2)
  V = (1, 4, 2, 2)

步骤:
  每个 head 各自做 attention
  head1: (4, 2) x (4, 2) x (4, 2)
  head2: (4, 2) x (4, 2) x (4, 2)

输出:
  out = (1, 4, 2, 2)
  attn = 注意力权重

9.2 这一段的 input / output 语义

  • 输入 Q/K/V
    • 多头 query/key/value
  • 输出 out
    • 每个 head 各自完成加权聚合后的结果
  • 输出 attn
    • 注意力权重

真正 scores -> softmax -> weighted sum 的内部细节,放到 04C-2-ProbAttention 展开。

10. 子块 C:拼回主干维度 + out_projection

对应代码:

python
out = out.view(B, L, -1)
return self.out_projection(out), attn

10.1 toy 张量演变图

text
输入:
  out = (1, 4, 2, 2)

步骤 1: view(B, L, -1)
  把 2 个 head 拼回最后一维
  -> (1, 4, 4)

步骤 2: out_projection: Linear(4 -> 4)
  -> (1, 4, 4)

10.2 一个可算的 toy 小算例

固定 out_projection 的 toy 权重矩阵:

text
Wo =
[ [1,0,0,0],
 [0,1,0,0],
 [0,0,1,0],
 [0,0,0,1] ]

如果某个时间步两个头的输出分别是:

text
head1 = [7, 8]
head2 = [9, 10]

拼接后:

text
concat = [7, 8, 9, 10]

再经过 Wo

text
out = [7, 8, 9, 10]

如果 Wo 不是单位阵,那它就会重新混合两个头的信息。

10.3 这一段的 input / output 语义

  • 输入 out(多头)
    • 各个 head 的 attention 输出
  • 输出 out_projection(out)
    • 回到模型主干 d_model 维空间的结果

11. 当前层真正要固定什么

  1. AttentionLayer 自己不算 attention 分数
  2. 它负责:
    • 把输入投影成多头 Q/K/V
    • 调用内层 attention
    • 把多头结果拼回 d_model
  3. 当前真实例子里:
    • d_model = 32
    • n_heads = 8
    • 所以每头 4

12. 下一步

继续看:

*记录并在线阅读我的笔记*