Appearance
Layer 2B — DSAttention(De-stationary Attention)
1. 在父层中的位置
forecast() 中 Encoder 和 Decoder 的每个 AttentionLayer 内嵌了 DSAttention 作为 inner_attention。AttentionLayer 完成 Q/K/V 的线性投影和多头拆分后,把 tau/delta 一并传入 DSAttention.forward()。
2. I/O 接口定义
AttentionLayer 接口(封装层,负责 Q/K/V 投影):
| 参数 | Shape | 含义 |
|---|---|---|
queries | (2, 12, 8) | encoder self-attn:Q 来自 enc_out |
keys | (2, 12, 8) | encoder self-attn:K 来自 enc_out |
values | (2, 12, 8) | encoder self-attn:V 来自 enc_out |
tau | (2, 1) | 从 forecast() 传入 |
delta | (2, 12) | 从 forecast() 传入(encoder 中使用) |
| 输出 | (2, 12, 8) | 注意力输出,形状不变 |
DSAttention.forward 接口(投影之后):
| 参数 | Shape | 含义 |
|---|---|---|
queries | (2, 12, 4, 2) | (B, L, H, d_keys),多头拆分后 |
keys | (2, 12, 4, 2) | (B, S, H, d_keys) |
values | (2, 12, 4, 2) | (B, S, H, d_values) |
tau | (2, 1) → 内部扩展 (2,1,1,1) | |
delta | (2, 12) → 内部扩展 (2,1,1,12) | |
| 输出 V | (2, 12, 4, 2) → view → (2,12,8) |
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码
python
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class DSAttention(nn.Module):
"""De-stationary Attention"""
def __init__(self, mask_flag=True, factor=5, scale=None,
attention_dropout=0.1, output_attention=False):
super(DSAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
delta = (
0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1)
) # B x 1 x 1 x S
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None§5.1 宏观逻辑
核心设计意图:在标准 Scaled Dot-Product Attention 的分子上引入两个可学习的调制项,使注意力分布能反映原始信号的非平稳特性。
标准 Attention 公式:
DSAttention 公式:
两者的差别只在 softmax 的输入:
τ 的物理直觉:
δ 的物理直觉:δ 是 shape (B, 1, 1, S) 的向量,对每个 key 位置 s 施加独立的加法偏置。均值 mean_enc 大的序列(整体偏高),delta 可能在趋势末端 (S≈12) 加正偏置,使模型更关注较新的时间步。
小例子(B=1, L=S=4, H=1, d_keys=1):
scores = QK^T = [[0.5, 0.1, 0.2, 0.8]] (1×4 矩阵,query=t₃ 对4个key的得分)
若 tau=2.0, delta=[0.3, 0.1, -0.1, 0.5]:
scores' = scores * 2.0 + delta
= [1.0+0.3, 0.2+0.1, 0.4-0.1, 1.6+0.5]
= [1.3, 0.3, 0.3, 2.1]
scale = 1/sqrt(1) = 1.0
softmax([1.3, 0.3, 0.3, 2.1]) ≈ [0.18, 0.07, 0.07, 0.68]与无调制版的 softmax([0.5, 0.1, 0.2, 0.8]) ≈ [0.29, 0.19, 0.21, 0.31] 对比,调制后模型更集中关注 key=t₃(最后位置),且 tau 的放大使分布更尖锐。
注意力复杂度分析
DSAttention 使用标准全局注意力(FullAttention 的变种),scores (B,H,L,S) 中 L=S=seq_len=12。 复杂度为
(toy 参数)。 与 Informer 的 ProbSparse 相比,复杂度更高。 Non-stationary Transformer 没有解决效率问题,其核心贡献是非平稳信息恢复,不是效率优化。
§5.2 AttentionLayer — Q/K/V 投影与多头拆分
python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)toy 参数(Encoder self-attention):L=S=seq_len=12,H=n_heads=4,d_keys=2。
query_projection:Linear(d_model=8, d_keys×n_heads=2×4=8),输入 (2,12,8),输出 (2,12,8)。
.view(B, L, H, -1) = .view(2, 12, 4, -1) → -1 = 8/(4) = 2 → 输出 (2, 12, 4, 2)。
key/value 同理:(2, 12, 4, 2)。
toy 数值:queries[0, 0, :, :] 是 batch=0 时间步 t₀ 的 4 个 head 各 2 维的查询向量,形如 [[q₀₀, q₀₁], [q₁₀, q₁₁], [q₂₀, q₂₁], [q₃₀, q₃₁]]。
§5.3 DSAttention — tau/delta 扩展
python
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
delta = (
0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1)
) # B x 1 x 1 x Stau 从 forecast() 传入为 (2, 1)(Projector 输出经 .exp())。
.unsqueeze(1) → (2, 1, 1),再 .unsqueeze(1) → (2, 1, 1, 1)。
delta 从 forecast() 传入为 (2, 12)(Projector 输出)。
.unsqueeze(1) → (2, 1, 12),再 .unsqueeze(1) → (2, 1, 1, 12)。
注意:两次 unsqueeze 发生在不同位置:
- tau: (2,1) → (2,1,1) → (2,1,1,1) — 在 dim=1 和 dim=1 连续插入
- delta: (2,12) → (2,1,12) → (2,1,1,12) — 结果末尾 12 对应 S 维
默认值处理:tau is None 时取 1.0(相当于不缩放),delta is None 时取 0.0(相当于不偏移)。Decoder self-attention 传 delta=None 就走这个路径。
§5.4 DSAttention — scores 计算
python
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + deltaeinsum("blhe,bshe->bhls"):queries (2,12,4,2) × keys (2,12,4,2) → scores (2,4,12,12)。
维度解释:b=batch,l=query时间步,h=head,e=d_keys;s=key时间步。输出 bhls = (batch, head, query_len, key_len)。
toy 数值:scores[0,0,0,:] 是 batch=0,head=0,query时间步 t₀ 对所有 12 个 key 位置的原始点积,范围约 [-2, 2](随机初始化范围)。
* tau:(2,4,12,12) × (2,1,1,1) → 广播 → (2,4,12,12)。tau 是正标量(.exp() 保证),乘完后 scores 整体缩放。
toy 值:若 tau[0,0] = exp(0.3) ≈ 1.35,则 scores[0,0,0,3] 从 0.8 → 0.8×1.35 = 1.08。
+ delta:(2,4,12,12) + (2,1,1,12) → 广播 → (2,4,12,12)。delta 的最后维度 12 对应 key 位置 s,每个 key 位置加独立偏置。
toy 值:delta[0,0,0,:] = [δ₀, δ₁, ..., δ₁₁],scores[0,0,0,3] 再加 δ₃。
§5.5 DSAttention — mask、softmax、V 加权
python
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)mask_flag:Encoder 的 DSAttention(mask_flag=False) 不做 mask(全局注意力);Decoder 的 self-attention DSAttention(mask_flag=True) 做下三角 causal mask(防止看未来)。Cross-attention DSAttention(mask_flag=False) 也不 mask。
scale = 1 / sqrt(E) = 1 / sqrt(d_keys=2) ≈ 0.707。
torch.softmax(scale * scores, dim=-1):在最后一维(key 位置 S=12)做 softmax → A shape (2, 4, 12, 12),每行(每个 query)的权重之和为 1。
einsum("bhls,bshd->blhd"):A (2,4,12,12) × values (2,12,4,2) → V_out (2,12,4,2)。
toy 数值:V_out[0,0,0,:] 是 batch=0,时间步 t₀,head=0 的 2 维输出,是 12 个 value 向量按注意力权重加权求和的结果。
out.view + out_projection(回到 AttentionLayer):
python
out = out.view(B, L, -1)
return self.out_projection(out), attnout.view(2, 12, -1) = out.view(2, 12, 8)(4 heads × d_values=2 = 8)→ (2, 12, 8)。
out_projection:Linear(d_values×n_heads=8, d_model=8) → (2, 12, 8),形状不变(因为 d_model = d_values×n_heads)。
6. DSAttention vs FullAttention 对比
| 维度 | FullAttention | DSAttention |
|---|---|---|
| scores 计算 | einsum * scale → softmax | einsum * tau + delta → * scale → softmax |
| tau | 无(等效 1.0) | 从 Projector(x_raw, std_enc) 学习 |
| delta | 无(等效 0.0) | 从 Projector(x_raw, mean_enc) 学习 |
| 非平稳信息 | 丢失 | 从统计量恢复注入 |
| 新增参数量 | — | 2 × Projector(约 37K,默认配置) |
| 代码差异 | scores = einsum * scale | scores = einsum * tau + delta,然后 scale * scores |
注意代码中 scale 的施加顺序:先 einsum * tau + delta,再 scale * scores。展开写是