Appearance
02 · Layer 1 · forecast 主链
§1 在父层中的位置
forecast 是 iTransformer.forward() 的唯一实质分支(当 self.task_name == 'long_term_forecast' 时调用)。父层 forward 仅做任务路由,所有计算逻辑全部在此函数内完成。
forward(x_enc, x_mark_enc, x_dec, x_mark_dec)
└── forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) ← 本文档
├── enc_embedding(...) ← §3A,见 [[03A-Layer2A-DataEmbedding_inverted]]
└── encoder(...) ← §3B,见 [[03B-Layer2B-Encoder]]§2 I/O 接口定义
输入四元组
| 参数 | shape(toy) | 说明 |
|---|---|---|
x_enc | (3, 12, 5) | 历史序列,B=3,seq_len=12,N=5 个变量 |
x_mark_enc | (3, 12, 4) | 编码器时间标记,time_dims=4 |
x_dec | (3, ?, 5) | 传入但 forecast() 完全不使用 |
x_mark_dec | (3, ?, 4) | 传入但 forecast() 完全不使用 |
x_dec和x_mark_dec仅为保持统一函数签名而存在(其他任务或模型会用到),iTransformer 的 forecast 分支不读取这两个参数。
输出
| 返回值 | shape(toy) |
|---|---|
dec_out | (3, 6, 5) |
即 (B, pred_len, N),与输入的 (B, seq_len, N) 在最后两维的语义完全对称。
§3 顺序图(具体层)
- §3A
enc_embedding:输入(3,12,5)+(3,12,4),输出(3,9,8),详见 [[03A-Layer2A-DataEmbedding_inverted]] - §3B
encoder:输入(3,9,8),输出(3,9,8),详见 [[03B-Layer2B-Encoder]]
§4 语义分组图(索引层)
三组职责分明:
- 归一化组:消除每个实例的分布偏移,使网络看到零均值单位方差的输入。
- 特征提取组:将变量序列映射为 token,在变量维度做注意力,捕捉跨变量依赖。
- 输出组:将 token 的 d_model 维特征直接投影到 pred_len 步,裁剪掉时间标记 token,再还原原始分布。
§5 逐步精读
§5.0 完整原始代码
python
class iTransformer(nn.Module):
"""
Paper link: https://arxiv.org/abs/2310.06625
"""
def __init__(self, configs):
super(iTransformer, self).__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.output_attention = configs.output_attention
# Embedding
self.enc_embedding = DataEmbedding_inverted(
configs.seq_len,
configs.d_model,
configs.embed,
configs.freq,
configs.dropout,
)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(
False,
configs.factor,
attention_dropout=configs.dropout,
output_attention=configs.output_attention,
),
configs.d_model,
configs.n_heads,
),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation,
)
for l in range(configs.e_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model),
)
# Decoder
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
):
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)
if self.task_name == "imputation":
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
if self.task_name == "anomaly_detection":
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
if self.task_name == "classification":
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
configs.d_model * configs.enc_in, configs.num_class
)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
_, _, N = x_enc.shape
# Embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out§5.1 宏观逻辑
最小例子:B=1,N=2,seq_len=4,pred_len=2,time_dims=1,token_count=N+time_dims=3
原始输入
x_enc (1, 4, 2) — 1个样本,4个时间步,2个变量
x_mark_enc (1, 4, 1) — 1个时间标记特征
① Instance Norm
对每个变量的4步序列做零均值单位方差标准化
x_enc_norm (1, 4, 2)
② enc_embedding(inverted)
关键操作:把时间轴(L=4)当作特征维,变量轴(N+time_dims=3)当作序列维
输出 enc_out (1, 3, d_model) — 3个 token,每个 token 代表一个"变量"
③ encoder
在 3 个 token 之间做 Self-Attention
enc_out (1, 3, d_model) — token 之间交换了跨变量信息
④ projection + 裁剪
Linear(d_model→pred_len=2):每个 token 直接输出 2 步预测值
(1, 3, 2) → permute → (1, 2, 3) → [:,:,:N=2] → (1, 2, 2)
⑤ De-normalization
还原到原始分布
dec_out (1, 2, 2) — (B, pred_len, N)"token = 变量"的直觉
标准 Transformer 的 token 是时间步,attention 捕捉不同时刻的依赖。iTransformer 把整条时间序列看作一个 token 的特征向量,attention 改为捕捉不同变量之间的相关性。这是 iTransformer 最核心的设计翻转:序列长度 → 特征维,变量数 → 序列长度。
图解 — iTransformer forecast() 整体架构流:
§5.2 步骤一:Instance Normalization
python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev各参数作用
| 参数/操作 | 作用 |
|---|---|
dim=1 | 在时间轴(seq_len=12)上求统计量,保留 batch 和变量维 |
keepdim=True | 保持 shape 为 (B,1,N) 而非 (B,N),使后续广播减法/除法直接作用于 L 维 |
.detach() | 均值只作为统计量参与前向计算,梯度不流回均值本身,避免优化器通过均值绕开归一化 |
unbiased=False | 除以 |
+ 1e-5 | 防止 stdev 为零时出现除零(例如常值序列) |
公式
Toy 数值追踪(batch=0,变量=0,12个时间步)
原始序列:
去均值后:
归一化后:
shape 变化:x_enc (3,12,5),means (3,1,5),stdev (3,1,5),归一化后 x_enc (3,12,5) 不变。
⚠️ 冗余操作:squeeze+unsqueeze round-trip
反归一化处用了
stdev[:, 0, :].unsqueeze(1):
stdev的 shape 已经是(3,1,5)(因为keepdim=True)stdev[:, 0, :]取第 0 个(也是唯一的)时间位置 →(3,5)(squeeze 了 dim=1).unsqueeze(1)→(3,1,5)(重新加回 dim=1)这是一个无意义的 round-trip,等价于直接使用
stdev。结果完全相同,只是多了两次内存操作。
设计约束传递:keepdim=True 在 §5.6 中被隐式依赖
keepdim=True 在 §5.6 中被隐式依赖此处
mean(dim=1, keepdim=True)和var(dim=1, keepdim=True)保证了means和stdev的 dim=1 一定等于 1。这个保证在 §5.6 反归一化中被隐式依赖:
stdev[:, 0, :]正确的原因是 dim=1 一定是 1,索引0是唯一合法的索引- 若此处改为
keepdim=False,stdev变成(3,5),stdev[:,0,:]会取 batch=0 的切片 → shape(5,)而非(3,5),语义完全错误
写法 meansshapemeans[:,0,:]语义keepdim=True(当前)(3,1,5)取唯一的时间位置,等价于 means本身keepdim=False(错误)(3,5)取 batch=0 的行,从 3 个样本变成了 1 个样本
§5.3 步骤二:enc_embedding
python
enc_out = self.enc_embedding(x_enc, x_mark_enc)enc_embedding 是 DataEmbedding_inverted 的实例。它将时间轴翻转为特征维,变量轴翻转为序列维,并将时间标记 token 拼接进来。
- 输入:
x_enc (3,12,5),x_mark_enc (3,12,4) - 输出:
enc_out (3,9,8),即(B, N+time_dims, d_model)=(3, 5+4, 8)
详见 [[03A-Layer2A-DataEmbedding_inverted]]
§5.4 步骤三:encoder
python
enc_out, attns = self.encoder(enc_out, attn_mask=None)标准 Transformer Encoder 堆叠,共 e_layers=2 层,每层包含 Multi-Head Attention + FFN。注意力在 token 维(9 个变量/时间 token)之间计算,而非时间步之间。
- 输入:
enc_out (3,9,8) - 输出:
enc_out (3,9,8)(shape 不变,内容被注意力机制更新)
详见 [[03B-Layer2B-Encoder]]
§5.5 步骤四:projection + 裁剪
python
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]self.projection 是 nn.Linear(d_model, pred_len),即 Linear(8, 6)。
ASCII 图:shape 变化全程
enc_out projection(Linear 8→6) permute(0,2,1) [:,:,:N=5]
(3, 9, 8) ─────────────────────────► (3, 9, 6) ──────► (3, 6, 9) ──────► (3, 6, 5)
每个 token(d_model=8) 轴0=B 丢弃最后
直接输出 pred_len=6 步 轴1=6步 4个时间
轴2=9tok 标记 token为什么 Linear 方向与标准 Transformer 相反?
| 模型 | Linear 作用 | 方向 |
|---|---|---|
| 标准 Transformer | 每个时间步 token 的 d_model 特征 → 变量数 c_out | Linear(d_model, c_out) |
| iTransformer | 每个变量 token 的 d_model 特征 → 预测步数 pred_len | Linear(d_model, pred_len) |
标准 Transformer 把 d_model 映射到输出变量数,因为每个 token 对应一个时间步,需要在变量维展开。iTransformer 把 token 颠倒了:每个 token 对应一个变量,d_model 编码了该变量在整段历史上的特征,因此直接将其映射到未来 pred_len 步是自然的。permute 方向随之相反。
[:,:,:N] 裁剪的必要性
embedding 阶段把 time_dims=4 个时间标记也 concat 成了 token(共 9 个 token = 5 变量 + 4 时间)。encoder 让这些时间 token 参与注意力以提供位置/时间先验,但最终预测只需要 N=5 个变量的输出,因此裁掉末尾 4 个时间 token 的预测结果。
Toy 数值
projection 输出 (3,9,6),permute 后 (3,6,9),裁剪后 (3,6,5),即 (B=3, pred_len=6, N=5)。
§5.6 步骤五:De-normalization
python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))公式
其中
shape 展开
stdev 路径:
stdev:(3,1,5)stdev[:, 0, :]:(3,5),取唯一的时间位置(dim=1=1 保证安全).unsqueeze(1):(3,1,5).repeat(1, pred_len=6, 1):(3,6,5)
means 路径完全相同。
⚠️ 冗余:stdev[:, 0, :].unsqueeze(1) 等价于 stdev
stdev[:, 0, :].unsqueeze(1) 等价于 stdev因为
keepdim=True已经使 stdev 的 shape 固定为(B,1,N),[:,0,:]squeeze 后再 unsqueeze 得到的仍是(B,1,N)。可直接写:pythondec_out = dec_out * stdev.repeat(1, self.pred_len, 1) dec_out = dec_out + means.repeat(1, self.pred_len, 1)
设计约束传递:keepdim=True 为何关键
keepdim=True在第一步被写入后,在第五步被隐式依赖:
stdev[:, 0, :]的正确性依赖 dim=1 的大小恰好为 1- 若第一步改为
keepdim=False(shape 变为(B,N)),则[:,0,:]会取 batch=0 的切片,语义完全错误因此
keepdim=True是第一步传递给第五步的隐式约束,在阅读反归一化代码时必须回溯到第一步才能正确理解。
Toy 数值还原(batch=0,变量=0,pred_len 中某预测步
网络输出归一化空间下的值
例如若
§6 下钻子组件列表
| 编号 | 文件 | 对应代码 | 输入 shape | 输出 shape |
|---|---|---|---|---|
| §3A | [[03A-Layer2A-DataEmbedding_inverted]] | self.enc_embedding(x_enc, x_mark_enc) | (3,12,5)+(3,12,4) | (3,9,8) |
| §3B | [[03B-Layer2B-Encoder]] | self.encoder(enc_out, attn_mask=None) | (3,9,8) | (3,9,8) |