Skip to content

DUET · Layer 3 — Linear_extractor(单专家)

§1 在父层中的位置

Linear_extractor_cluster.__init__ 实例化 num_experts=6Linear_extractor

python
self.experts = nn.ModuleList([expert(config) for _ in range(self.num_experts)])

这里 expert = Linear_extractor(import 时重命名)。每个专家独立持有自己的 Linear_SeasonalLinear_Trend 参数。

forward() 中:

python
expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]

§2 I/O 接口定义

python
Linear_extractor.forward(x_enc) -> Tensor
参数shape含义
x_enc(n_i, L, 1)分发给当前专家的子 batch(n_i 个样本)
返回(n_i, d_model, 1)每条序列的时序特征向量

全局 toy:n_i 随路由结果变化(约 3–4),L=16d_model=8moving_avg=3

⚠️ self.pred_len 实际存的是 d_model

源码 self.pred_len = configs.d_model——这里的 "pred_len" 不是预测步数,而是专家输出的特征维度 d_model。这是为了复用 DLinear 的 encoder 骨架(DLinear 中 pred_len 是真实预测长度),但在 DUET MoE 中专家只做特征压缩,输出维度改为了 d_model。阅读代码时凡遇到 self.pred_len 均应理解为 d_model=8


§3 顺序图(具体层)


§4 语义分组图(索引层)


§5 逐步骤精读

§5.0 完整原始代码

python
class Linear_extractor(nn.Module):

    def __init__(self, configs, individual=False):
        super(Linear_extractor, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.d_model
        self.decompsition = series_decomp(configs.moving_avg)
        self.individual = individual
        self.channels = configs.enc_in
        self.enc_in = 1 if configs.CI else configs.enc_in
        if self.individual:
            self.Linear_Seasonal = nn.ModuleList()
            self.Linear_Trend = nn.ModuleList()
            for i in range(self.channels):
                self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len))
                self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len))
                self.Linear_Seasonal[i].weight = nn.Parameter(
                    (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
                )
                self.Linear_Trend[i].weight = nn.Parameter(
                    (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
                )
        else:
            self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
            self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
            self.Linear_Seasonal.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
            )
            self.Linear_Trend.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
            )

    def encoder(self, x):
        seasonal_init, trend_init = self.decompsition(x)
        seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(
            0, 2, 1
        )
        if self.individual:
            seasonal_output = torch.zeros(
                [seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
                dtype=seasonal_init.dtype,
            ).to(seasonal_init.device)
            trend_output = torch.zeros(
                [trend_init.size(0), trend_init.size(1), self.pred_len],
                dtype=trend_init.dtype,
            ).to(trend_init.device)
            for i in range(self.channels):
                seasonal_output[:, i, :] = self.Linear_Seasonal[i](seasonal_init[:, i, :])
                trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :])
        else:
            seasonal_output = self.Linear_Seasonal(seasonal_init)
            trend_output = self.Linear_Trend(trend_init)
        x = seasonal_output + trend_output
        return x.permute(0, 2, 1)

    def forecast(self, x_enc):
        return self.encoder(x_enc)

    def forward(self, x_enc):
        if x_enc.shape[0] == 0:
            return torch.empty((0, self.pred_len, self.enc_in)).to(x_enc.device)
        dec_out = self.forecast(x_enc)
        return dec_out[:, -self.pred_len :, :]
⚠️ decompsition 拼写错误

源码属性名为 self.decompsition(缺少一个 'o'),正确拼写应为 decomposition。这是原始代码里的 typo,不影响运行。


§5.1 宏观逻辑

一句话目标:用 DLinear 骨架(trend + seasonal 两路线性)把一条长 L 的时序压缩为 d_model 维特征向量——这里不是在"预测未来",而是在"提取当前窗口的时序模式表示"。

DLinear 原始设计 vs DUET 专家的区别

项目原始 DLinearDUET 中的 Linear_extractor
Linear 输出维度pred_len(实际预测步数)d_model(特征维度,非预测步)
任务直接输出预测值输出特征向量,后续由通道 Transformer 处理
self.pred_len 含义真实预测长度d_model(⚠️ 命名复用)
归一化无(DLinear 本体无 RevIN)由 cluster 负责(RevIN 在外层)

用小例子(n_i=2, L=4, d_model=2, moving_avg=3)串起来

输入 x: (2, 4, 1)  ← 2 个样本,每条序列长 4

Step 1: series_decomp(kernel=3)
  front pad: 复制 x[:,0:1,:] 1 次
  end pad:   复制 x[:,-1:,:] 1 次
  padded 长度: 1+4+1 = 6
  AvgPool1d(k=3, stride=1) on 6 → 输出长度: (6-3)/1+1 = 4 ✓

  seasonal_init = x - trend   shape (2, 4, 1)
  trend_init    = moving_avg  shape (2, 4, 1)

Step 2: permute(0,2,1)
  seasonal_init → (2, 1, 4)   ← C=1 通道,L=4 时间步移到最后
  trend_init    → (2, 1, 4)

Step 3: Linear(4→2) 独立作用于最后一轴
  seasonal_output = Linear_S(seasonal_init) → (2, 1, 2)
  trend_output    = Linear_T(trend_init)    → (2, 1, 2)
  x = seasonal_output + trend_output        → (2, 1, 2)

Step 4: permute(0,2,1)
  x → (2, 2, 1)  ← d_model=2 在中间,C=1 在末尾

Step 5: [:, -2:, :]  ← no-op(形状已是 (2,2,1))
  返回 (2, 2, 1)

shape 变化全链

(n_i, L, 1)
  → series_decomp → seasonal(n_i,L,1), trend(n_i,L,1)
  → permute → (n_i, 1, L)
  → Linear_S / Linear_T → (n_i, 1, d_model)
  → sum → (n_i, 1, d_model)
  → permute → (n_i, d_model, 1)
  → [:, -d_model:, :] → (n_i, d_model, 1)

§5.2 __init__:两条分支

individual 参数

在 DUET 中,Linear_extractor_clusterexpert(config) 实例化专家,未传 individual 参数,故默认 individual=False(共享权重分支)。individual=True 的代码路径在 DUET 运行时不会执行。

共享权重分支(individual=False,DUET 实际走的路径)

python
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Trend    = nn.Linear(self.seq_len, self.pred_len)

两个线性层:输入 seq_len=16,输出 pred_len=d_model=8。参数量:各 16×8+8 = 136,合计 272 个参数/专家,6 个专家共 1632 个参数。

权重初始化

python
self.Linear_Seasonal.weight = nn.Parameter(
    (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)

初始化为全 116 的常数矩阵(shape (8, 16))。这使初始输出等于输入序列的均值——即初始状态下专家退化为全局平均,训练时再从此基础学习偏移。bias 未初始化,使用 PyTorch 默认(均匀分布)。


§5.3 series_decomp:趋势-季节分解

series_decomp 来自 duet/layers/Autoformer_EncDec.py(与 Autoformer 共用)。

完整代码:

python
class moving_avg(nn.Module):
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x

class series_decomp(nn.Module):
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

moving_avg.forward 详解(toy:kernel=3, seq=16)

padding 步数 = (31)//2=1

front = x[:, 0:1, :].repeat(1, 1, 1)   ← 复制首个时间步 1 次
end   = x[:, -1:, :].repeat(1, 1, 1)   ← 复制末个时间步 1 次
padded = cat([front, x, end], dim=1)    → shape (n_i, 18, 1)

AvgPool1d 需要 (B, C, L) 格式:

x.permute(0,2,1) → (n_i, 1, 18)
AvgPool1d(k=3, stride=1, pad=0) → 输出长度: (18-3)/1+1 = 16
结果 (n_i, 1, 16)
.permute(0,2,1) → (n_i, 16, 1) = moving_mean

toy 数值追踪(取 x[0, :, 0] = [1, 3, 5, 7, 9, 11, 13, 15, 14, 12, 10, 8, 6, 4, 2, 0]):

padded[0, :, 0] = [1, 1, 3, 5, 7, 9, 11, 13, 15, 14, 12, 10, 8, 6, 4, 2, 0, 0]
                   ^                                                         ^
                  front pad                                            end pad

moving_mean[0, :, 0]:
  位置 0: mean(1,1,3) = 5/3 ≈ 1.67
  位置 1: mean(1,3,5) = 9/3 = 3.00
  位置 2: mean(3,5,7) = 15/3 = 5.00
  ...
  位置 15: mean(2,0,0) = 2/3 ≈ 0.67

seasonal[0, :, 0] = x - moving_mean:
  位置 0: 1 - 1.67 = -0.67
  位置 1: 3 - 3.00 =  0.00
  位置 2: 5 - 5.00 =  0.00
  ...(线性序列的季节分量接近 0)
为什么用 replication padding 而不是 zero padding?

Zero padding 会在序列两端引入人为的"趋势跳变"(均值被拉向 0)。复制端点值(replication padding)假设序列在边界外保持恒定,使移动平均在边界处不失真,趋势估计更稳定。这是时序分解的标准技巧。


§5.4 encoder:permute + Linear + 求和

python
def encoder(self, x):
    seasonal_init, trend_init = self.decompsition(x)
    seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
    ...
    seasonal_output = self.Linear_Seasonal(seasonal_init)
    trend_output = self.Linear_Trend(trend_init)
    x = seasonal_output + trend_output
    return x.permute(0, 2, 1)

为什么 permute 再做 Linear?

nn.Linear 只作用于输入的最后一维。分解输出是 (n_i, L, C) = (n_i, 16, 1),最后一维是 C=1(通道),不是我们想要映射的 L=16。permute 后变成 (n_i, C, L) = (n_i, 1, 16),最后一维是 L,Linear 就能把 16 个时间步映射到 8 个特征维度。

ASCII 图解 — permute 目的:

输入  (n_i, 16, 1)         →  permute(0,2,1)  →  (n_i, 1, 16)
       ↑    ↑  ↑                                      ↑   ↑
     batch  L  C                                    batch  C   L

                                                    Linear 作用于此维

全局 toy 数值追踪(设 n_i=3,取其中第 0 个样本):

seasonal_init[0, :, 0] = [-0.67, 0.00, 0.00, ...] (16 个值)
seasonal_init[0] shape (16, 1) → permute → (1, 16)

Linear_Seasonal.weight shape (8, 16),初始全 1/16
seasonal_output[0] = (1,16) @ (16,8).T = (1, 8)
  每个输出维 = mean(seasonal_init[0, :, 0]) × 1 + bias
           ≈ 0 + bias(季节分量均值接近 0,初始权重均等)

Linear_Trend.weight shape (8, 16),初始全 1/16
trend_init[0, :, 0] = [1.67, 3.00, 5.00, ...] (16 个值)
trend_output[0] = mean(trend_init[0,:,0]) × 8 = 均值 × 8 个输出
  均值 ≈ (1.67 + ... + 0.67)/16 ≈ 7.5(线性序列的均值)
  初始 trend_output[0, 0, :] ≈ [7.5, 7.5, ..., 7.5]

x = seasonal_output + trend_output → (n_i, 1, 8)
x.permute(0,2,1) → (n_i, 8, 1)

§5.5 forward:空 batch 保护与切片

python
def forward(self, x_enc):
    if x_enc.shape[0] == 0:
        return torch.empty((0, self.pred_len, self.enc_in)).to(x_enc.device)
    dec_out = self.forecast(x_enc)
    return dec_out[:, -self.pred_len :, :]

空 batch 保护:当某专家在当前 mini-batch 中没有被路由到任何样本时(_part_sizes[e]=0),dispatch 返回 shape (0, 16, 1) 的空 tensor。空 tensor 无法做卷积/线性运算,这里提前拦截并返回 (0, 8, 1) 的空 tensor。SparseDispatcher.combine 的 torch.cat 可以 cat 含 0 行的 tensor,stitched 形状不受影响。

[:, -self.pred_len:, :] 切片

dec_out shape (n_i, d_model, 1) = (n_i, 8, 1)

self.pred_len = d_model = 8,故切片 [:, -8:, :] 取最后 8 行 = 全部行 → no-op。这个切片是从 DLinear 代码继承来的写法,在 DLinear 原始用途中 pred_len 可能小于序列某些维度,这里不产生任何截断效果。


§6 下钻子组件

子组件职责说明
series_decompmoving_avg + 残差分解内联于 §5.3,不另开文档(仅 2 步,已完整精读)

创建:2026-04-24

*记录并在线阅读我的笔记*