Skip to content

Layer 0 — 接入界面

本文回答一个问题:TFB 如何把 PatchTST 当黑盒使用?

包含两面:实例化侧(config → PatchTST.__init__)和调用侧(_processforward 的 I/O 契约)。

1. 在父层中的位置

TFB benchmark 的最外层循环最终走到这里:TransformerAdapter 负责实例化 PatchTST(config) 并在每个 batch 里调用 PatchTST.forward()。本文档覆盖这一整个"接入边界"。


2. I/O 接口定义

实例化侧

python
model = PatchTST(config)
config 字段来源PatchTST 里的用途
seq_len命令行 --seq_len(默认96)self.seq_len,序列长度
pred_len命令行 --pred_len(即 horizon)self.pred_len,FlattenHead 输出维度
patch_lenMODEL_HYPER_PARAMS["patch_len"]=16self.patch_len,patch 窗口大小
strideMODEL_HYPER_PARAMS["stride"]=8self.stride,patch 步长;同时 padding=stride
d_modelMODEL_HYPER_PARAMS["d_model"]=512embedding 和注意力维度
n_headsMODEL_HYPER_PARAMS["n_heads"]=8多头注意力头数
e_layersMODEL_HYPER_PARAMS["e_layers"]=2Encoder 层数
d_ffMODEL_HYPER_PARAMS["d_ff"]=2048FFN 中间维度
enc_in框架自动推断(数据集特征列数)config.enc_in,FlattenHead 的 n_vars
task_name命令行 --task_nameforward() 分支判断

注意padding 不是独立参数,直接赋值 padding = self.stride(PatchTST.py:41)。
toy 中 padding = stride = 2,真实值 padding = stride = 8

调用侧 — _process 构造的四元组

python
output = self.model(input, input_mark, dec_input, target_mark)
参数shape(toy)含义
input (= x_enc)(2, 12, 4) = (B, seq_len, enc_in)encoder 历史输入
input_mark (= x_mark_enc)(2, 12, 4)时间戳特征,PatchTST 不使用
dec_input (= x_dec)(2, label_len+3, 4)decoder 输入,PatchTST 不使用
target_mark (= x_mark_dec)(2, label_len+3, 4)时间戳特征,PatchTST 不使用
返回(2, 3, 4) = (B, pred_len, enc_in)预测输出

PatchTST 的 forecast() 接收四个参数,但实际只用了 x_enc,其余三个完全忽略(forecast 函数签名接收但内部不读取 mark 和 x_dec)。


3. 顺序图(具体层)


4. 语义分组图(索引层)


5. 精读

5.1 参数合并:Config 对象如何构建

优先级:命令行参数 > MODEL_HYPER_PARAMS > DEFAULT_HYPER_PARAMS

python
# MODEL_HYPER_PARAMS 中与 PatchTST 直接相关的值(toy 中被覆盖为小数值):
"patch_len": 16,   # toy: 4
"stride": 8,       # toy: 2
"d_model": 512,    # toy: 16
"n_heads": 8,      # toy: 2
"e_layers": 2,     # toy: 1
"d_ff": 2048,      # toy: 64

enc_in 不在命令行里,由框架在 forecast_fit 里根据数据集特征列数自动注入,toy 值 = 4。

5.2 PatchTST.__init__ 关键注册

原始代码:

python
# PatchTST.py:30-93
def __init__(self, config):
    super(PatchTST, self).__init__()
    self.task_name = config.task_name
    self.seq_len = config.seq_len
    self.pred_len = config.pred_len
    self.patch_len = config.patch_len
    self.stride = config.stride
    padding = self.stride

    # patching and embedding
    self.patch_embedding = PatchEmbedding(
        config.d_model, self.patch_len, self.stride, padding, config.dropout
    )

    # Encoder
    self.encoder = Encoder(
        [
            EncoderLayer(
                AttentionLayer(
                    FullAttention(
                        False,
                        config.factor,
                        attention_dropout=config.dropout,
                        output_attention=config.output_attention,
                    ),
                    config.d_model,
                    config.n_heads,
                ),
                config.d_model,
                config.d_ff,
                dropout=config.dropout,
                activation=config.activation,
            )
            for l in range(config.e_layers)
        ],
        norm_layer=torch.nn.LayerNorm(config.d_model),
    )

    # Prediction Head
    self.head_nf = config.d_model * int(
        (config.seq_len - self.patch_len) / self.stride + 2
    )
    if (
        self.task_name == "long_term_forecast"
        or self.task_name == "short_term_forecast"
    ):
        self.head = FlattenHead(
            config.enc_in,
            self.head_nf,
            config.pred_len,
            head_dropout=config.dropout,
        )
    elif self.task_name == "imputation" or self.task_name == "anomaly_detection":
        self.head = FlattenHead(
            config.enc_in, self.head_nf, config.seq_len, head_dropout=config.dropout
        )
    elif self.task_name == "classification":
        self.flatten = nn.Flatten(start_dim=-2)
        self.dropout = nn.Dropout(config.dropout)
        self.projection = nn.Linear(self.head_nf * config.enc_in, config.num_class)

注解版(聚焦 task_name if/else 和 toy 数值):

python
self.task_name = config.task_name  # TFB 中 = "long_term_forecast" 或 "short_term_forecast"
self.patch_len = config.patch_len  # toy: 4
self.stride = config.stride        # toy: 2
padding = self.stride              # ⚠️ padding 不是独立参数,直接等于 stride,toy: 2

# patch_embedding / encoder 与 task_name 无关,无论什么任务都注册
self.patch_embedding = PatchEmbedding(d_model=16, patch_len=4, stride=2, padding=2, ...)
self.encoder = Encoder([EncoderLayer(...1], norm_layer=LayerNorm(16))

# head_nf 计算(所有 task 共用)
self.head_nf = 16 * int((12-4)/2 + 2) = 16 * 6 = 96

# ── if/else:根据 task_name 注册不同的 head ──

# TFB 实际走这条(forecast):
self.head = FlattenHead(n_vars=4, nf=96, target_window=pred_len=3)
# Linear(96, 3),输出 pred_len 步预测

# imputation/anomaly_detection(不走):
self.head = FlattenHead(n_vars=4, nf=96, target_window=seq_len=12)
# Linear(96, 12),输出补全/重建全序列

# classification(不走):
self.flatten = nn.Flatten(start_dim=-2)  # 不用 FlattenHead
self.projection = nn.Linear(head_nf × enc_in=96×4=384, num_class)
# 全局特征拼接后直接分类

两条路径对比

forecast/short_forecast(TFB 实际)imputation/anomalyclassification
self.headFlattenHead(target=pred_len=3)FlattenHead(target=seq_len=12)不注册 head
Linear 输出维度pred_len=3seq_len=12num_class
目标预测未来 pred_len 步重建完整序列全局分类
TFB 走哪条✓ 此路径

toy 数值head.linear.weight.shape = (pred_len=3, head_nf=96)

5.3 _process() 构造 dec_input

python
# adapters_for_transformers.py
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
# zeros shape: (B, pred_len, enc_in) = (2, 3, 4)

dec_input = torch.cat(
    [target[:, :self.config.label_len, :], dec_input], dim=1
).float().to(input.device)
# cat 后 shape: (B, label_len+3, enc_in)

PatchTST 的 forecast() 签名接收这个 dec_input,但完全不使用它。这是框架统一接口的开销。

5.4 forward() 分支判断

原始代码:

python
# PatchTST.py:222-238
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if (
        self.task_name == "long_term_forecast"
        or self.task_name == "short_term_forecast"
    ):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len :, :]  # [B, L, D]
    if self.task_name == "imputation":
        dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
        return dec_out  # [B, L, D]
    if self.task_name == "anomaly_detection":
        dec_out = self.anomaly_detection(x_enc)
        return dec_out  # [B, L, D]
    if self.task_name == "classification":
        dec_out = self.classification(x_enc, x_mark_enc)
        return dec_out  # [B, N]
    return None

4 条分支对比

task_name调用函数输入返回 shapeTFB 实际
long/short_term_forecastforecast(x_enc,...)x_enc(B, pred_len, C) = (2,3,4)✓ 此路径
imputationimputation(x_enc,...,mask)x_enc + mask(B, seq_len, C)
anomaly_detectionanomaly_detection(x_enc)只用 x_enc(B, seq_len, C)
classificationclassification(x_enc,x_mark_enc)x_enc + mark(B, num_class)

forecast 返回 dec_out[:, -self.pred_len:, :];由于 forecast() 本身就只返回 (B, pred_len, C),这里的切片实际上取全部,是防御性写法。

toy 数值:dec_out shape 为 (2, 3, 4) = (B, pred_len, enc_in)[:, -3:, :] 取全部。


6. 下钻子组件

子组件职责文档
forecast(x_enc, ...)PatchTST 的主计算:归一化+patch+encoder+head02-Layer1-forecast主链

*记录并在线阅读我的笔记*