Appearance
Level 3 forward 主链
Abstract
这一篇只讲一件事:
已经进入
Informer.forward(...)以后,代码怎样把四输入一路送进short_forecast(...),再从short_forecast(...)拿回整段 decoder 输出,并在forward(...)里截出最后pred_len段。
1. 上下文
上一层:
下一层:
这一层的入口接口是:
python
forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None)这一层的输出是:
python
output.shape = (B, pred_len, c_out)2. 当前层第一性
这一层存在的第一性是:
把外层统一四输入接到 Informer 的 forecasting 主体上,并明确最终只保留未来
pred_len段。
3. 本层入口参数与输出含义
3.1 输入
x_enc- encoder 侧历史数值窗口,形状
(B, seq_len, enc_in)
- encoder 侧历史数值窗口,形状
x_mark_enc- encoder 侧历史时间特征,形状
(B, seq_len, time_dim)
- encoder 侧历史时间特征,形状
x_dec- decoder 侧数值输入,形状
(B, label_len + pred_len, dec_in)
- decoder 侧数值输入,形状
x_mark_dec- decoder 侧时间特征,形状
(B, label_len + pred_len, time_dim)
- decoder 侧时间特征,形状
task_name- 决定
forward(...)走哪条任务分支
- 决定
pred_len- 决定最终保留多少步输出
3.2 输出
dec_outshort_forecast(...)的整段 decoder 输出,长度还是label_len + pred_len
outputforward(...)最终返回值,只保留最后pred_len步
4. 顺序图
5. 抽象树
6. 当前真实例子与 toy 例子
6.1 真实运行例子
当前真实命令下,最关键的 Informer 参数是:
task_name = "short_term_forecast"seq_len = 96label_len = 48(adapter 默认)pred_len = 24enc_in = dec_in = c_out = 7d_model = 32embed = "timeF"freq = "h"e_layers = 2d_layers = 1n_heads = 8distil = True
所以真实输入输出是:
text
x_enc: (4, 96, 7)
x_mark_enc: (4, 96, 4)
x_dec: (4, 72, 7)
x_mark_dec: (4, 72, 4)
-> short_forecast(...) 返回 dec_out: (4, 72, 7)
-> forward(...) 截取最后 pred_len=24 段
-> output: (4, 24, 7)6.2 固定 toy 例子
后面每段代码都贴着这组 toy 张量讲:
B = 1seq_len = 4label_len = 2pred_len = 2enc_in = dec_in = c_out = 2d_model = 4
python
x_enc = [
[1, 10],
[2, 11],
[3, 12],
[4, 13],
] # (1, 4, 2)
x_mark_enc = [
[0.10, 0.20, 0.30, 0.40],
[0.20, 0.20, 0.30, 0.50],
[0.30, 0.20, 0.30, 0.60],
[0.40, 0.20, 0.30, 0.70],
] # (1, 4, 4)
x_dec = [
[3, 12],
[4, 13],
[0, 0],
[0, 0],
] # (1, 4, 2)
x_mark_dec = [
[0.30, 0.20, 0.30, 0.60],
[0.40, 0.20, 0.30, 0.70],
[0.50, 0.20, 0.30, 0.80],
[0.60, 0.20, 0.30, 0.90],
] # (1, 4, 4)7. 代码块 1:forward(...)
位置:
完整代码:
python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == "long_term_forecast":
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :] # [B, L, D]
if self.task_name == "short_term_forecast":
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :] # [B, L, D]
if self.task_name == "imputation":
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == "anomaly_detection":
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == "classification":
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None带中文注释的完整代码:
python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
# task_name 决定当前到底走哪条任务分支
if self.task_name == "long_term_forecast":
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :]
if self.task_name == "short_term_forecast":
# 当前 benchmark 例子真正进入这里
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
# 最后只保留未来 pred_len 段
return dec_out[:, -self.pred_len :, :]
if self.task_name == "imputation":
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out
if self.task_name == "anomaly_detection":
dec_out = self.anomaly_detection(x_enc)
return dec_out
if self.task_name == "classification":
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out
return None7.1 贴着这段代码讲:分支选择
当前真实例子里:
task_name = "short_term_forecast"
所以这段代码实际只会执行:
python
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :]7.2 这一段的 toy 张量演变图
text
输入:
x_enc = (1, 4, 2)
x_mark_enc = (1, 4, 4)
x_dec = (1, 4, 2)
x_mark_dec = (1, 4, 4)
步骤 1: task_name 判断
short_term_forecast -> 进入 short_forecast(...)
张量本身不变,只是控制流分支被固定
步骤 2: short_forecast(...) 返回整段 dec_out
dec_out = (1, 4, 2)
步骤 3: 只保留最后 pred_len=2 步
output = dec_out[:, -2:, :]
= (1, 2, 2)7.3 这一段的 input / output 语义
- 输入
x_enc- encoder 侧历史数值窗口
- 输入
x_mark_enc- encoder 侧历史时间特征
- 输入
x_dec- decoder 侧“历史尾部 + 未来占位”数值输入
- 输入
x_mark_dec- decoder 侧对应时间特征
- 中间
dec_out- decoder 整段输出,长度还是
label_len + pred_len
- decoder 整段输出,长度还是
- 输出
output- 真正要拿来和监督对齐的未来
pred_len段
- 真正要拿来和监督对齐的未来
8. 代码块 2:short_forecast(...)
完整代码:
python
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
).detach() # B x 1 x E
x_enc = x_enc / std_enc
enc_out = self.enc_embedding(x_enc, x_mark_enc)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
dec_out = dec_out * std_enc + mean_enc
return dec_out # [B, L, D]带中文注释的完整代码:
python
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# 1. 只对 encoder 侧数值输入做标准化
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
).detach()
x_enc = x_enc / std_enc
# 2. encoder / decoder 两侧分别做 embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
# 3. 历史窗口先过 encoder
enc_out, attns = self.encoder(enc_out, attn_mask=None)
# 4. decoder 同时看自己的输入和 encoder 输出
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
# 5. 把结果映回原始数值尺度
dec_out = dec_out * std_enc + mean_enc
return dec_out8.1 子块 A:标准化
对应代码:
python
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
).detach()
x_enc = x_enc / std_enctoy 张量演变图
text
原始 x_enc:
[
[1, 10],
[2, 11],
[3, 12],
[4, 13],
] (1, 4, 2)
步骤 1: 按时间维求均值
mean_enc = [ [2.5, 11.5] ] (1, 1, 2)
步骤 2: 去均值
x_centered =
[
[-1.5, -1.5],
[-0.5, -0.5],
[ 0.5, 0.5],
[ 1.5, 1.5],
] (1, 4, 2)
步骤 3: 求标准差
std_enc ≈ [ [1.118, 1.118] ] (1, 1, 2)
步骤 4: 标准化
x_norm ≈
[
[-1.342, -1.342],
[-0.447, -0.447],
[ 0.447, 0.447],
[ 1.342, 1.342],
] (1, 4, 2)这一步的 input / output 语义
- 输入
x_enc- 原始 encoder 数值窗口
- 中间
mean_enc- 每个通道在整段历史上的均值
- 中间
std_enc- 每个通道在整段历史上的尺度
- 输出
x_norm- 已被拉到稳定尺度的 encoder 输入
8.2 子块 B:两侧 embedding
对应代码:
python
enc_out = self.enc_embedding(x_enc, x_mark_enc)
dec_out = self.dec_embedding(x_dec, x_mark_dec)toy 张量演变图
text
encoder 侧:
x_enc = (1, 4, 2)
x_mark_enc = (1, 4, 4)
-> enc_embedding(...)
-> enc_out = (1, 4, 4)
decoder 侧:
x_dec = (1, 4, 2)
x_mark_dec = (1, 4, 4)
-> dec_embedding(...)
-> dec_out = (1, 4, 4)DataEmbedding 内部三条支路怎么一步步编码,放到 04B-DataEmbedding 细讲。
这一步的 input / output 语义
- 输入
x_enc/x_dec- 原始数值窗口
- 输入
x_mark_enc/x_mark_dec- 时间特征
- 输出
enc_out/dec_out- 映射到统一
d_model隐空间后的表示
- 映射到统一
8.3 子块 C:encoder
对应代码:
python
enc_out, attns = self.encoder(enc_out, attn_mask=None)toy 张量演变图
text
输入:
enc_out = (1, 4, 4)
步骤 1: 第 1 层 EncoderLayer
-> 仍是 (1, 4, 4)
步骤 2: distil=True 时经过 ConvLayer 压缩长度
-> 可能变成 (1, 2, 4) 或 (1, 3, 4)
步骤 3: 最后一层 EncoderLayer
-> 仍保持最后一维 d_model=4
输出:
enc_out = (1, L', 4)
attns = 注意力权重集合这一步的 input / output 语义
- 输入
enc_out- embedding 后的历史隐藏表示
- 输出
enc_out- 供 decoder 读取的上下文表示
- 输出
attns- 各层 attention 权重集合
8.4 子块 D:decoder
对应代码:
python
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)toy 张量演变图
text
输入:
dec_out = (1, 4, 4)
enc_out = (1, L', 4)
步骤 1: decoder 自己先在内部做 self-attention
-> (1, 4, 4)
步骤 2: 再读取 encoder 上下文 cross-attention
-> (1, 4, 4)
步骤 3: 最后 projection: d_model=4 -> c_out=2
-> dec_out = (1, 4, 2)这一步的 input / output 语义
- 输入
dec_out- decoder 当前隐藏表示
- 输入
enc_out- encoder 上下文
- 输出
dec_out- 已回到输出通道空间的 decoder 结果
8.5 子块 E:反标准化
对应代码:
python
dec_out = dec_out * std_enc + mean_enctoy 张量演变图
text
输入:
decoder 输出 dec_out = (1, 4, 2)
std_enc = (1, 1, 2)
mean_enc = (1, 1, 2)
步骤:
每个时间步、每个通道做
dec_out[channel] * std_enc[channel] + mean_enc[channel]
输出:
反标准化后的 dec_out = (1, 4, 2)这一步的 input / output 语义
- 输入
dec_out- 标准化尺度上的 decoder 结果
- 输入
std_enc / mean_enc- 从 encoder 输入侧保存下来的尺度信息
- 输出
dec_out- 回到原始数值尺度的 decoder 结果
9. 当前层真正要固定什么
forward(...)当前真实例子只走short_term_forecastshort_forecast(...)先返回整段dec_out- 真正的最终输出还要回到
forward(...)再截pred_len - 当前真实例子里最关键的控制参数是:
task_namepred_lend_modelenc_in / dec_in / c_outembed = timeFe_layers / d_layers / n_heads / d_ff / distil
10. 下一步
继续看: