Appearance
Level 6 5B 训练子块:fit_method 与 forecast_fit
Abstract
入口:
pythonfit_method(...)当前例子里,它实际走到:
pythonmodel.forecast_fit(...)这一层解释:
_eval_batch(...)里的训练子块,怎样从统一训练入口fit_method进入forecast_fit(...)主链。
1. 上一层与当前层位置
上一层是:
当前层是:
- Level 6:只细分
5B 训练子块。
这一层还不进入:
- 预测子块
5C - 打分子块
5D - 训练循环内部每一行
1.5 这一层的关系类型说明
这一层是在继续下钻:
- Level 5 里的
5B 训练子块
所以这里的 6A / 6B / 6C / 6D / 6E:
- 是 训练子块内部的逻辑分块
- 不是 Level 5 的新同级主节点
- 也不是外层对象包含关系
1.6 同层文件关系
这一层和下面两篇是同层兄弟文档:
三者之间没有“前一篇调用后一篇”的代码关系。
它们只是共同对应 25 里的三个兄弟子块:
5B5C5D
2. 当前层第一性
5B 训练子块 的第一性是:
把
_eval_batch(...)里已经切好的训练数据,交给模型的统一训练入口,完成一次正式 fit。
2.5 当前例子里 model 到底是什么
当前例子里,_eval_batch(...) 里拿到的 model 不是裸的 DLinear,而是:
python
TransformerAdapter(...)它的来源是:
python
model = model_factory()而 TransformerAdapter 的继承关系是:
python
TransformerAdapter(DeepForecastingModelBase)所以当前 model 实例天然就带着:
forecast_fit(...)batch_forecast(...)_init_model(...)_process(...)
这些方法。
这里要分清两层:
2.1 fit_method
它不是一个新算法,而是一个兼容入口变量:
python
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit它的意义是:
无论底层模型提供的是
forecast_fit(...)还是fit(...),策略层都只写一套调用逻辑。
更具体一点:
python
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit这里的 fit_method 不是字符串,也不是新的函数定义。
它是一个方法变量,保存的是:
- 如果
model有forecast_fit,就保存model.forecast_fit - 否则就保存
model.fit
在当前例子里,因为:
model = TransformerAdapter(...)TransformerAdapter继承自DeepForecastingModelBaseDeepForecastingModelBase定义了forecast_fit(...)
所以这里实际等价于:
python
fit_method = model.forecast_fit而这里的 model.forecast_fit 又是一个绑定方法,也就是:
python
DeepForecastingModelBase.forecast_fit.__get__(model, TransformerAdapter)你可以把它先简单理解成:
“已经绑定到当前
model实例上的forecast_fit方法入口”。
2.2 forecast_fit(...)
它才是当前例子里真正执行训练的主函数。
所以这两句在当前例子里的真实含义就是:
python
fit_method = model.forecast_fit
fit_method(...)也就是:
python
model.forecast_fit(...)再展开一层就是:
python
DeepForecastingModelBase.forecast_fit(model, ...)这里只是 Python 的绑定方法机制把 self=model 自动带进去了。
3. 当前命令的最小例子
这一层以下文这组数字作为例子:
series.shape = (17420, 7)tv_ratio = 0.8train_length = 13936train_valid_data.shape = (13936, 7)train_ratio_in_tv = 0.75batch_size = 4seq_len = 96horizon = 24model = TransformerAdapter(DLinear)
在 _eval_batch(...) 里,训练子块对应的是:
python
target_train_valid_data, exog_train_valid_data = split_channel(
train_valid_data, target_channel
)
covariates_train = {"exog": exog_train_valid_data}
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit
fit_method(
target_train_valid_data,
covariates=covariates_train,
train_ratio_in_tv=train_ratio_in_tv,
)当前例子里,上面这段代码可以直接改写成:
python
model.forecast_fit(
target_train_valid_data,
covariates=covariates_train,
train_ratio_in_tv=train_ratio_in_tv,
)再按类定义展开就是:
python
DeepForecastingModelBase.forecast_fit(
model,
target_train_valid_data,
covariates=covariates_train,
train_ratio_in_tv=train_ratio_in_tv,
)所以:
fit_method(...)就是当前RollingForecast训练子块真正进入DeepForecastingModelBase.forecast_fit(...)的入口。
4. 先看顺序图
这张图回答的是:
训练子块在代码执行上怎样推进到训练循环入口。
5. 再看抽象索引树
这棵树回答的是:
为了理解
forecast_fit(...),这一层最重要的五个逻辑块是什么。
6. 职责树
6.1 6A 统一训练入口
职责:
- 兼容
forecast_fit(...)和fit(...) - 让上层策略代码不用关心具体模型接口差异
6.2 6B 数据准备
职责:
- 合并
covariates["exog"] - 根据数据形态补齐
freq / enc_in / dec_in / c_out / label_len
6.3 6C 模型准备
职责:
- 真正实例化底层 PyTorch 模型
6.4 6D DataLoader 准备
职责:
- 切 train/valid
- 做归一化
- 构造
DatasetForTransformer和DataLoader
6.5 6E 优化器与训练循环入口
职责:
- 初始化 loss / optimizer / early stopping
- 进入 epoch-batch 训练循环
7. 输入输出接口
7.1 本层入口接口
_eval_batch(...) 中的训练入口:
python
fit_method(
target_train_valid_data,
covariates=covariates_train,
train_ratio_in_tv=train_ratio_in_tv,
)7.2 forecast_fit(...) 的函数形参
python
forecast_fit(
train_valid_data: pd.DataFrame,
*,
covariates: Optional[dict] = None,
train_ratio_in_tv: float = 1.0,
**kwargs,
)7.3 本层关键输入
train_valid_datacovariatestrain_ratio_in_tv
7.4 本层关键中间变量
series_dimexog_datatrain_datavalid_dataself.train_data_loadercriterionoptimizer
7.5 本层输出
- 返回值语义上是“已完成 fit 的模型对象”
- 更关键的是副作用:
self.model已初始化并训练self.scaler已 fitself.train_data_loader已建立
8. 函数 / 文件关系图
9. 关键代码对应关系
9.1 统一训练入口
在 rolling_forecast.py:
python
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit当前例子里,因为 TransformerAdapter 提供了 forecast_fit(...),所以实际走:
python
model.forecast_fit(...)9.2 数据依赖参数补齐
在 deep_forecasting_model_base.py:
python
self.multi_forecasting_hyper_param_tune(train_valid_data)当前例子里,这一步会补:
freqenc_in = 7dec_in = 7c_out = 7label_len = seq_len // 2 = 48
9.3 真正初始化底层模型
在 deep_forecasting_model_base.py:
python
self.model = self._init_model()再往下在 adapter 里会变成:
python
self.model_class(self.config)当前例子最终就是:
python
DLinear(self.config)9.4 建立 DataLoader
在 deep_forecasting_model_base.py:
python
train_dataset, self.train_data_loader = forecasting_data_provider(...)这一步会正式回到你阶段一学过的那条链:
DatasetForTransformerDataLoader- batch 四元组
10. 当前例子的具体落地
10.1 输入到 forecast_fit(...)
train_valid_data.shape约为(13936, 7)covariates_train["exog"]是外生变量部分
10.2 参数补齐后
enc_in = 7dec_in = 7c_out = 7label_len = 48freq = 'h'
10.3 训练数据侧
按 train_ratio_in_tv = 0.75 再切出:
train_datavalid_data
然后建立:
self.train_data_loadervalid_data_loader
10.4 最后进入训练循环前
已经具备:
- 模型
self.model - 训练 batch 流
self.train_data_loader criterionoptimizerearly_stopping
11. 当前层最重要的认知
这一层最需要稳定下来的不是训练循环细节,而是:
fit_method只是统一入口;真正的训练主链在forecast_fit(...);而forecast_fit(...)又负责把“原始训练数据”推进到“epoch-batch 训练循环入口”。
12. 下一层入口
下一层最自然的 BFS 细分是:
python
for epoch in range(config.num_epochs):
for i, (input, target, input_mark, target_mark) in enumerate(self.train_data_loader):也就是继续写:
Level 7 训练循环
对应下层文档:
13. 只留一句
Level 6 只看训练子块怎样从
fit_method进入forecast_fit(...),并推进到训练循环入口。