Appearance
FullAttention:标准缩放点积注意力
Abstract
FullAttention是标准 Transformer attention:所有 query 都和所有 key 计算相似度。PatchTST 的 Encoder 中使用的就是
AttentionLayer(FullAttention(...))。
1. 图解
![[zdocs/pytorch-basics/assets/self_attention_full_ds.svg]]
本图上半部分就是 FullAttention:scores = QK^T,再 softmax,最后乘 V。
2. 源码
python
class FullAttention(nn.Module):
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None3. Shape 链
toy:
text
queries: (B,L,H,E) = (2,5,2,3)
keys: (B,S,H,E) = (2,6,2,3)
values: (B,S,H,D) = (2,6,2,4)第一步:
python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)输出:
text
scores: (B,H,L,S) = (2,2,5,6)第二步:
python
A = torch.softmax(scale * scores, dim=-1)输出:
text
A: (2,2,5,6)第三步:
python
V = torch.einsum("bhls,bshd->blhd", A, values)输出:
text
V: (B,L,H,D) = (2,5,2,4)4. 数学公式
点积分数:
缩放:
注意力权重:
加权求和:
5. toy 数字
只看一个 batch、一个 head、一个 query:
text
scores = [2.0, 1.0, 0.0]softmax 后:
text
A = [0.665, 0.245, 0.090]如果:
text
v0 = [10, 0]
v1 = [0, 20]
v2 = [10, 10]则:
6. mask 的意义
如果 mask_flag=True,并且没传 attn_mask:
python
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)这会把不允许看的位置设成 -inf,softmax 后对应权重接近 0。
在 PatchTST encoder self-attention 里通常不需要 causal mask;在 decoder 自回归场景才常见。