Skip to content

DSAttention:去平稳注意力

Abstract

DSAttentionFullAttention 几乎相同。

唯一核心差异是:softmax 前的 score 不再只是 QK^T,而是 QK^T * tau + delta

1. 图解

![[zdocs/pytorch-basics/assets/self_attention_full_ds.svg]]

图中下半部分就是 DSAttention:在标准 attention 分数上乘 tau、加 delta

2. 源码关键段

python
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1)
delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1)

scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta

其余逻辑和 FullAttention 一致:

python
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)

3. Shape 解释

toy:

text
queries: (2,5,2,3)
keys:    (2,6,2,3)
values:  (2,6,2,4)

原始 score:

text
QK: (B,H,L,S) = (2,2,5,6)

tau

python
tau.unsqueeze(1).unsqueeze(1)

如果原始 tau.shape=(B,)=(2,),则:

text
(2,) -> (2,1) -> (2,1,1)

广播到:

text
(B,H,L,S) = (2,2,5,6)

delta

python
delta.unsqueeze(1).unsqueeze(1)

如果 delta.shape=(B,S)=(2,6),则:

text
(2,6) -> (2,1,6) -> (2,1,1,6)

广播到:

text
(2,2,5,6)

4. 数学公式

标准 attention:

scoresb,h,l,s=eQb,l,h,eKb,s,h,e

DSAttention:

scoresb,h,l,sDS=τbscoresb,h,l,s+δb,s

然后:

Ab,h,l,:=softmax(scoresb,h,l,:DSE)

5. toy 数字直觉

假设某个 query 的原始 score 是:

text
scores = [2.0, 1.0, 0.0]

如果:

text
tau = 0.5
delta = [0.0, 1.0, 0.0]

则:

scoresDS=0.5[2,1,0]+[0,1,0]=[1.0,1.5,0.0]

原来第 0 个 key 分数最高,修正后第 1 个 key 分数最高。

6. 什么时候看它

如果模型是普通 Informer / PatchTST,通常不会走 DSAttention

如果看到 Non-stationary Transformer 或带 tau/delta 的去平稳机制,就要回到这个类。

*记录并在线阅读我的笔记*