Skip to content

Layer 3 — DecoderLayer 精读

Decoder.forward() 循环调用([[03C-Layer2C-Decoder]])。
本层覆盖:masked self-attention + decomp1 + cross-attention + decomp2 + FFN + decomp3 + trend 路由。


1. 在父层中的位置

Decoder.forward()
  └─ for layer in self.layers:
       x, residual_trend = layer(x, cross)   ← DecoderLayer(本文档)
            ├─ self.self_attention(x,x,x)    → 详见 04-Layer4-AutoCorrelationLayer
            └─ self.cross_attention(x, enc_out, enc_out)  → 详见 04-Layer4-AutoCorrelationLayer

2. I/O 接口定义

shape含义
输入 x(2, 10, 8)seasonal 表示(decoder embedding 输出或上层输出)
输入 cross(2, 12, 8)encoder 输出(cross-attention 的 K/V 来源)
输出 x(2, 10, 8)更新后的 seasonal 表示
输出 residual_trend(2, 10, 5)3 次 decomp 提取的趋势增量之和,经 Conv1d 投影

3. 顺序图(具体层)


4. 语义分组图(索引层)


5. 逐步精读

5.0 完整原始代码

python
class DecoderLayer(nn.Module):
    def __init__(
        self,
        self_attention,
        cross_attention,
        d_model,
        c_out,
        d_ff=None,
        moving_avg=25,
        dropout=0.1,
        activation="relu",
    ):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(
            in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
        )
        self.conv2 = nn.Conv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
        )
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.decomp3 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.projection = nn.Conv1d(
            in_channels=d_model,
            out_channels=c_out,
            kernel_size=3,
            stride=1,
            padding=1,
            padding_mode="circular",
            bias=False,
        )
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
        x, trend1 = self.decomp1(x)
        x = x + self.dropout(
            self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
        )
        x, trend2 = self.decomp2(x)
        y = x
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        x, trend3 = self.decomp3(x + y)

        residual_trend = trend1 + trend2 + trend3
        residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(
            1, 2
        )
        return x, residual_trend

5.1 宏观逻辑:3 段 + 3 次 decomp 的设计

为什么每段后都做 decomp?

每次注意力或 FFN 操作都可能混入趋势(低频)信息。若不及时剥离,趋势会通过多次残差连接不断累积,最终污染 seasonal 路径。

每段后立即 decomp 相当于"实时滤波":seasonal 路径始终保持高频周期特征,trend 路径收集所有被剥离的低频分量。

shape 变化链:

x: (2,10,8)
→ +self_attn → (2,10,8) → decomp1 → x(2,10,8) + trend1(2,10,8)
→ +cross_attn → (2,10,8) → decomp2 → x(2,10,8) + trend2(2,10,8)
→ +FFN → (2,10,8) → decomp3 → x(2,10,8) + trend3(2,10,8)
trend1+trend2+trend3: (2,10,8) → Conv1d(8→5) → residual_trend(2,10,5)

5.2 段一:masked self-attention + decomp1

python
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
x, trend1 = self.decomp1(x)

self.self_attentionAutoCorrelationLayer,初始化时 mask_flag=True(对应 AutoCorrelation(True, ...)),用于 decoder 的自注意力 causal masking。返回值是 (output, attn) 元组,[0] 取 output。

残差相加后,decomp1(x) 分解出 seasonal x (2,10,8)trend1 (2,10,8)

mask_flag 的作用

Autoformer 的 AutoCorrelation 在 training 时基于整个序列计算 lag(batch-norm 风格),mask_flag 的具体效果在 AutoCorrelation 内部处理。Decoder 的 self-attention 对未来时间步 mask,防止信息泄露。

toy:x[0,0,:] = [0.3, -0.1, ...](8维 seasonal),trend1[0,0,:] = [0.05, 0.02, ...](moving avg 提取的低频分量)。


5.3 段二:cross-attention + decomp2

python
x = x + self.dropout(
    self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
)
x, trend2 = self.decomp2(x)

self.cross_attentionAutoCorrelationLayermask_flag=False
Q 来自 decoder x (2,10,8),K/V 来自 encoder 输出 cross (2,12,8)

cross-attention 中 LQ=10(decoder dec_len),LK=LV=12(encoder seq_len)。AutoCorrelation 的 forward() 在 L > S 时会对 V/K 补零,L < S 时截断,这里 LQ=10<LK=12,故 V/K 被截断到 10。

decomp2 提取 trend2 (2,10,8)


5.4 段三:FFN + decomp3

python
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
x, trend3 = self.decomp3(x + y)

与 EncoderLayer 的 FFN 完全相同:Conv1d(k=1) 实现 Position-wise FFN,8→16→8。decomp3 提取 trend3 (2,10,8)


5.5 trend 路由:汇聚 + Conv1d 投影

python
residual_trend = trend1 + trend2 + trend3
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)

self.projection = nn.Conv1d(in_channels=8, out_channels=5, kernel_size=3, stride=1, padding=1, padding_mode="circular")

为什么用 Conv1d 而不是 Linear 投影 trend?

Conv1d vs Linear 投影 trend 的区别

seasonal 路径用 nn.Linear(8→5) 投影——因为每个时间步独立,Position-wise 线性变换足够。

trend 路径用 Conv1d(k=3, pad=1, circular) 投影——kernel size=3 使 trend 投影有局部时间感受野,可以在投影时做轻微的时间平滑,与 trend 的低频平滑特性一致。

同时 Conv1d 的 circular padding 与 TokenEmbedding、moving_avg 等保持相同的边界处理风格。

Conv1d 需要 permute:

residual_trend: (2, 10, 8)   ← (B, L, d_model)
.permute(0, 2, 1): (2, 8, 10)   ← Conv1d 需要 (B, C, L)
Conv1d(8→5, k=3, p=1): (2, 5, 10)
  L_out = floor((10+2×1-3)/1+1) = 10 ✓
.transpose(1, 2): (2, 10, 5)   ← 还原 (B, L, c_out)

toy 数值:trend1[0,6,0]=0.02trend2[0,6,0]=0.01trend3[0,6,0]=0.00
residual_trend[0,6,0] (before proj) =0.03,经 Conv1d 投影后 residual_trend[0,6,0]=0.03(近似,kernel=3 有少量平滑)。


6. 下钻子组件

子组件职责下层文档
AutoCorrelationLayerLinear Q/K/V 投影 + view + FFT 互相关委托[[04-Layer4-AutoCorrelationLayer]]
series_decomp(decomp1, decomp2, decomp3)moving_avg → trend;x − trend → seasonal[[03D-Layer2D-SeriesDecomp]]

*记录并在线阅读我的笔记*