Skip to content

Layer 3 — DecoderLayer 精读

父层(Layer 2C)Decoder.forward 的循环体调用 layer(x, cross, ...)
本文档只覆盖 DecoderLayer.forward 这一层(三段式:masked self-attn + cross-attn + FFN)。
子层 AttentionLayer 及以下见 04-Layer4-AttentionLayer


1. 在父层中的位置

Decoder.forward
  └─ for layer in self.layers:
         x = layer(x, cross, ...)   ← 本文档
              └─ self.self_attention(x, x, x, ...)    → 详见 Layer4(mask_flag=True)
              └─ self.cross_attention(x, cross, cross) → 详见 Layer4(mask_flag=False)
              └─ conv1 → relu → conv2

2. I/O 接口定义

python
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
shape(toy)含义
输入 x(3, 12, 8) = (B, dec_len, d_model)decoder embedding 输出(label_len=5 + pred_len=7)
输入 cross(3, 6, 8) = (B, enc_seq, d_model)encoder 最终输出,贯穿所有段
输出 x(3, 12, 8)三段处理后的 decoder token 表示,形状不变

x_mask / cross_mask / tau / delta 全为 None


3. 顺序图(具体层)


4. 语义分组图(索引层)

三段都有残差跳接 + LayerNorm,总共 3 对。核心是第二段 cross-attention:decoder 的 Q 去问 encoder 的 K/V,把历史序列信息拉入预测。


5. 宏观逻辑

Generative Decoder 的工作方式

decoder 输入 x(3,12,8) 中,前 5 步是真实历史(start token),后 7 步是全零占位。三段处理后这 12 步全都被计算,父层只取后 7 步作为预测。这与 autoregressive 不同——整个 pred_len 一次性并行计算。

dec_input = [h1 h2 h3 h4 h5 | 0  0  0  0  0  0  0]
              ←label_len=5→    ←pred_len=7→

decoder 输出: [r1 r2 r3 r4 r5 | p1 p2 p3 p4 p5 p6 p7]
long_forecast() 取: [:, -7:, :] → (3, 7, 6)

masked self-attention 为什么不破坏并行性?

训练时已经用 causal mask 让各预测位置不互相看,推理时直接喂零占位也能并行预测全部 pred_len 步,不需要逐步生成。


6. 逐步解析

6.0 完整原始代码

python
class DecoderLayer(nn.Module):
    def __init__(
        self,
        self_attention,
        cross_attention,
        d_model,
        d_ff=None,
        dropout=0.1,
        activation="relu",
    ):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu


	def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
	    x = x + self.dropout(
	        self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
	    )
	    x = self.norm1(x)
	
	    x = x + self.dropout(
	        self.cross_attention(
	            x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
	        )[0]
	    )
	
	    y = x = self.norm2(x)
	    y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
	    y = self.dropout(self.conv2(y).transpose(-1, 1))
	
	    return self.norm3(x + y)

6.1 段 1:masked self-attention

本节的作用

Decoder token 的因果自注意力(Q=K=V=x),mask_flag=True 确保预测位置只看左边,保持并行预测的同时维持因果约束。

python
x = x + self.dropout(
    self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
)
x = self.norm1(x)

self.self_attention = AttentionLayer(ProbAttention(mask_flag=True, ...))。三个 x 都是 decoder token(自注意力 Q=K=V=x),[0] 取 out 丢弃 attn weights。

mask_flag=True → ProbAttention 内部用 cumsum 初始化 context,并施加因果掩码,确保每个 decoder 位置只能看到左边的 token。

new_x: (3,12,8) → dropout → 残差① → norm1 → x: (3,12,8)

→ AttentionLayer 细节见 04-Layer4-AttentionLayer


6.2 段 2:cross-attention

本节的作用

Q 来自 decoder(12步),K/V 来自 encoder 输出(6步);decoder 的每个位置通过 cross-attention 检索 encoder 的历史表示,是模型获取"输入序列上下文"的核心机制。

python
x = x + self.dropout(
    self.cross_attention(
        x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
    )[0]
)
y = x = self.norm2(x)

self.cross_attention = AttentionLayer(ProbAttention(mask_flag=False, ...))

这里 Q ≠ K/V,是 cross-attention:

Q ← x:     (3, 12, 8)  ← decoder 的 12 步 queries
K ← cross: (3,  6, 8)  ← encoder 的 6 步 keys(distilling 后)
V ← cross: (3,  6, 8)  ← encoder 的 6 步 values

注意力矩阵为 (B, H, L_Q=12, L_K=6):每个 decoder 位置用自己的 query 检索 encoder 的全部 key/value,吸取历史序列信息。mask_flag=False → encoder 所有位置对 decoder 全部可见。

ProbAttention 参数推导(cross-attn):
  U_part = ceil(ln(L_K=6)) = 2  (随机采样 2 个 encoder key)
  u      = ceil(ln(L_Q=12)) = 3  (选出 3 个活跃 decoder query)
new_x: (3,12,8) → dropout → 残差② → y = x = norm2 → (3,12,8)

y = x = self.norm2(x):norm2 后 xy 同时指向结果,后续 FFN 修改 yx 保持 norm2 值用于残差③。


6.3 段 3:FFN

本节的作用

与 EncoderLayer 结构完全相同的 Position-wise FFN(Conv1d k=1),对每个 decoder 位置独立做 8→24→8 非线性变换 + 残差③。

python
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))

return self.norm3(x + y)

与 EncoderLayer 完全相同的 Position-wise FFN:

y.transpose(-1,1): (3,12,8) → (3,8,12)
conv1(8→24, k=1): (3,8,12) → (3,24,12)
relu + dropout: (3,24,12)
conv2(24→8, k=1): (3,24,12) → (3,8,12)
transpose(-1,1): (3,8,12) → (3,12,8)
dropout: (3,12,8)

norm3(x + y): (3,12,8)  ← 残差③ + LayerNorm,最终输出

7. 下钻子组件

子组件调用场景mask_flagQ / K / V 来源下层文档
self.self_attention段 1True(因果)Q=K=V=x (decoder)04-Layer4-AttentionLayer
self.cross_attention段 2FalseQ=x (dec), K/V=cross (enc)04-Layer4-AttentionLayer

*记录并在线阅读我的笔记*