Skip to content

Layer 0 — 接入界面

本文回答一个问题:TFB 如何把 DLinear 当黑盒使用?

包含两面:实例化侧(config → DLinear.__init__)和调用侧(_processforward 的 I/O 契约)。

1. 在父层中的位置

TFB benchmark 的最外层循环最终走到这里:TransformerAdapter 负责实例化 DLinear(config) 并在每个 batch 里调用 DLinear.forward()。本文档覆盖这一整个"接入边界"。


2. I/O 接口定义

实例化侧

python
model = DLinear(config)
config 字段来源DLinear 里的用途
seq_len命令行 --seq_len(默认96)self.seq_len,Linear 的输入维度
pred_len命令行 --pred_len(即 horizon)self.pred_len,Linear 的输出维度
moving_avgMODEL_HYPER_PARAMS["moving_avg"]=25series_decomp 的窗口大小
enc_in框架自动推断(数据集特征列数)self.channels,特征数
task_name命令行 --task_name(默认 short_term_forecast)forward 里的分支判断
individual⚠️ 硬编码 False,adapter 不传这个参数是否给每个变量独立训练线性头

注意individual=False_init_model() 调用 DLinear(self.config) 时的默认值,命令行无法覆盖。

调用侧 — _process 构造的四元组

python
output = self.model(input, input_mark, dec_input, target_mark)
参数shape(toy)含义
input (= x_enc)(2, 6, 3) = (B, seq_len, enc_in)encoder 历史输入
input_mark (= x_mark_enc)(2, 6, 4)时间戳特征,DLinear 不使用
dec_input (= x_dec)(2, 5, 3) = (B, label_len+pred_len, enc_in)decoder 输入,DLinear 不使用
target_mark (= x_mark_dec)(2, 5, 4)时间戳特征,DLinear 不使用
返回(2, 2, 3) = (B, pred_len, enc_in)预测输出

DLinear 实际只用了 x_enc,其余三个参数完全忽略。


3. 顺序图(具体层)


4. 语义分组图(索引层)


5. 精读

5.1 参数合并:Config 对象如何构建

python
# adapters_for_transformers.py:70
def __init__(self, model_name, model_class, **kwargs):
    super(TransformerAdapter, self).__init__(MODEL_HYPER_PARAMS, **kwargs)
python
# deep_forecasting_model_base.py:75
self.config = Config(model_config, **kwargs)
python
# deep_forecasting_model_base.py:42-58(Config.__init__)
for key, value in DEFAULT_HYPER_PARAMS.items():  # 基础默认值
    setattr(self, key, value)
for key, value in model_config.items():           # MODEL_HYPER_PARAMS 覆盖
    setattr(self, key, value)
for key, value in kwargs.items():                 # 命令行参数再覆盖
    setattr(self, key, value)
if hasattr(self, "horizon"):                      # horizon → pred_len 向后兼容
    setattr(self, "pred_len", self.horizon)

优先级:命令行参数 > MODEL_HYPER_PARAMS > DEFAULT_HYPER_PARAMS

enc_in 不在命令行里,由框架在 forecast_fit 里根据数据集特征列数自动注入,toy 值 = 3。

5.2 DLinear.__init__ 关键注册

原始代码:

python
# DLinear.py:28-55
self.decompsition = series_decomp(configs.moving_avg)
self.individual = individual
self.channels = configs.enc_in

if self.individual:
    self.Linear_Seasonal = nn.ModuleList()
    self.Linear_Trend = nn.ModuleList()

    for i in range(self.channels):
        self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len))
        self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len))

        self.Linear_Seasonal[i].weight = nn.Parameter(
            (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
        )
        self.Linear_Trend[i].weight = nn.Parameter(
            (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
        )
else:
    self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
    self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)

    self.Linear_Seasonal.weight = nn.Parameter(
        (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
    )
    self.Linear_Trend.weight = nn.Parameter(
        (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
    )

注解版:

python
self.individual = individual   # TFB adapter 不传此参数,默认 False
self.channels = configs.enc_in # toy: 3

# ── individual=False(TFB 实际路径)──
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)  # toy: Linear(6, 2)
self.Linear_Trend    = nn.Linear(self.seq_len, self.pred_len)  # toy: Linear(6, 2)
# 两个单独的 Linear,所有变量共享同一组参数
# weight 初始化为全 1/seq_len:weight.shape = (pred_len, seq_len) = (2, 6)

# ── individual=True(代码存在,TFB 不走)──
self.Linear_Seasonal = nn.ModuleList()   # 长度 = channels 的列表
self.Linear_Trend    = nn.ModuleList()
for i in range(self.channels):           # toy: i = 0, 1, 2
    self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len))
    # 每个变量 i 有专属的 Linear(6,2),参数互不共享
    # 共 channels 个 Linear_Seasonal + channels 个 Linear_Trend

toy 数值(individual=False):
Linear_Seasonal.weight[0] = [0.167, 0.167, 0.167, 0.167, 0.167, 0.167](对历史6步均匀加权)。

toy 数值(individual=True,如果走此路径):
Linear_Seasonal[0].weight.shape = (2, 6)Linear_Seasonal[1].weight.shape = (2, 6)Linear_Seasonal[2].weight.shape = (2, 6)——三组独立参数,训练后各自学出适合自己变量的映射。

5.3 _process() 构造 dec_input

python
# adapters_for_transformers.py:85-93
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
# zeros shape: (B, pred_len, enc_in) = (2, 2, 3)

dec_input = torch.cat(
    [target[:, :self.config.label_len, :], dec_input], dim=1
).float().to(input.device)
# cat 后 shape: (B, label_len+pred_len, enc_in) = (2, label_len+2, 3)

DLinear 的 forward 接收这个 dec_input,但完全不使用它。这是框架统一接口的开销。

5.4 forward() 分支判断

python
# DLinear.py:111-120
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if (
        self.task_name == "long_term_forecast"
        or self.task_name == "short_term_forecast"
    ):
        dec_out = self.forecast(x_enc)          # 只传 x_enc,其余丢弃
        return dec_out[:, -self.pred_len :, :]  # 截取最后 pred_len 步

toy 数值:dec_out 的 shape 为 (2, 6, 3)(encoder 输出 seq_len 步),截取 [:, -2:, :](2, 2, 3)

为什么 encoder 输出 seq_len 步而不是 pred_len 步?
encoder()Linear(seq_len, pred_len) 输出维度是 pred_len,所以直接输出 (B, pred_len, C)[:, -pred_len:, :] 实际上等于取全部。


6. 下钻子组件

子组件职责文档
encoder(x_enc)DLinear 的主计算:分解 + 双路 Linear + 合并02-Layer1-encoder主链

*记录并在线阅读我的笔记*