Appearance
_eval_batch 代码精读
Abstract
入口:
pythonForecastingStrategy.execute(...) -> RollingForecast._execute(...) -> RollingForecast._eval_batch(...)这篇只做一件事:
按源码真实顺序,把
_eval_batch(...)从头到尾讲清楚。
1. 这篇和现有文档的关系
这篇不是新的层级文档。
它建立在这些文档之上:
- [[24-Level4-rolling任务主体]]
- [[25-Level5-_eval_batch四段总览]]
- [[26-Level6-5B-训练子块-fit_method与forecast_fit]]
- [[27-Level6-5C-预测子块]]
- [[28-Level6-5D-打分与收尾子块]]
这篇的作用是:
把
25/26/27/28的抽象分块,重新对齐回源码顺序。
2. 第一性
_eval_batch(...) 的第一性不是“训练模型”,也不是“算指标”。
它真正做的是:
在一条时间序列上,按 rolling forecast 协议,先 fit 一次模型,再批量生成 rolling 预测,最后统一打分并组装结果。
所以这一个函数里同时包含三件事:
- 训练入口
- 预测入口
- 结果评测与封装
3. 源码入口
文件:
函数签名:
python
def _eval_batch(
self,
series: pd.DataFrame,
meta_info: Optional[pd.Series],
model: ModelBase,
series_name: str,
) -> List:输入接口:
seriesmeta_infomodelseries_name
输出接口:
single_series_results: List
4. 当前最小例子
这篇都用当前最小例子解释:
series.shape = (17420, 7)tv_ratio = 0.8train_ratio_in_tv = 0.75horizon = 24stride = 24num_rollings = 48
由此得到:
train_length = 13936test_length = 3484
5. 源码顺序总图
这张图只表示:
真实执行顺序。
6. 第 1 段:读取配置
代码:
python
target_channel = self._get_scalar_config_value("target_channel", series_name)
stride = self._get_scalar_config_value("stride", series_name)
horizon = self._get_scalar_config_value("horizon", series_name)
num_rollings = self._get_scalar_config_value("num_rollings", series_name)
train_ratio_in_tv = self._get_scalar_config_value("train_ratio_in_tv", series_name)
tv_ratio = self._get_scalar_config_value("tv_ratio", series_name)这一段的作用:
- 从 strategy 配置里取当前序列真正使用的参数
输出的关键变量:
target_channelstridehorizonnum_rollingstrain_ratio_in_tvtv_ratio
当前例子里最关键的是:
horizon = 24stride = 24num_rollings = 48tv_ratio = 0.8
7. 第 2 段:切 train/test
代码:
python
train_length, test_length = self._get_split_lens(series, meta_info, tv_ratio)
train_valid_data, test_data = split_time(series, train_length)这一段的作用:
- 先算长度
- 再按时间顺序切成:
train_valid_datatest_data
当前例子里:
train_length = 13936test_length = 3484train_valid_data.shape = (13936, 7)test_data.shape = (3484, 7)
8. 第 3 段:构造训练目标与协变量
代码:
python
target_train_valid_data, exog_train_valid_data = split_channel(
train_valid_data, target_channel
)
target4batch, exog_data4batch = split_channel(series, target_channel)
covariates_train, covariates4batch = {}, {}
covariates_train["exog"] = exog_train_valid_data
covariates4batch["exog"] = exog_data4batch这一段要分清两个层次:
8.1 训练用
target_train_valid_datacovariates_train
这是给后面的 fit_method(...) 用的。
8.2 预测用
target4batchcovariates4batch
这是给后面的 RollingForecastEvalBatchMaker(...) 用的。
所以这里不是重复切两次,而是:
一次为训练准备,一次为 rolling 预测准备。
9. 第 4 段:进入训练子块
代码:
python
start_fit_time = time.time()
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit
fit_method(
target_train_valid_data,
covariates=covariates_train,
train_ratio_in_tv=train_ratio_in_tv,
)
end_fit_time = time.time()这一段是:
_eval_batch(...)正式进入训练线的入口。
当前例子里:
model不是裸模型- 而是
TransformerAdapter(...)
所以这里实际走的是:
python
model.forecast_fit(...)再展开就是:
python
DeepForecastingModelBase.forecast_fit(...)这正对应:
- [[26-Level6-5B-训练子块-fit_method与forecast_fit]]
输出的关键结果:
- 模型已经 fit 完成
fit_time = end_fit_time - start_fit_time
10. 第 5 段:准备评测 scaler
代码:
python
eval_scaler = self._get_eval_scaler(target_train_valid_data, train_ratio_in_tv)这一段的作用:
- 为后面 metric 计算准备统一缩放器
这一步不进入模型内部,但它直接影响:
evaluator.evaluate(...)
11. 第 6 段:生成 rolling 预测请求
代码:
python
index_list = self._get_index(train_length, test_length, horizon, stride)
index_list = index_list[:num_rollings]作用:
- 把测试区间变成一组 rolling 起点
当前例子里:
stride = 24horizon = 24num_rollings = 48
所以这里得到的是:
前 48 个 rolling 预测窗口的起始索引列表。
12. 第 7 段:构造 batch maker
代码:
python
batch_maker = RollingForecastEvalBatchMaker(
target4batch,
index_list,
covariates4batch,
)
predict_batch_maker = RollingForecastPredictBatchMaker(batch_maker)作用:
batch_maker:掌握全部 rolling 目标窗口predict_batch_maker:提供一个能持续吐出预测 batch 的接口
这一步是:
_eval_batch(...)正式进入预测子块前的桥。
13. 第 8 段:批量 rolling 预测
代码:
python
all_predicts = []
total_inference_time = 0
while predict_batch_maker.has_more_batches():
start_inference_time = time.time()
predicts = model.batch_forecast(horizon, predict_batch_maker)
end_inference_time = time.time()
total_inference_time += end_inference_time - start_inference_time
all_predicts.append(predicts)
all_predicts = np.concatenate(all_predicts, axis=0)这一段是:
_eval_batch(...)进入预测主链的核心入口。
当前例子里,这里实际会进入:
python
model.batch_forecast(...)再展开就是:
python
DeepForecastingModelBase.batch_forecast(...)
-> _perform_rolling_predictions(...)
-> _process(...)
-> DLinear.forward(...) / Informer.forward(...)这正对应:
- [[27-Level6-5C-预测子块]]
- [[31-Level7-batch_forecast主链]]
- [[32-Level8-_perform_rolling_predictions]]
14. 第 9 段:生成真实目标
代码:
python
targets = batch_maker.make_batch_eval(horizon)["target"]作用:
- 生成与
all_predicts一一对齐的真实 rolling 目标
这里的 targets 不是训练集,也不是整个测试集。
它是:
与每个 rolling 预测窗口对应的真实目标片段。
15. 第 10 段:逐窗口打分
代码:
python
all_test_results = []
for predicts, target in zip(all_predicts, targets):
single_series_results = self.evaluator.evaluate(
target,
predicts,
eval_scaler,
target_train_valid_data.values,
)
all_test_results.append(single_series_results)
single_series_results = np.mean(np.stack(all_test_results), axis=0).tolist()这一段的作用:
- 每个 rolling 窗口各算一次指标
- 再对所有 rolling 窗口的指标取平均
这一步就是:
rolling forecast 评测协议真正落到数字上的地方。
16. 第 11 段:组装最终结果
代码:
python
save_true_pred = self._get_scalar_config_value("save_true_pred", series_name)
actual_data_encoded = self._encode_data(targets) if save_true_pred else np.nan
inference_data_encoded = (
self._encode_data(all_predicts) if save_true_pred else np.nan
)
single_series_results += [
series_name,
end_fit_time - start_fit_time,
total_inference_time,
actual_data_encoded,
inference_data_encoded,
"",
]
return single_series_results这一段的作用:
- 把指标结果和附加字段拼起来
最终返回的内容包含:
- metric 结果
- 文件名
- fit 时间
- inference 时间
- 可选真实值编码
- 可选预测值编码
- log 信息
17. 抽象分块怎么和源码顺序对应
前面你已经有抽象分块:
5A 配置与切数5B 训练子块5C 预测子块5D 打分与收尾子块
现在把它们对回源码顺序:
| 抽象分块 | 源码顺序 |
|---|---|
5A | 第 1 段 到 第 3 段 |
5B | 第 4 段 到 第 5 段 |
5C | 第 6 段 到 第 8 段 |
5D | 第 9 段 到 第 11 段 |
所以旧文档的“四段”没有错,但它是抽象分块视角。
这篇则是真实代码顺序视角。
18. 你接下来该怎么用这篇
这篇读完后,下一步不要再反复啃 _eval_batch。
接下来应该固定两件事:
_eval_batch(...)只是外层协议主体- 真正需要进入模型阅读的入口有两个:
fit_method -> forecast_fit(...)model.batch_forecast(...)
然后你就可以开始做你新的 roadmap:
在固定
_eval_batch外层协议的前提下,细读模型代码,建立“代码 - 论文理论”映射。
19. 最后压成 6 句话
_eval_batch(...)是 rolling forecast 单序列评测的业务主体。- 它先读配置、切数据,再进入训练、预测、打分三大动作。
- 训练入口是
fit_method(...),当前例子里实际走到forecast_fit(...)。 - 预测入口是
model.batch_forecast(...)。 - 打分不是对整段测试集一次完成,而是对每个 rolling 窗口逐个评测再求平均。
- 读懂
_eval_batch(...)的意义,是固定 benchmark 协议外壳,之后把注意力转向模型本身。