Appearance
Layer 3 — DecoderLayer 精读
父层(Layer 2C)
Decoder.forward的循环体调用layer(x, cross, ...)。
本文档只覆盖DecoderLayer.forward这一层(三段式:masked self-attn + cross-attn + FFN)。
子层 AttentionLayer 及以下见 04-Layer4-AttentionLayer。
1. 在父层中的位置
Decoder.forward
└─ for layer in self.layers:
x = layer(x, cross, ...) ← 本文档
└─ self.self_attention(x, x, x, ...) → 详见 Layer4(mask_flag=True)
└─ self.cross_attention(x, cross, cross) → 详见 Layer4(mask_flag=False)
└─ conv1 → relu → conv22. I/O 接口定义
python
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):| shape(toy) | 含义 | |
|---|---|---|
输入 x | (3, 12, 8) = (B, dec_len, d_model) | decoder embedding 输出(label_len=5 + pred_len=7) |
输入 cross | (3, 6, 8) = (B, enc_seq, d_model) | encoder 最终输出,贯穿所有段 |
输出 x | (3, 12, 8) | 三段处理后的 decoder token 表示,形状不变 |
x_mask / cross_mask / tau / delta全为None。
3. 顺序图(具体层)
4. 语义分组图(索引层)
三段都有残差跳接 + LayerNorm,总共 3 对。核心是第二段 cross-attention:decoder 的 Q 去问 encoder 的 K/V,把历史序列信息拉入预测。
5. 宏观逻辑
Generative Decoder 的工作方式:
decoder 输入 x(3,12,8) 中,前 5 步是真实历史(start token),后 7 步是全零占位。三段处理后这 12 步全都被计算,父层只取后 7 步作为预测。这与 autoregressive 不同——整个 pred_len 一次性并行计算。
dec_input = [h1 h2 h3 h4 h5 | 0 0 0 0 0 0 0]
←label_len=5→ ←pred_len=7→
decoder 输出: [r1 r2 r3 r4 r5 | p1 p2 p3 p4 p5 p6 p7]
long_forecast() 取: [:, -7:, :] → (3, 7, 6)masked self-attention 为什么不破坏并行性?
训练时已经用 causal mask 让各预测位置不互相看,推理时直接喂零占位也能并行预测全部 pred_len 步,不需要逐步生成。
6. 逐步解析
6.0 完整原始代码
python
class DecoderLayer(nn.Module):
def __init__(
self,
self_attention,
cross_attention,
d_model,
d_ff=None,
dropout=0.1,
activation="relu",
):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
x = x + self.dropout(
self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
)
x = self.norm1(x)
x = x + self.dropout(
self.cross_attention(
x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)6.1 段 1:masked self-attention
本节的作用
Decoder token 的因果自注意力(Q=K=V=x),
mask_flag=True确保预测位置只看左边,保持并行预测的同时维持因果约束。
python
x = x + self.dropout(
self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
)
x = self.norm1(x)self.self_attention = AttentionLayer(ProbAttention(mask_flag=True, ...))。三个 x 都是 decoder token(自注意力 Q=K=V=x),[0] 取 out 丢弃 attn weights。
mask_flag=True → ProbAttention 内部用 cumsum 初始化 context,并施加因果掩码,确保每个 decoder 位置只能看到左边的 token。
new_x: (3,12,8) → dropout → 残差① → norm1 → x: (3,12,8)→ AttentionLayer 细节见 04-Layer4-AttentionLayer。
6.2 段 2:cross-attention
本节的作用
Q 来自 decoder(12步),K/V 来自 encoder 输出(6步);decoder 的每个位置通过 cross-attention 检索 encoder 的历史表示,是模型获取"输入序列上下文"的核心机制。
python
x = x + self.dropout(
self.cross_attention(
x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)
y = x = self.norm2(x)self.cross_attention = AttentionLayer(ProbAttention(mask_flag=False, ...))。
这里 Q ≠ K/V,是 cross-attention:
Q ← x: (3, 12, 8) ← decoder 的 12 步 queries
K ← cross: (3, 6, 8) ← encoder 的 6 步 keys(distilling 后)
V ← cross: (3, 6, 8) ← encoder 的 6 步 values注意力矩阵为 (B, H, L_Q=12, L_K=6):每个 decoder 位置用自己的 query 检索 encoder 的全部 key/value,吸取历史序列信息。mask_flag=False → encoder 所有位置对 decoder 全部可见。
ProbAttention 参数推导(cross-attn):
U_part = ceil(ln(L_K=6)) = 2 (随机采样 2 个 encoder key)
u = ceil(ln(L_Q=12)) = 3 (选出 3 个活跃 decoder query)
new_x: (3,12,8) → dropout → 残差② → y = x = norm2 → (3,12,8)y = x = self.norm2(x):norm2 后 x 和 y 同时指向结果,后续 FFN 修改 y,x 保持 norm2 值用于残差③。
6.3 段 3:FFN
本节的作用
与 EncoderLayer 结构完全相同的 Position-wise FFN(Conv1d k=1),对每个 decoder 位置独立做
8→24→8非线性变换 + 残差③。
python
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)与 EncoderLayer 完全相同的 Position-wise FFN:
y.transpose(-1,1): (3,12,8) → (3,8,12)
conv1(8→24, k=1): (3,8,12) → (3,24,12)
relu + dropout: (3,24,12)
conv2(24→8, k=1): (3,24,12) → (3,8,12)
transpose(-1,1): (3,8,12) → (3,12,8)
dropout: (3,12,8)
norm3(x + y): (3,12,8) ← 残差③ + LayerNorm,最终输出7. 下钻子组件
| 子组件 | 调用场景 | mask_flag | Q / K / V 来源 | 下层文档 |
|---|---|---|---|---|
self.self_attention | 段 1 | True(因果) | Q=K=V=x (decoder) | 04-Layer4-AttentionLayer |
self.cross_attention | 段 2 | False | Q=x (dec), K/V=cross (enc) | 04-Layer4-AttentionLayer |