Skip to content

Level 6 5B 训练子块:fit_method 与 forecast_fit

Abstract

入口:

python
fit_method(...)

当前例子里,它实际走到:

python
model.forecast_fit(...)

这一层解释:

_eval_batch(...) 里的训练子块,怎样从统一训练入口 fit_method 进入 forecast_fit(...) 主链。

1. 上一层与当前层位置

上一层是:

当前层是:

  • Level 6:只细分 5B 训练子块

这一层还不进入:

  • 预测子块 5C
  • 打分子块 5D
  • 训练循环内部每一行

1.5 这一层的关系类型说明

这一层是在继续下钻:

  • Level 5 里的 5B 训练子块

所以这里的 6A / 6B / 6C / 6D / 6E

  • 训练子块内部的逻辑分块
  • 不是 Level 5 的新同级主节点
  • 也不是外层对象包含关系

1.6 同层文件关系

这一层和下面两篇是同层兄弟文档:

三者之间没有“前一篇调用后一篇”的代码关系。
它们只是共同对应 25 里的三个兄弟子块:

  • 5B
  • 5C
  • 5D

2. 当前层第一性

5B 训练子块 的第一性是:

_eval_batch(...) 里已经切好的训练数据,交给模型的统一训练入口,完成一次正式 fit。

2.5 当前例子里 model 到底是什么

当前例子里,_eval_batch(...) 里拿到的 model 不是裸的 DLinear,而是:

python
TransformerAdapter(...)

它的来源是:

python
model = model_factory()

TransformerAdapter 的继承关系是:

python
TransformerAdapter(DeepForecastingModelBase)

所以当前 model 实例天然就带着:

  • forecast_fit(...)
  • batch_forecast(...)
  • _init_model(...)
  • _process(...)

这些方法。

这里要分清两层:

2.1 fit_method

它不是一个新算法,而是一个兼容入口变量

python
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit

它的意义是:

无论底层模型提供的是 forecast_fit(...) 还是 fit(...),策略层都只写一套调用逻辑。

更具体一点:

python
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit

这里的 fit_method 不是字符串,也不是新的函数定义。
它是一个方法变量,保存的是:

  • 如果 modelforecast_fit,就保存 model.forecast_fit
  • 否则就保存 model.fit

在当前例子里,因为:

  • model = TransformerAdapter(...)
  • TransformerAdapter 继承自 DeepForecastingModelBase
  • DeepForecastingModelBase 定义了 forecast_fit(...)

所以这里实际等价于:

python
fit_method = model.forecast_fit

而这里的 model.forecast_fit 又是一个绑定方法,也就是:

python
DeepForecastingModelBase.forecast_fit.__get__(model, TransformerAdapter)

你可以把它先简单理解成:

“已经绑定到当前 model 实例上的 forecast_fit 方法入口”。

2.2 forecast_fit(...)

它才是当前例子里真正执行训练的主函数。

所以这两句在当前例子里的真实含义就是:

python
fit_method = model.forecast_fit
fit_method(...)

也就是:

python
model.forecast_fit(...)

再展开一层就是:

python
DeepForecastingModelBase.forecast_fit(model, ...)

这里只是 Python 的绑定方法机制把 self=model 自动带进去了。

3. 当前命令的最小例子

这一层以下文这组数字作为例子:

  • series.shape = (17420, 7)
  • tv_ratio = 0.8
  • train_length = 13936
  • train_valid_data.shape = (13936, 7)
  • train_ratio_in_tv = 0.75
  • batch_size = 4
  • seq_len = 96
  • horizon = 24
  • model = TransformerAdapter(DLinear)

_eval_batch(...) 里,训练子块对应的是:

python
target_train_valid_data, exog_train_valid_data = split_channel(
    train_valid_data, target_channel
)
covariates_train = {"exog": exog_train_valid_data}

fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit
fit_method(
    target_train_valid_data,
    covariates=covariates_train,
    train_ratio_in_tv=train_ratio_in_tv,
)

当前例子里,上面这段代码可以直接改写成:

python
model.forecast_fit(
    target_train_valid_data,
    covariates=covariates_train,
    train_ratio_in_tv=train_ratio_in_tv,
)

再按类定义展开就是:

python
DeepForecastingModelBase.forecast_fit(
    model,
    target_train_valid_data,
    covariates=covariates_train,
    train_ratio_in_tv=train_ratio_in_tv,
)

所以:

fit_method(...) 就是当前 RollingForecast 训练子块真正进入 DeepForecastingModelBase.forecast_fit(...) 的入口。

4. 先看顺序图

这张图回答的是:

训练子块在代码执行上怎样推进到训练循环入口。

5. 再看抽象索引树

这棵树回答的是:

为了理解 forecast_fit(...),这一层最重要的五个逻辑块是什么。

6. 职责树

6.1 6A 统一训练入口

职责:

  • 兼容 forecast_fit(...)fit(...)
  • 让上层策略代码不用关心具体模型接口差异

6.2 6B 数据准备

职责:

  • 合并 covariates["exog"]
  • 根据数据形态补齐 freq / enc_in / dec_in / c_out / label_len

6.3 6C 模型准备

职责:

  • 真正实例化底层 PyTorch 模型

6.4 6D DataLoader 准备

职责:

  • 切 train/valid
  • 做归一化
  • 构造 DatasetForTransformerDataLoader

6.5 6E 优化器与训练循环入口

职责:

  • 初始化 loss / optimizer / early stopping
  • 进入 epoch-batch 训练循环

7. 输入输出接口

7.1 本层入口接口

_eval_batch(...) 中的训练入口:

python
fit_method(
    target_train_valid_data,
    covariates=covariates_train,
    train_ratio_in_tv=train_ratio_in_tv,
)

7.2 forecast_fit(...) 的函数形参

python
forecast_fit(
    train_valid_data: pd.DataFrame,
    *,
    covariates: Optional[dict] = None,
    train_ratio_in_tv: float = 1.0,
    **kwargs,
)

7.3 本层关键输入

  • train_valid_data
  • covariates
  • train_ratio_in_tv

7.4 本层关键中间变量

  • series_dim
  • exog_data
  • train_data
  • valid_data
  • self.train_data_loader
  • criterion
  • optimizer

7.5 本层输出

  • 返回值语义上是“已完成 fit 的模型对象”
  • 更关键的是副作用:
    • self.model 已初始化并训练
    • self.scaler 已 fit
    • self.train_data_loader 已建立

8. 函数 / 文件关系图

9. 关键代码对应关系

9.1 统一训练入口

rolling_forecast.py

python
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit

当前例子里,因为 TransformerAdapter 提供了 forecast_fit(...),所以实际走:

python
model.forecast_fit(...)

9.2 数据依赖参数补齐

deep_forecasting_model_base.py

python
self.multi_forecasting_hyper_param_tune(train_valid_data)

当前例子里,这一步会补:

  • freq
  • enc_in = 7
  • dec_in = 7
  • c_out = 7
  • label_len = seq_len // 2 = 48

9.3 真正初始化底层模型

deep_forecasting_model_base.py

python
self.model = self._init_model()

再往下在 adapter 里会变成:

python
self.model_class(self.config)

当前例子最终就是:

python
DLinear(self.config)

9.4 建立 DataLoader

deep_forecasting_model_base.py

python
train_dataset, self.train_data_loader = forecasting_data_provider(...)

这一步会正式回到你阶段一学过的那条链:

  • DatasetForTransformer
  • DataLoader
  • batch 四元组

10. 当前例子的具体落地

10.1 输入到 forecast_fit(...)

  • train_valid_data.shape 约为 (13936, 7)
  • covariates_train["exog"] 是外生变量部分

10.2 参数补齐后

  • enc_in = 7
  • dec_in = 7
  • c_out = 7
  • label_len = 48
  • freq = 'h'

10.3 训练数据侧

train_ratio_in_tv = 0.75 再切出:

  • train_data
  • valid_data

然后建立:

  • self.train_data_loader
  • valid_data_loader

10.4 最后进入训练循环前

已经具备:

  • 模型 self.model
  • 训练 batch 流 self.train_data_loader
  • criterion
  • optimizer
  • early_stopping

11. 当前层最重要的认知

这一层最需要稳定下来的不是训练循环细节,而是:

fit_method 只是统一入口;真正的训练主链在 forecast_fit(...);而 forecast_fit(...) 又负责把“原始训练数据”推进到“epoch-batch 训练循环入口”。

12. 下一层入口

下一层最自然的 BFS 细分是:

python
for epoch in range(config.num_epochs):
    for i, (input, target, input_mark, target_mark) in enumerate(self.train_data_loader):

也就是继续写:

  • Level 7 训练循环

对应下层文档:

13. 只留一句

Level 6 只看训练子块怎样从 fit_method 进入 forecast_fit(...),并推进到训练循环入口。

*记录并在线阅读我的笔记*