Appearance
4D Decoder 主链
Abstract
这一篇是:
04-Level4-short_forecast五段总览里4D Decoder这个子块的下钻文档。只讲:
decoder 怎样把
x_dec自身状态与 encoder 输出的上下文融合起来,并最终投影回c_out个输出通道。
1. 上下文
上一层:
这一层的入口代码是:
python
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)这一层的输出是:
python
dec_out.shape = (B, L_dec, c_out)2. 当前层第一性
这一层存在的第一性是:
把 decoder 侧输入与 encoder 上下文融合,并从隐藏空间
d_model回到最终输出通道空间c_out。
3. 本层入口参数与输出含义
3.1 输入
x- decoder 当前隐藏表示,形状
(B, L_dec, d_model)
- decoder 当前隐藏表示,形状
cross- encoder 输出,形状
(B, L', d_model)
- encoder 输出,形状
d_layers- decoder 层数
n_heads- attention 头数
d_ff- FFN 中间维
c_out- 最终输出通道数
3.2 输出
dec_out- 已回到
c_out通道空间的 decoder 整段输出
- 已回到
4. 顺序图
5. 抽象树
6. 当前真实例子与 toy 例子
6.1 真实运行例子
当前真实例子里默认:
d_layers = 1n_heads = 8d_ff = 128d_model = 32c_out = 7
所以真实 decoder 主链是:
text
1 个 DecoderLayer -> LayerNorm -> Linear(32, 7)6.2 固定 toy 例子
B = 1L_dec = 4d_model = 4d_ff = 8c_out = 2d_layers = 1
python
x0 = [
[d11, d12, d13, d14],
[d21, d22, d23, d24],
[d31, d32, d33, d34],
[d41, d42, d43, d44],
] # (1, 4, 4)
cross0 = [
[c11, c12, c13, c14],
[c21, c22, c23, c24],
[c31, c32, c33, c34],
] # (1, L', 4)7. 代码块 1:DecoderLayer.forward(...)
位置:
完整代码:
python
class DecoderLayer(nn.Module):
def __init__(
self,
self_attention,
cross_attention,
d_model,
d_ff=None,
dropout=0.1,
activation="relu",
):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
x = x + self.dropout(
self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
)
x = self.norm1(x)
x = x + self.dropout(
self.cross_attention(
x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)7.1 子块 A:self-attention
对应代码:
python
x = x + self.dropout(
self.self_attention(x, x, x, ...)[0]
)
x = self.norm1(x)toy 张量演变图
text
输入 x0 = (1, 4, 4)
步骤 1: decoder 自己先看自己的输入
self_attn(x0) -> s1 = (1, 4, 4)
固定第 3 个时间步的 toy 注意力权重:
beta_3 = [0.1, 0.2, 0.5, 0.2]
那么第 3 个时间步的新表示是:
s1[3] = 0.1*x0[1] + 0.2*x0[2] + 0.5*x0[3] + 0.2*x0[4]
步骤 2: residual
x1 = x0 + s1 = (1, 4, 4)
步骤 3: norm1
x2 = norm1(x1) = (1, 4, 4)这一步的 input / output 语义
- 输入
x0- decoder 当前隐藏状态
- 输出
x2- decoder 在自身序列内做过信息交换后的状态
7.2 子块 B:cross-attention
对应代码:
python
x = x + self.dropout(
self.cross_attention(
x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)toy 张量演变图
text
输入:
x2 = (1, 4, 4)
cross0 = (1, 3, 4)
步骤 1: decoder 用 x2 去读取 encoder 上下文 cross0
cross_attn(x2, cross0, cross0) -> s2 = (1, 4, 4)
这里可以拆成 AttentionLayer 的那套接口语义:
queries = x2 # decoder 当前状态
keys = cross0 # encoder 上下文位置
values = cross0 # encoder 上下文内容
固定第 1 个 decoder 时间步对 encoder 三个位置的 toy 权重:
gamma_1 = [0.6, 0.3, 0.1]
如果 encoder 三个 value 向量是:
c1 = [c11, c12, c13, c14]
c2 = [c21, c22, c23, c24]
c3 = [c31, c32, c33, c34]
那么该步 cross-attention 输出就是:
s2[1] = 0.6*c1 + 0.3*c2 + 0.1*c3
步骤 2: residual
x3 = x2 + s2 = (1, 4, 4)
步骤 3: 这一段结束后
第 1 个 decoder 时间步已经不只是“看自己”,
还把 encoder 历史上下文读进来了这一步的 input / output 语义
- 输入
x2- decoder 当前状态
- 输入
cross0- encoder 提供的历史上下文
- 输出
x3- 已融合 encoder 上下文的 decoder 状态
再把这一段压成一句
text
x2(自己当前的状态)
-> 拿 x2 当 query
-> 去 cross0 里找最相关的历史上下文
-> 得到 s2
-> 与 x2 残差相加
-> x37.3 子块 C:FFN
对应代码:
python
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)toy 张量演变图
text
输入 x3 = (1, 4, 4)
步骤 1: norm2
x4 = (1, 4, 4)
步骤 2: transpose 给 Conv1d
x4_t = (1, 4, 4)
步骤 3: conv1: d_model=4 -> d_ff=8
h = (1, 8, 4)
这一步的语义不是“时间卷积预测”,
而是对每个时间步的隐藏向量做通道维扩展。
因为 kernel_size = 1,
所以它更像“逐时间步的线性层”。
步骤 4: conv2: d_ff=8 -> d_model=4
y_t = (1, 4, 4)
步骤 5: transpose 回来
y = (1, 4, 4)
步骤 6: residual + norm3
x5 = norm3(x4 + y) = (1, 4, 4)一个可算的 toy 小算例
固定某个时间步经过 norm2 后的隐藏向量是:
text
x4[t] = [1, 2, 0, 1]为了看懂 conv1 / conv2,先把它们当成 1x1 的线性变换。
假设 conv1 的前两行权重是:
text
w1 = [1, 0, 0, 0]
w2 = [0, 1, 1, 0]那么:
text
h1 = 1*1 + 0*2 + 0*0 + 0*1 = 1
h2 = 0*1 + 1*2 + 1*0 + 0*1 = 2也就是说,conv1 在每个时间步上把 4 维隐藏表示先扩到更大的 d_ff 空间。
再假设 conv2 取其中两行做回投影:
text
u1 = h1 + h2 = 3
u2 = h2 = 2那它就是在做:
text
更高维中间特征
-> 再映回 d_model这一步的 input / output 语义
- 输入
x3- 已融合上下文的 decoder 状态
- 输出
x5- 单层 decoder 最终隐藏状态
8. 代码块 2:Decoder.forward(...)
完整代码:
python
class Decoder(nn.Module):
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
for layer in self.layers:
x = layer(
x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
)
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x8.1 toy 张量演变图
text
输入:
x0 = (1, 4, 4)
cross0 = (1, 3, 4)
步骤 1: 经过 1 个 DecoderLayer
x5 = (1, 4, 4)
步骤 2: 最终 LayerNorm
x6 = (1, 4, 4)
步骤 3: projection: d_model=4 -> c_out=2
为了理解 projection,固定 toy 权重矩阵:
W = [ [1, 0, 0, 1],
[0, 1, 1, 0] ]
如果某个时间步 decoder 隐状态是 [u1, u2, u3, u4],
那么投影后两维输出是:
y1 = u1 + u4
y2 = u2 + u3
所以 projection 不只是改 shape,
它是在最后一步把 4 维隐藏表示重新组合成 2 个输出通道。
dec_out = (1, 4, 2)再把 projection 拆成一步一步
text
某个时间步输入隐藏向量:
x6[t] = [u1, u2, u3, u4]
projection 权重:
W =
[ [1, 0, 0, 1],
[0, 1, 1, 0] ]
输出第 1 维:
y1 = 1*u1 + 0*u2 + 0*u3 + 1*u4
输出第 2 维:
y2 = 0*u1 + 1*u2 + 1*u3 + 0*u4
所以最终:
x6[t] = [u1,u2,u3,u4]
-> dec_out[t] = [u1+u4, u2+u3]8.2 这一段的 input / output 语义
- 输入
x0- decoder 侧隐藏表示
- 输入
cross0- encoder 上下文表示
- 输出
dec_out- 已回到
c_out通道空间的 decoder 结果
- 已回到
9. 当前层真正要固定什么
- decoder 内部顺序是:
- self-attention
- cross-attention
- FFN
cross就是 encoder 输出- FFN 里的
conv1/conv2在这里更应该理解成“逐时间步通道变换” - 最后一层 projection 真正把最后一维从
d_model变成c_out - 当前真实例子里最终会把:
(B, 72, 32)变成(B, 72, 7)
10. 下一步
继续看: