Skip to content

FullAttention:标准缩放点积注意力

Abstract

FullAttention 是标准 Transformer attention:所有 query 都和所有 key 计算相似度。

PatchTST 的 Encoder 中使用的就是 AttentionLayer(FullAttention(...))

1. 图解

![[zdocs/pytorch-basics/assets/self_attention_full_ds.svg]]

本图上半部分就是 FullAttentionscores = QK^T,再 softmax,最后乘 V

2. 源码

python
class FullAttention(nn.Module):
    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

3. Shape 链

toy:

text
queries: (B,L,H,E) = (2,5,2,3)
keys:    (B,S,H,E) = (2,6,2,3)
values:  (B,S,H,D) = (2,6,2,4)

第一步:

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)

输出:

text
scores: (B,H,L,S) = (2,2,5,6)

第二步:

python
A = torch.softmax(scale * scores, dim=-1)

输出:

text
A: (2,2,5,6)

第三步:

python
V = torch.einsum("bhls,bshd->blhd", A, values)

输出:

text
V: (B,L,H,D) = (2,5,2,4)

4. 数学公式

点积分数:

scoresb,h,l,s=e=1EQb,l,h,eKb,s,h,e

缩放:

scale=1E

注意力权重:

Ab,h,l,:=softmax(scalescoresb,h,l,:)

加权求和:

Vb,l,h,d=s=1SAb,h,l,sVb,s,h,d

5. toy 数字

只看一个 batch、一个 head、一个 query:

text
scores = [2.0, 1.0, 0.0]

softmax 后:

text
A = [0.665, 0.245, 0.090]

如果:

text
v0 = [10, 0]
v1 = [0, 20]
v2 = [10, 10]

则:

out=0.665v0+0.245v1+0.090v2=[7.55,5.80]

6. mask 的意义

如果 mask_flag=True,并且没传 attn_mask

python
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)

这会把不允许看的位置设成 -inf,softmax 后对应权重接近 0。

在 PatchTST encoder self-attention 里通常不需要 causal mask;在 decoder 自回归场景才常见。

*记录并在线阅读我的笔记*