Appearance
DUET · Layer 3 — Linear_extractor(单专家)
§1 在父层中的位置
Linear_extractor_cluster.__init__ 实例化 num_experts=6 个 Linear_extractor:
python
self.experts = nn.ModuleList([expert(config) for _ in range(self.num_experts)])这里 expert = Linear_extractor(import 时重命名)。每个专家独立持有自己的 Linear_Seasonal 和 Linear_Trend 参数。
forward() 中:
python
expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]§2 I/O 接口定义
python
Linear_extractor.forward(x_enc) -> Tensor| 参数 | shape | 含义 |
|---|---|---|
x_enc | (n_i, L, 1) | 分发给当前专家的子 batch(n_i 个样本) |
| 返回 | (n_i, d_model, 1) | 每条序列的时序特征向量 |
全局 toy:n_i 随路由结果变化(约 3–4),L=16,d_model=8,moving_avg=3。
⚠️ self.pred_len 实际存的是 d_model
self.pred_len 实际存的是 d_model源码
self.pred_len = configs.d_model——这里的 "pred_len" 不是预测步数,而是专家输出的特征维度 d_model。这是为了复用 DLinear 的 encoder 骨架(DLinear 中 pred_len 是真实预测长度),但在 DUET MoE 中专家只做特征压缩,输出维度改为了 d_model。阅读代码时凡遇到self.pred_len均应理解为d_model=8。
§3 顺序图(具体层)
§4 语义分组图(索引层)
§5 逐步骤精读
§5.0 完整原始代码
python
class Linear_extractor(nn.Module):
def __init__(self, configs, individual=False):
super(Linear_extractor, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.d_model
self.decompsition = series_decomp(configs.moving_avg)
self.individual = individual
self.channels = configs.enc_in
self.enc_in = 1 if configs.CI else configs.enc_in
if self.individual:
self.Linear_Seasonal = nn.ModuleList()
self.Linear_Trend = nn.ModuleList()
for i in range(self.channels):
self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len))
self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len))
self.Linear_Seasonal[i].weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)
self.Linear_Trend[i].weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)
else:
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Seasonal.weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)
self.Linear_Trend.weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)
def encoder(self, x):
seasonal_init, trend_init = self.decompsition(x)
seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(
0, 2, 1
)
if self.individual:
seasonal_output = torch.zeros(
[seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
dtype=seasonal_init.dtype,
).to(seasonal_init.device)
trend_output = torch.zeros(
[trend_init.size(0), trend_init.size(1), self.pred_len],
dtype=trend_init.dtype,
).to(trend_init.device)
for i in range(self.channels):
seasonal_output[:, i, :] = self.Linear_Seasonal[i](seasonal_init[:, i, :])
trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :])
else:
seasonal_output = self.Linear_Seasonal(seasonal_init)
trend_output = self.Linear_Trend(trend_init)
x = seasonal_output + trend_output
return x.permute(0, 2, 1)
def forecast(self, x_enc):
return self.encoder(x_enc)
def forward(self, x_enc):
if x_enc.shape[0] == 0:
return torch.empty((0, self.pred_len, self.enc_in)).to(x_enc.device)
dec_out = self.forecast(x_enc)
return dec_out[:, -self.pred_len :, :]⚠️ decompsition 拼写错误
decompsition 拼写错误源码属性名为
self.decompsition(缺少一个 'o'),正确拼写应为decomposition。这是原始代码里的 typo,不影响运行。
§5.1 宏观逻辑
一句话目标:用 DLinear 骨架(trend + seasonal 两路线性)把一条长 L 的时序压缩为 d_model 维特征向量——这里不是在"预测未来",而是在"提取当前窗口的时序模式表示"。
DLinear 原始设计 vs DUET 专家的区别:
| 项目 | 原始 DLinear | DUET 中的 Linear_extractor |
|---|---|---|
Linear 输出维度 | pred_len(实际预测步数) | d_model(特征维度,非预测步) |
| 任务 | 直接输出预测值 | 输出特征向量,后续由通道 Transformer 处理 |
self.pred_len 含义 | 真实预测长度 | d_model(⚠️ 命名复用) |
| 归一化 | 无(DLinear 本体无 RevIN) | 由 cluster 负责(RevIN 在外层) |
用小例子(n_i=2, L=4, d_model=2, moving_avg=3)串起来:
输入 x: (2, 4, 1) ← 2 个样本,每条序列长 4
Step 1: series_decomp(kernel=3)
front pad: 复制 x[:,0:1,:] 1 次
end pad: 复制 x[:,-1:,:] 1 次
padded 长度: 1+4+1 = 6
AvgPool1d(k=3, stride=1) on 6 → 输出长度: (6-3)/1+1 = 4 ✓
seasonal_init = x - trend shape (2, 4, 1)
trend_init = moving_avg shape (2, 4, 1)
Step 2: permute(0,2,1)
seasonal_init → (2, 1, 4) ← C=1 通道,L=4 时间步移到最后
trend_init → (2, 1, 4)
Step 3: Linear(4→2) 独立作用于最后一轴
seasonal_output = Linear_S(seasonal_init) → (2, 1, 2)
trend_output = Linear_T(trend_init) → (2, 1, 2)
x = seasonal_output + trend_output → (2, 1, 2)
Step 4: permute(0,2,1)
x → (2, 2, 1) ← d_model=2 在中间,C=1 在末尾
Step 5: [:, -2:, :] ← no-op(形状已是 (2,2,1))
返回 (2, 2, 1)shape 变化全链:
(n_i, L, 1)
→ series_decomp → seasonal(n_i,L,1), trend(n_i,L,1)
→ permute → (n_i, 1, L)
→ Linear_S / Linear_T → (n_i, 1, d_model)
→ sum → (n_i, 1, d_model)
→ permute → (n_i, d_model, 1)
→ [:, -d_model:, :] → (n_i, d_model, 1)§5.2 __init__:两条分支
individual 参数:
在 DUET 中,Linear_extractor_cluster 以 expert(config) 实例化专家,未传 individual 参数,故默认 individual=False(共享权重分支)。individual=True 的代码路径在 DUET 运行时不会执行。
共享权重分支(individual=False,DUET 实际走的路径):
python
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)两个线性层:输入 seq_len=16,输出 pred_len=d_model=8。参数量:各 16×8+8 = 136,合计 272 个参数/专家,6 个专家共 1632 个参数。
权重初始化:
python
self.Linear_Seasonal.weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)初始化为全 (8, 16))。这使初始输出等于输入序列的均值——即初始状态下专家退化为全局平均,训练时再从此基础学习偏移。bias 未初始化,使用 PyTorch 默认(均匀分布)。
§5.3 series_decomp:趋势-季节分解
series_decomp 来自 duet/layers/Autoformer_EncDec.py(与 Autoformer 共用)。
完整代码:
python
class moving_avg(nn.Module):
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_meanmoving_avg.forward 详解(toy:kernel=3, seq=16):
padding 步数 =
front = x[:, 0:1, :].repeat(1, 1, 1) ← 复制首个时间步 1 次
end = x[:, -1:, :].repeat(1, 1, 1) ← 复制末个时间步 1 次
padded = cat([front, x, end], dim=1) → shape (n_i, 18, 1)AvgPool1d 需要 (B, C, L) 格式:
x.permute(0,2,1) → (n_i, 1, 18)
AvgPool1d(k=3, stride=1, pad=0) → 输出长度: (18-3)/1+1 = 16
结果 (n_i, 1, 16)
.permute(0,2,1) → (n_i, 16, 1) = moving_meantoy 数值追踪(取 x[0, :, 0] = [1, 3, 5, 7, 9, 11, 13, 15, 14, 12, 10, 8, 6, 4, 2, 0]):
padded[0, :, 0] = [1, 1, 3, 5, 7, 9, 11, 13, 15, 14, 12, 10, 8, 6, 4, 2, 0, 0]
^ ^
front pad end pad
moving_mean[0, :, 0]:
位置 0: mean(1,1,3) = 5/3 ≈ 1.67
位置 1: mean(1,3,5) = 9/3 = 3.00
位置 2: mean(3,5,7) = 15/3 = 5.00
...
位置 15: mean(2,0,0) = 2/3 ≈ 0.67
seasonal[0, :, 0] = x - moving_mean:
位置 0: 1 - 1.67 = -0.67
位置 1: 3 - 3.00 = 0.00
位置 2: 5 - 5.00 = 0.00
...(线性序列的季节分量接近 0)为什么用 replication padding 而不是 zero padding?
Zero padding 会在序列两端引入人为的"趋势跳变"(均值被拉向 0)。复制端点值(replication padding)假设序列在边界外保持恒定,使移动平均在边界处不失真,趋势估计更稳定。这是时序分解的标准技巧。
§5.4 encoder:permute + Linear + 求和
python
def encoder(self, x):
seasonal_init, trend_init = self.decompsition(x)
seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
...
seasonal_output = self.Linear_Seasonal(seasonal_init)
trend_output = self.Linear_Trend(trend_init)
x = seasonal_output + trend_output
return x.permute(0, 2, 1)为什么 permute 再做 Linear?
nn.Linear 只作用于输入的最后一维。分解输出是 (n_i, L, C) = (n_i, 16, 1),最后一维是 C=1(通道),不是我们想要映射的 L=16。permute 后变成 (n_i, C, L) = (n_i, 1, 16),最后一维是 L,Linear 就能把 16 个时间步映射到 8 个特征维度。
ASCII 图解 — permute 目的:
输入 (n_i, 16, 1) → permute(0,2,1) → (n_i, 1, 16)
↑ ↑ ↑ ↑ ↑
batch L C batch C L
↑
Linear 作用于此维全局 toy 数值追踪(设 n_i=3,取其中第 0 个样本):
seasonal_init[0, :, 0] = [-0.67, 0.00, 0.00, ...] (16 个值)
seasonal_init[0] shape (16, 1) → permute → (1, 16)
Linear_Seasonal.weight shape (8, 16),初始全 1/16
seasonal_output[0] = (1,16) @ (16,8).T = (1, 8)
每个输出维 = mean(seasonal_init[0, :, 0]) × 1 + bias
≈ 0 + bias(季节分量均值接近 0,初始权重均等)
Linear_Trend.weight shape (8, 16),初始全 1/16
trend_init[0, :, 0] = [1.67, 3.00, 5.00, ...] (16 个值)
trend_output[0] = mean(trend_init[0,:,0]) × 8 = 均值 × 8 个输出
均值 ≈ (1.67 + ... + 0.67)/16 ≈ 7.5(线性序列的均值)
初始 trend_output[0, 0, :] ≈ [7.5, 7.5, ..., 7.5]
x = seasonal_output + trend_output → (n_i, 1, 8)
x.permute(0,2,1) → (n_i, 8, 1)§5.5 forward:空 batch 保护与切片
python
def forward(self, x_enc):
if x_enc.shape[0] == 0:
return torch.empty((0, self.pred_len, self.enc_in)).to(x_enc.device)
dec_out = self.forecast(x_enc)
return dec_out[:, -self.pred_len :, :]空 batch 保护:当某专家在当前 mini-batch 中没有被路由到任何样本时(_part_sizes[e]=0),dispatch 返回 shape (0, 16, 1) 的空 tensor。空 tensor 无法做卷积/线性运算,这里提前拦截并返回 (0, 8, 1) 的空 tensor。SparseDispatcher.combine 的 torch.cat 可以 cat 含 0 行的 tensor,stitched 形状不受影响。
[:, -self.pred_len:, :] 切片:
dec_out shape (n_i, d_model, 1) = (n_i, 8, 1)。
self.pred_len = d_model = 8,故切片 [:, -8:, :] 取最后 8 行 = 全部行 → no-op。这个切片是从 DLinear 代码继承来的写法,在 DLinear 原始用途中 pred_len 可能小于序列某些维度,这里不产生任何截断效果。
§6 下钻子组件
| 子组件 | 职责 | 说明 |
|---|---|---|
series_decomp | moving_avg + 残差分解 | 内联于 §5.3,不另开文档(仅 2 步,已完整精读) |
创建:2026-04-24