Appearance
Layer 0 — 接入界面
本文回答一个问题:TFB 如何把 PatchTST 当黑盒使用?
包含两面:实例化侧(config →
PatchTST.__init__)和调用侧(_process→forward的 I/O 契约)。
1. 在父层中的位置
TFB benchmark 的最外层循环最终走到这里:TransformerAdapter 负责实例化 PatchTST(config) 并在每个 batch 里调用 PatchTST.forward()。本文档覆盖这一整个"接入边界"。
2. I/O 接口定义
实例化侧
python
model = PatchTST(config)| config 字段 | 来源 | PatchTST 里的用途 |
|---|---|---|
seq_len | 命令行 --seq_len(默认96) | self.seq_len,序列长度 |
pred_len | 命令行 --pred_len(即 horizon) | self.pred_len,FlattenHead 输出维度 |
patch_len | MODEL_HYPER_PARAMS["patch_len"]=16 | self.patch_len,patch 窗口大小 |
stride | MODEL_HYPER_PARAMS["stride"]=8 | self.stride,patch 步长;同时 padding=stride |
d_model | MODEL_HYPER_PARAMS["d_model"]=512 | embedding 和注意力维度 |
n_heads | MODEL_HYPER_PARAMS["n_heads"]=8 | 多头注意力头数 |
e_layers | MODEL_HYPER_PARAMS["e_layers"]=2 | Encoder 层数 |
d_ff | MODEL_HYPER_PARAMS["d_ff"]=2048 | FFN 中间维度 |
enc_in | 框架自动推断(数据集特征列数) | config.enc_in,FlattenHead 的 n_vars |
task_name | 命令行 --task_name | forward() 分支判断 |
注意:
padding不是独立参数,直接赋值padding = self.stride(PatchTST.py:41)。
toy 中padding = stride = 2,真实值padding = stride = 8。
调用侧 — _process 构造的四元组
python
output = self.model(input, input_mark, dec_input, target_mark)| 参数 | shape(toy) | 含义 |
|---|---|---|
input (= x_enc) | (2, 12, 4) = (B, seq_len, enc_in) | encoder 历史输入 |
input_mark (= x_mark_enc) | (2, 12, 4) | 时间戳特征,PatchTST 不使用 |
dec_input (= x_dec) | (2, label_len+3, 4) | decoder 输入,PatchTST 不使用 |
target_mark (= x_mark_dec) | (2, label_len+3, 4) | 时间戳特征,PatchTST 不使用 |
| 返回 | (2, 3, 4) = (B, pred_len, enc_in) | 预测输出 |
PatchTST 的
forecast()接收四个参数,但实际只用了x_enc,其余三个完全忽略(forecast 函数签名接收但内部不读取 mark 和 x_dec)。
3. 顺序图(具体层)
4. 语义分组图(索引层)
5. 精读
5.1 参数合并:Config 对象如何构建
优先级:命令行参数 > MODEL_HYPER_PARAMS > DEFAULT_HYPER_PARAMS。
python
# MODEL_HYPER_PARAMS 中与 PatchTST 直接相关的值(toy 中被覆盖为小数值):
"patch_len": 16, # toy: 4
"stride": 8, # toy: 2
"d_model": 512, # toy: 16
"n_heads": 8, # toy: 2
"e_layers": 2, # toy: 1
"d_ff": 2048, # toy: 64enc_in 不在命令行里,由框架在 forecast_fit 里根据数据集特征列数自动注入,toy 值 = 4。
5.2 PatchTST.__init__ 关键注册
原始代码:
python
# PatchTST.py:30-93
def __init__(self, config):
super(PatchTST, self).__init__()
self.task_name = config.task_name
self.seq_len = config.seq_len
self.pred_len = config.pred_len
self.patch_len = config.patch_len
self.stride = config.stride
padding = self.stride
# patching and embedding
self.patch_embedding = PatchEmbedding(
config.d_model, self.patch_len, self.stride, padding, config.dropout
)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(
False,
config.factor,
attention_dropout=config.dropout,
output_attention=config.output_attention,
),
config.d_model,
config.n_heads,
),
config.d_model,
config.d_ff,
dropout=config.dropout,
activation=config.activation,
)
for l in range(config.e_layers)
],
norm_layer=torch.nn.LayerNorm(config.d_model),
)
# Prediction Head
self.head_nf = config.d_model * int(
(config.seq_len - self.patch_len) / self.stride + 2
)
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
):
self.head = FlattenHead(
config.enc_in,
self.head_nf,
config.pred_len,
head_dropout=config.dropout,
)
elif self.task_name == "imputation" or self.task_name == "anomaly_detection":
self.head = FlattenHead(
config.enc_in, self.head_nf, config.seq_len, head_dropout=config.dropout
)
elif self.task_name == "classification":
self.flatten = nn.Flatten(start_dim=-2)
self.dropout = nn.Dropout(config.dropout)
self.projection = nn.Linear(self.head_nf * config.enc_in, config.num_class)注解版(聚焦 task_name if/else 和 toy 数值):
python
self.task_name = config.task_name # TFB 中 = "long_term_forecast" 或 "short_term_forecast"
self.patch_len = config.patch_len # toy: 4
self.stride = config.stride # toy: 2
padding = self.stride # ⚠️ padding 不是独立参数,直接等于 stride,toy: 2
# patch_embedding / encoder 与 task_name 无关,无论什么任务都注册
self.patch_embedding = PatchEmbedding(d_model=16, patch_len=4, stride=2, padding=2, ...)
self.encoder = Encoder([EncoderLayer(...)×1], norm_layer=LayerNorm(16))
# head_nf 计算(所有 task 共用)
self.head_nf = 16 * int((12-4)/2 + 2) = 16 * 6 = 96
# ── if/else:根据 task_name 注册不同的 head ──
# TFB 实际走这条(forecast):
self.head = FlattenHead(n_vars=4, nf=96, target_window=pred_len=3)
# Linear(96, 3),输出 pred_len 步预测
# imputation/anomaly_detection(不走):
self.head = FlattenHead(n_vars=4, nf=96, target_window=seq_len=12)
# Linear(96, 12),输出补全/重建全序列
# classification(不走):
self.flatten = nn.Flatten(start_dim=-2) # 不用 FlattenHead
self.projection = nn.Linear(head_nf × enc_in=96×4=384, num_class)
# 全局特征拼接后直接分类两条路径对比:
| forecast/short_forecast(TFB 实际) | imputation/anomaly | classification | |
|---|---|---|---|
self.head | FlattenHead(target=pred_len=3) | FlattenHead(target=seq_len=12) | 不注册 head |
| Linear 输出维度 | pred_len=3 | seq_len=12 | num_class |
| 目标 | 预测未来 pred_len 步 | 重建完整序列 | 全局分类 |
| TFB 走哪条 | ✓ 此路径 | ✗ | ✗ |
toy 数值:head.linear.weight.shape = (pred_len=3, head_nf=96)。
5.3 _process() 构造 dec_input
python
# adapters_for_transformers.py
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
# zeros shape: (B, pred_len, enc_in) = (2, 3, 4)
dec_input = torch.cat(
[target[:, :self.config.label_len, :], dec_input], dim=1
).float().to(input.device)
# cat 后 shape: (B, label_len+3, enc_in)PatchTST 的
forecast()签名接收这个dec_input,但完全不使用它。这是框架统一接口的开销。
5.4 forward() 分支判断
原始代码:
python
# PatchTST.py:222-238
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
):
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :] # [B, L, D]
if self.task_name == "imputation":
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == "anomaly_detection":
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == "classification":
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None4 条分支对比:
| task_name | 调用函数 | 输入 | 返回 shape | TFB 实际 |
|---|---|---|---|---|
long/short_term_forecast | forecast(x_enc,...) | x_enc | (B, pred_len, C) = (2,3,4) | ✓ 此路径 |
imputation | imputation(x_enc,...,mask) | x_enc + mask | (B, seq_len, C) | ✗ |
anomaly_detection | anomaly_detection(x_enc) | 只用 x_enc | (B, seq_len, C) | ✗ |
classification | classification(x_enc,x_mark_enc) | x_enc + mark | (B, num_class) | ✗ |
forecast返回dec_out[:, -self.pred_len:, :];由于forecast()本身就只返回(B, pred_len, C),这里的切片实际上取全部,是防御性写法。
toy 数值:dec_out shape 为 (2, 3, 4) = (B, pred_len, enc_in),[:, -3:, :] 取全部。
6. 下钻子组件
| 子组件 | 职责 | 文档 |
|---|---|---|
forecast(x_enc, ...) | PatchTST 的主计算:归一化+patch+encoder+head | 02-Layer1-forecast主链 |