Skip to content

Level 3 forward 主链

Abstract

这一篇只讲一件事:

已经进入 Informer.forward(...) 以后,代码怎样把四输入一路送进 short_forecast(...),再从 short_forecast(...) 拿回整段 decoder 输出,并在 forward(...) 里截出最后 pred_len 段。

1. 上下文

上一层:

下一层:

这一层的入口接口是:

python
forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None)

这一层的输出是:

python
output.shape = (B, pred_len, c_out)

2. 当前层第一性

这一层存在的第一性是:

把外层统一四输入接到 Informer 的 forecasting 主体上,并明确最终只保留未来 pred_len 段。

3. 本层入口参数与输出含义

3.1 输入

  • x_enc
    • encoder 侧历史数值窗口,形状 (B, seq_len, enc_in)
  • x_mark_enc
    • encoder 侧历史时间特征,形状 (B, seq_len, time_dim)
  • x_dec
    • decoder 侧数值输入,形状 (B, label_len + pred_len, dec_in)
  • x_mark_dec
    • decoder 侧时间特征,形状 (B, label_len + pred_len, time_dim)
  • task_name
    • 决定 forward(...) 走哪条任务分支
  • pred_len
    • 决定最终保留多少步输出

3.2 输出

  • dec_out
    • short_forecast(...) 的整段 decoder 输出,长度还是 label_len + pred_len
  • output
    • forward(...) 最终返回值,只保留最后 pred_len

4. 顺序图

5. 抽象树

6. 当前真实例子与 toy 例子

6.1 真实运行例子

当前真实命令下,最关键的 Informer 参数是:

  • task_name = "short_term_forecast"
  • seq_len = 96
  • label_len = 48(adapter 默认)
  • pred_len = 24
  • enc_in = dec_in = c_out = 7
  • d_model = 32
  • embed = "timeF"
  • freq = "h"
  • e_layers = 2
  • d_layers = 1
  • n_heads = 8
  • distil = True

所以真实输入输出是:

text
x_enc:      (4, 96, 7)
x_mark_enc: (4, 96, 4)
x_dec:      (4, 72, 7)
x_mark_dec: (4, 72, 4)
-> short_forecast(...) 返回 dec_out: (4, 72, 7)
-> forward(...) 截取最后 pred_len=24 段
-> output: (4, 24, 7)

6.2 固定 toy 例子

后面每段代码都贴着这组 toy 张量讲:

  • B = 1
  • seq_len = 4
  • label_len = 2
  • pred_len = 2
  • enc_in = dec_in = c_out = 2
  • d_model = 4
python
x_enc = [
    [1, 10],
    [2, 11],
    [3, 12],
    [4, 13],
]  # (1, 4, 2)

x_mark_enc = [
    [0.10, 0.20, 0.30, 0.40],
    [0.20, 0.20, 0.30, 0.50],
    [0.30, 0.20, 0.30, 0.60],
    [0.40, 0.20, 0.30, 0.70],
]  # (1, 4, 4)

x_dec = [
    [3, 12],
    [4, 13],
    [0, 0],
    [0, 0],
]  # (1, 4, 2)

x_mark_dec = [
    [0.30, 0.20, 0.30, 0.60],
    [0.40, 0.20, 0.30, 0.70],
    [0.50, 0.20, 0.30, 0.80],
    [0.60, 0.20, 0.30, 0.90],
]  # (1, 4, 4)

7. 代码块 1:forward(...)

位置:

完整代码:

python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if self.task_name == "long_term_forecast":
        dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len :, :]  # [B, L, D]
    if self.task_name == "short_term_forecast":
        dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len :, :]  # [B, L, D]
    if self.task_name == "imputation":
        dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
        return dec_out  # [B, L, D]
    if self.task_name == "anomaly_detection":
        dec_out = self.anomaly_detection(x_enc)
        return dec_out  # [B, L, D]
    if self.task_name == "classification":
        dec_out = self.classification(x_enc, x_mark_enc)
        return dec_out  # [B, N]
    return None

带中文注释的完整代码:

python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    # task_name 决定当前到底走哪条任务分支
    if self.task_name == "long_term_forecast":
        dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len :, :]

    if self.task_name == "short_term_forecast":
        # 当前 benchmark 例子真正进入这里
        dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        # 最后只保留未来 pred_len 段
        return dec_out[:, -self.pred_len :, :]

    if self.task_name == "imputation":
        dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
        return dec_out

    if self.task_name == "anomaly_detection":
        dec_out = self.anomaly_detection(x_enc)
        return dec_out

    if self.task_name == "classification":
        dec_out = self.classification(x_enc, x_mark_enc)
        return dec_out

    return None

7.1 贴着这段代码讲:分支选择

当前真实例子里:

  • task_name = "short_term_forecast"

所以这段代码实际只会执行:

python
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :]

7.2 这一段的 toy 张量演变图

text
输入:
  x_enc      = (1, 4, 2)
  x_mark_enc = (1, 4, 4)
  x_dec      = (1, 4, 2)
  x_mark_dec = (1, 4, 4)

步骤 1: task_name 判断
  short_term_forecast -> 进入 short_forecast(...)
  张量本身不变,只是控制流分支被固定

步骤 2: short_forecast(...) 返回整段 dec_out
  dec_out = (1, 4, 2)

步骤 3: 只保留最后 pred_len=2 步
  output = dec_out[:, -2:, :]
         = (1, 2, 2)

7.3 这一段的 input / output 语义

  • 输入 x_enc
    • encoder 侧历史数值窗口
  • 输入 x_mark_enc
    • encoder 侧历史时间特征
  • 输入 x_dec
    • decoder 侧“历史尾部 + 未来占位”数值输入
  • 输入 x_mark_dec
    • decoder 侧对应时间特征
  • 中间 dec_out
    • decoder 整段输出,长度还是 label_len + pred_len
  • 输出 output
    • 真正要拿来和监督对齐的未来 pred_len

8. 代码块 2:short_forecast(...)

完整代码:

python
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # Normalization
    mean_enc = x_enc.mean(1, keepdim=True).detach()  # B x 1 x E
    x_enc = x_enc - mean_enc
    std_enc = torch.sqrt(
        torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
    ).detach()  # B x 1 x E
    x_enc = x_enc / std_enc

    enc_out = self.enc_embedding(x_enc, x_mark_enc)
    dec_out = self.dec_embedding(x_dec, x_mark_dec)
    enc_out, attns = self.encoder(enc_out, attn_mask=None)

    dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)

    dec_out = dec_out * std_enc + mean_enc
    return dec_out  # [B, L, D]

带中文注释的完整代码:

python
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # 1. 只对 encoder 侧数值输入做标准化
    mean_enc = x_enc.mean(1, keepdim=True).detach()
    x_enc = x_enc - mean_enc
    std_enc = torch.sqrt(
        torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
    ).detach()
    x_enc = x_enc / std_enc

    # 2. encoder / decoder 两侧分别做 embedding
    enc_out = self.enc_embedding(x_enc, x_mark_enc)
    dec_out = self.dec_embedding(x_dec, x_mark_dec)

    # 3. 历史窗口先过 encoder
    enc_out, attns = self.encoder(enc_out, attn_mask=None)

    # 4. decoder 同时看自己的输入和 encoder 输出
    dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)

    # 5. 把结果映回原始数值尺度
    dec_out = dec_out * std_enc + mean_enc
    return dec_out

8.1 子块 A:标准化

对应代码:

python
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(
    torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
).detach()
x_enc = x_enc / std_enc

toy 张量演变图

text
原始 x_enc:
[
  [1, 10],
  [2, 11],
  [3, 12],
  [4, 13],
]  (1, 4, 2)

步骤 1: 按时间维求均值
mean_enc = [ [2.5, 11.5] ]  (1, 1, 2)

步骤 2: 去均值
x_centered =
[
  [-1.5, -1.5],
  [-0.5, -0.5],
  [ 0.5,  0.5],
  [ 1.5,  1.5],
]  (1, 4, 2)

步骤 3: 求标准差
std_enc ≈ [ [1.118, 1.118] ]  (1, 1, 2)

步骤 4: 标准化
x_norm ≈
[
  [-1.342, -1.342],
  [-0.447, -0.447],
  [ 0.447,  0.447],
  [ 1.342,  1.342],
]  (1, 4, 2)

这一步的 input / output 语义

  • 输入 x_enc
    • 原始 encoder 数值窗口
  • 中间 mean_enc
    • 每个通道在整段历史上的均值
  • 中间 std_enc
    • 每个通道在整段历史上的尺度
  • 输出 x_norm
    • 已被拉到稳定尺度的 encoder 输入

8.2 子块 B:两侧 embedding

对应代码:

python
enc_out = self.enc_embedding(x_enc, x_mark_enc)
dec_out = self.dec_embedding(x_dec, x_mark_dec)

toy 张量演变图

text
encoder 侧:
  x_enc      = (1, 4, 2)
  x_mark_enc = (1, 4, 4)
  -> enc_embedding(...)
  -> enc_out = (1, 4, 4)

decoder 侧:
  x_dec      = (1, 4, 2)
  x_mark_dec = (1, 4, 4)
  -> dec_embedding(...)
  -> dec_out = (1, 4, 4)

DataEmbedding 内部三条支路怎么一步步编码,放到 04B-DataEmbedding 细讲。

这一步的 input / output 语义

  • 输入 x_enc/x_dec
    • 原始数值窗口
  • 输入 x_mark_enc/x_mark_dec
    • 时间特征
  • 输出 enc_out/dec_out
    • 映射到统一 d_model 隐空间后的表示

8.3 子块 C:encoder

对应代码:

python
enc_out, attns = self.encoder(enc_out, attn_mask=None)

toy 张量演变图

text
输入:
  enc_out = (1, 4, 4)

步骤 1: 第 1 层 EncoderLayer
  -> 仍是 (1, 4, 4)

步骤 2: distil=True 时经过 ConvLayer 压缩长度
  -> 可能变成 (1, 2, 4) 或 (1, 3, 4)

步骤 3: 最后一层 EncoderLayer
  -> 仍保持最后一维 d_model=4

输出:
  enc_out = (1, L', 4)
  attns   = 注意力权重集合

这一步的 input / output 语义

  • 输入 enc_out
    • embedding 后的历史隐藏表示
  • 输出 enc_out
    • 供 decoder 读取的上下文表示
  • 输出 attns
    • 各层 attention 权重集合

8.4 子块 D:decoder

对应代码:

python
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)

toy 张量演变图

text
输入:
  dec_out = (1, 4, 4)
  enc_out = (1, L', 4)

步骤 1: decoder 自己先在内部做 self-attention
  -> (1, 4, 4)

步骤 2: 再读取 encoder 上下文 cross-attention
  -> (1, 4, 4)

步骤 3: 最后 projection: d_model=4 -> c_out=2
  -> dec_out = (1, 4, 2)

这一步的 input / output 语义

  • 输入 dec_out
    • decoder 当前隐藏表示
  • 输入 enc_out
    • encoder 上下文
  • 输出 dec_out
    • 已回到输出通道空间的 decoder 结果

8.5 子块 E:反标准化

对应代码:

python
dec_out = dec_out * std_enc + mean_enc

toy 张量演变图

text
输入:
  decoder 输出 dec_out = (1, 4, 2)
  std_enc  = (1, 1, 2)
  mean_enc = (1, 1, 2)

步骤:
  每个时间步、每个通道做
  dec_out[channel] * std_enc[channel] + mean_enc[channel]

输出:
  反标准化后的 dec_out = (1, 4, 2)

这一步的 input / output 语义

  • 输入 dec_out
    • 标准化尺度上的 decoder 结果
  • 输入 std_enc / mean_enc
    • 从 encoder 输入侧保存下来的尺度信息
  • 输出 dec_out
    • 回到原始数值尺度的 decoder 结果

9. 当前层真正要固定什么

  1. forward(...) 当前真实例子只走 short_term_forecast
  2. short_forecast(...) 先返回整段 dec_out
  3. 真正的最终输出还要回到 forward(...) 再截 pred_len
  4. 当前真实例子里最关键的控制参数是:
    • task_name
    • pred_len
    • d_model
    • enc_in / dec_in / c_out
    • embed = timeF
    • e_layers / d_layers / n_heads / d_ff / distil

10. 下一步

继续看:

*记录并在线阅读我的笔记*