Skip to content

Layer 2B — DSAttention(De-stationary Attention)

1. 在父层中的位置

forecast() 中 Encoder 和 Decoder 的每个 AttentionLayer 内嵌了 DSAttention 作为 inner_attentionAttentionLayer 完成 Q/K/V 的线性投影和多头拆分后,把 tau/delta 一并传入 DSAttention.forward()

2. I/O 接口定义

AttentionLayer 接口(封装层,负责 Q/K/V 投影):

参数Shape含义
queries(2, 12, 8)encoder self-attn:Q 来自 enc_out
keys(2, 12, 8)encoder self-attn:K 来自 enc_out
values(2, 12, 8)encoder self-attn:V 来自 enc_out
tau(2, 1)从 forecast() 传入
delta(2, 12)从 forecast() 传入(encoder 中使用)
输出(2, 12, 8)注意力输出,形状不变

DSAttention.forward 接口(投影之后):

参数Shape含义
queries(2, 12, 4, 2)(B, L, H, d_keys),多头拆分后
keys(2, 12, 4, 2)(B, S, H, d_keys)
values(2, 12, 4, 2)(B, S, H, d_values)
tau(2, 1) → 内部扩展 (2,1,1,1)
delta(2, 12) → 内部扩展 (2,1,1,12)
输出 V(2, 12, 4, 2) → view → (2,12,8)

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries, keys, values, attn_mask, tau=tau, delta=delta
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


class DSAttention(nn.Module):
    """De-stationary Attention"""

    def __init__(self, mask_flag=True, factor=5, scale=None,
                 attention_dropout=0.1, output_attention=False):
        super(DSAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1.0 / sqrt(E)

        tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1)  # B x 1 x 1 x 1
        delta = (
            0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1)
        )  # B x 1 x 1 x S

        scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

§5.1 宏观逻辑

核心设计意图:在标准 Scaled Dot-Product Attention 的分子上引入两个可学习的调制项,使注意力分布能反映原始信号的非平稳特性。

标准 Attention 公式:A=softmax(QKTE)

DSAttention 公式:A=softmax(QKTτ+δE)

两者的差别只在 softmax 的输入:

QKT标准得分τ全局温度+δ位置偏置

τ 的物理直觉τ>1 放大所有 score,softmax 后分布更尖锐(更关注高权重位置);τ<1 压缩 score,分布更均匀。标准差 std_enc 大的序列(波动剧烈),学到的 tau 倾向于更大,使注意力更专注于少数关键时间步。

δ 的物理直觉:δ 是 shape (B, 1, 1, S) 的向量,对每个 key 位置 s 施加独立的加法偏置。均值 mean_enc 大的序列(整体偏高),delta 可能在趋势末端 (S≈12) 加正偏置,使模型更关注较新的时间步。

小例子(B=1, L=S=4, H=1, d_keys=1):

scores = QK^T = [[0.5, 0.1, 0.2, 0.8]]  (1×4 矩阵,query=t₃ 对4个key的得分)

若 tau=2.0, delta=[0.3, 0.1, -0.1, 0.5]:
  scores' = scores * 2.0 + delta
          = [1.0+0.3, 0.2+0.1, 0.4-0.1, 1.6+0.5]
          = [1.3, 0.3, 0.3, 2.1]

scale = 1/sqrt(1) = 1.0
softmax([1.3, 0.3, 0.3, 2.1]) ≈ [0.18, 0.07, 0.07, 0.68]

与无调制版的 softmax([0.5, 0.1, 0.2, 0.8]) ≈ [0.29, 0.19, 0.21, 0.31] 对比,调制后模型更集中关注 key=t₃(最后位置),且 tau 的放大使分布更尖锐。

注意力复杂度分析

DSAttention 使用标准全局注意力(FullAttention 的变种),scores (B,H,L,S) 中 L=S=seq_len=12。 复杂度为 O(L2)=O(144)(toy 参数)。 与 Informer 的 ProbSparse O(LlogL)=O(12×3.5843) 相比,复杂度更高。 Non-stationary Transformer 没有解决效率问题,其核心贡献是非平稳信息恢复,不是效率优化。

§5.2 AttentionLayer — Q/K/V 投影与多头拆分

python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads

queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)

toy 参数(Encoder self-attention):L=S=seq_len=12,H=n_heads=4,d_keys=2。

query_projection:Linear(d_model=8, d_keys×n_heads=2×4=8),输入 (2,12,8),输出 (2,12,8)。

.view(B, L, H, -1) = .view(2, 12, 4, -1) → -1 = 8/(4) = 2 → 输出 (2, 12, 4, 2)

key/value 同理:(2, 12, 4, 2)。

toy 数值:queries[0, 0, :, :] 是 batch=0 时间步 t₀ 的 4 个 head 各 2 维的查询向量,形如 [[q₀₀, q₀₁], [q₁₀, q₁₁], [q₂₀, q₂₁], [q₃₀, q₃₁]]。

§5.3 DSAttention — tau/delta 扩展

python
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1)  # B x 1 x 1 x 1
delta = (
    0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1)
)  # B x 1 x 1 x S

tauforecast() 传入为 (2, 1)(Projector 输出经 .exp())。

.unsqueeze(1) → (2, 1, 1),再 .unsqueeze(1)(2, 1, 1, 1)

deltaforecast() 传入为 (2, 12)(Projector 输出)。

.unsqueeze(1) → (2, 1, 12),再 .unsqueeze(1)(2, 1, 1, 12)

注意:两次 unsqueeze 发生在不同位置

  • tau: (2,1) → (2,1,1) → (2,1,1,1) — 在 dim=1 和 dim=1 连续插入
  • delta: (2,12) → (2,1,12) → (2,1,1,12) — 结果末尾 12 对应 S 维

默认值处理:tau is None 时取 1.0(相当于不缩放),delta is None 时取 0.0(相当于不偏移)。Decoder self-attention 传 delta=None 就走这个路径。

§5.4 DSAttention — scores 计算

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta

einsum("blhe,bshe->bhls"):queries (2,12,4,2) × keys (2,12,4,2) → scores (2,4,12,12)。

维度解释:b=batch,l=query时间步,h=head,e=d_keys;s=key时间步。输出 bhls = (batch, head, query_len, key_len)。

toy 数值:scores[0,0,0,:] 是 batch=0,head=0,query时间步 t₀ 对所有 12 个 key 位置的原始点积,范围约 [-2, 2](随机初始化范围)。

* tau:(2,4,12,12) × (2,1,1,1) → 广播 → (2,4,12,12)。tau 是正标量(.exp() 保证),乘完后 scores 整体缩放。

toy 值:若 tau[0,0] = exp(0.3) ≈ 1.35,则 scores[0,0,0,3] 从 0.8 → 0.8×1.35 = 1.08。

+ delta:(2,4,12,12) + (2,1,1,12) → 广播 → (2,4,12,12)。delta 的最后维度 12 对应 key 位置 s,每个 key 位置加独立偏置。

toy 值:delta[0,0,0,:] = [δ₀, δ₁, ..., δ₁₁],scores[0,0,0,3] 再加 δ₃。

§5.5 DSAttention — mask、softmax、V 加权

python
if self.mask_flag:
    if attn_mask is None:
        attn_mask = TriangularCausalMask(B, L, device=queries.device)
    scores.masked_fill_(attn_mask.mask, -np.inf)

A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)

mask_flag:Encoder 的 DSAttention(mask_flag=False) 不做 mask(全局注意力);Decoder 的 self-attention DSAttention(mask_flag=True) 做下三角 causal mask(防止看未来)。Cross-attention DSAttention(mask_flag=False) 也不 mask。

scale = 1 / sqrt(E) = 1 / sqrt(d_keys=2) ≈ 0.707。

torch.softmax(scale * scores, dim=-1):在最后一维(key 位置 S=12)做 softmax → A shape (2, 4, 12, 12),每行(每个 query)的权重之和为 1。

einsum("bhls,bshd->blhd"):A (2,4,12,12) × values (2,12,4,2) → V_out (2,12,4,2)。

toy 数值:V_out[0,0,0,:] 是 batch=0,时间步 t₀,head=0 的 2 维输出,是 12 个 value 向量按注意力权重加权求和的结果。

out.view + out_projection(回到 AttentionLayer)

python
out = out.view(B, L, -1)
return self.out_projection(out), attn

out.view(2, 12, -1) = out.view(2, 12, 8)(4 heads × d_values=2 = 8)→ (2, 12, 8)

out_projection:Linear(d_values×n_heads=8, d_model=8) → (2, 12, 8),形状不变(因为 d_model = d_values×n_heads)。


6. DSAttention vs FullAttention 对比

维度FullAttentionDSAttention
scores 计算einsum * scale → softmaxeinsum * tau + delta → * scale → softmax
tau无(等效 1.0)从 Projector(x_raw, std_enc) 学习
delta无(等效 0.0)从 Projector(x_raw, mean_enc) 学习
非平稳信息丢失从统计量恢复注入
新增参数量2 × Projector(约 37K,默认配置)
代码差异scores = einsum * scalescores = einsum * tau + delta,然后 scale * scores

注意代码中 scale 的施加顺序:先 einsum * tau + delta,再 scale * scores。展开写是 QKTτ+δE,等价于 QKTτE+δE——scale 同时作用在两项上。

*记录并在线阅读我的笔记*