Skip to content

_eval_batch 代码精读

Abstract

入口:

python
ForecastingStrategy.execute(...)
-> RollingForecast._execute(...)
-> RollingForecast._eval_batch(...)

这篇只做一件事:

按源码真实顺序,把 _eval_batch(...) 从头到尾讲清楚。

1. 这篇和现有文档的关系

这篇不是新的层级文档。

它建立在这些文档之上:

  • [[24-Level4-rolling任务主体]]
  • [[25-Level5-_eval_batch四段总览]]
  • [[26-Level6-5B-训练子块-fit_method与forecast_fit]]
  • [[27-Level6-5C-预测子块]]
  • [[28-Level6-5D-打分与收尾子块]]

这篇的作用是:

25/26/27/28 的抽象分块,重新对齐回源码顺序。

2. 第一性

_eval_batch(...) 的第一性不是“训练模型”,也不是“算指标”。

它真正做的是:

在一条时间序列上,按 rolling forecast 协议,先 fit 一次模型,再批量生成 rolling 预测,最后统一打分并组装结果。

所以这一个函数里同时包含三件事:

  1. 训练入口
  2. 预测入口
  3. 结果评测与封装

3. 源码入口

文件:

函数签名:

python
def _eval_batch(
    self,
    series: pd.DataFrame,
    meta_info: Optional[pd.Series],
    model: ModelBase,
    series_name: str,
) -> List:

输入接口:

  • series
  • meta_info
  • model
  • series_name

输出接口:

  • single_series_results: List

4. 当前最小例子

这篇都用当前最小例子解释:

  • series.shape = (17420, 7)
  • tv_ratio = 0.8
  • train_ratio_in_tv = 0.75
  • horizon = 24
  • stride = 24
  • num_rollings = 48

由此得到:

  • train_length = 13936
  • test_length = 3484

5. 源码顺序总图

这张图只表示:

真实执行顺序。

6. 第 1 段:读取配置

代码:

python
target_channel = self._get_scalar_config_value("target_channel", series_name)
stride = self._get_scalar_config_value("stride", series_name)
horizon = self._get_scalar_config_value("horizon", series_name)
num_rollings = self._get_scalar_config_value("num_rollings", series_name)
train_ratio_in_tv = self._get_scalar_config_value("train_ratio_in_tv", series_name)
tv_ratio = self._get_scalar_config_value("tv_ratio", series_name)

这一段的作用:

  • 从 strategy 配置里取当前序列真正使用的参数

输出的关键变量:

  • target_channel
  • stride
  • horizon
  • num_rollings
  • train_ratio_in_tv
  • tv_ratio

当前例子里最关键的是:

  • horizon = 24
  • stride = 24
  • num_rollings = 48
  • tv_ratio = 0.8

7. 第 2 段:切 train/test

代码:

python
train_length, test_length = self._get_split_lens(series, meta_info, tv_ratio)
train_valid_data, test_data = split_time(series, train_length)

这一段的作用:

  • 先算长度
  • 再按时间顺序切成:
    • train_valid_data
    • test_data

当前例子里:

  • train_length = 13936
  • test_length = 3484
  • train_valid_data.shape = (13936, 7)
  • test_data.shape = (3484, 7)

8. 第 3 段:构造训练目标与协变量

代码:

python
target_train_valid_data, exog_train_valid_data = split_channel(
    train_valid_data, target_channel
)
target4batch, exog_data4batch = split_channel(series, target_channel)
covariates_train, covariates4batch = {}, {}
covariates_train["exog"] = exog_train_valid_data
covariates4batch["exog"] = exog_data4batch

这一段要分清两个层次:

8.1 训练用

  • target_train_valid_data
  • covariates_train

这是给后面的 fit_method(...) 用的。

8.2 预测用

  • target4batch
  • covariates4batch

这是给后面的 RollingForecastEvalBatchMaker(...) 用的。

所以这里不是重复切两次,而是:

一次为训练准备,一次为 rolling 预测准备。

9. 第 4 段:进入训练子块

代码:

python
start_fit_time = time.time()
fit_method = model.forecast_fit if hasattr(model, "forecast_fit") else model.fit
fit_method(
    target_train_valid_data,
    covariates=covariates_train,
    train_ratio_in_tv=train_ratio_in_tv,
)
end_fit_time = time.time()

这一段是:

_eval_batch(...) 正式进入训练线的入口。

当前例子里:

  • model 不是裸模型
  • 而是 TransformerAdapter(...)

所以这里实际走的是:

python
model.forecast_fit(...)

再展开就是:

python
DeepForecastingModelBase.forecast_fit(...)

这正对应:

  • [[26-Level6-5B-训练子块-fit_method与forecast_fit]]

输出的关键结果:

  • 模型已经 fit 完成
  • fit_time = end_fit_time - start_fit_time

10. 第 5 段:准备评测 scaler

代码:

python
eval_scaler = self._get_eval_scaler(target_train_valid_data, train_ratio_in_tv)

这一段的作用:

  • 为后面 metric 计算准备统一缩放器

这一步不进入模型内部,但它直接影响:

  • evaluator.evaluate(...)

11. 第 6 段:生成 rolling 预测请求

代码:

python
index_list = self._get_index(train_length, test_length, horizon, stride)
index_list = index_list[:num_rollings]

作用:

  • 把测试区间变成一组 rolling 起点

当前例子里:

  • stride = 24
  • horizon = 24
  • num_rollings = 48

所以这里得到的是:

前 48 个 rolling 预测窗口的起始索引列表。

12. 第 7 段:构造 batch maker

代码:

python
batch_maker = RollingForecastEvalBatchMaker(
    target4batch,
    index_list,
    covariates4batch,
)
predict_batch_maker = RollingForecastPredictBatchMaker(batch_maker)

作用:

  • batch_maker:掌握全部 rolling 目标窗口
  • predict_batch_maker:提供一个能持续吐出预测 batch 的接口

这一步是:

_eval_batch(...) 正式进入预测子块前的桥。

13. 第 8 段:批量 rolling 预测

代码:

python
all_predicts = []
total_inference_time = 0
while predict_batch_maker.has_more_batches():
    start_inference_time = time.time()
    predicts = model.batch_forecast(horizon, predict_batch_maker)
    end_inference_time = time.time()
    total_inference_time += end_inference_time - start_inference_time
    all_predicts.append(predicts)
all_predicts = np.concatenate(all_predicts, axis=0)

这一段是:

_eval_batch(...) 进入预测主链的核心入口。

当前例子里,这里实际会进入:

python
model.batch_forecast(...)

再展开就是:

python
DeepForecastingModelBase.batch_forecast(...)
-> _perform_rolling_predictions(...)
-> _process(...)
-> DLinear.forward(...) / Informer.forward(...)

这正对应:

  • [[27-Level6-5C-预测子块]]
  • [[31-Level7-batch_forecast主链]]
  • [[32-Level8-_perform_rolling_predictions]]

14. 第 9 段:生成真实目标

代码:

python
targets = batch_maker.make_batch_eval(horizon)["target"]

作用:

  • 生成与 all_predicts 一一对齐的真实 rolling 目标

这里的 targets 不是训练集,也不是整个测试集。
它是:

与每个 rolling 预测窗口对应的真实目标片段。

15. 第 10 段:逐窗口打分

代码:

python
all_test_results = []
for predicts, target in zip(all_predicts, targets):
    single_series_results = self.evaluator.evaluate(
        target,
        predicts,
        eval_scaler,
        target_train_valid_data.values,
    )
    all_test_results.append(single_series_results)
single_series_results = np.mean(np.stack(all_test_results), axis=0).tolist()

这一段的作用:

  • 每个 rolling 窗口各算一次指标
  • 再对所有 rolling 窗口的指标取平均

这一步就是:

rolling forecast 评测协议真正落到数字上的地方。

16. 第 11 段:组装最终结果

代码:

python
save_true_pred = self._get_scalar_config_value("save_true_pred", series_name)
actual_data_encoded = self._encode_data(targets) if save_true_pred else np.nan
inference_data_encoded = (
    self._encode_data(all_predicts) if save_true_pred else np.nan
)

single_series_results += [
    series_name,
    end_fit_time - start_fit_time,
    total_inference_time,
    actual_data_encoded,
    inference_data_encoded,
    "",
]
return single_series_results

这一段的作用:

  • 把指标结果和附加字段拼起来

最终返回的内容包含:

  1. metric 结果
  2. 文件名
  3. fit 时间
  4. inference 时间
  5. 可选真实值编码
  6. 可选预测值编码
  7. log 信息

17. 抽象分块怎么和源码顺序对应

前面你已经有抽象分块:

  • 5A 配置与切数
  • 5B 训练子块
  • 5C 预测子块
  • 5D 打分与收尾子块

现在把它们对回源码顺序:

抽象分块源码顺序
5A第 1 段 到 第 3 段
5B第 4 段 到 第 5 段
5C第 6 段 到 第 8 段
5D第 9 段 到 第 11 段

所以旧文档的“四段”没有错,但它是抽象分块视角
这篇则是真实代码顺序视角

18. 你接下来该怎么用这篇

这篇读完后,下一步不要再反复啃 _eval_batch

接下来应该固定两件事:

  1. _eval_batch(...) 只是外层协议主体
  2. 真正需要进入模型阅读的入口有两个:
    • fit_method -> forecast_fit(...)
    • model.batch_forecast(...)

然后你就可以开始做你新的 roadmap:

在固定 _eval_batch 外层协议的前提下,细读模型代码,建立“代码 - 论文理论”映射。

19. 最后压成 6 句话

  1. _eval_batch(...) 是 rolling forecast 单序列评测的业务主体。
  2. 它先读配置、切数据,再进入训练、预测、打分三大动作。
  3. 训练入口是 fit_method(...),当前例子里实际走到 forecast_fit(...)
  4. 预测入口是 model.batch_forecast(...)
  5. 打分不是对整段测试集一次完成,而是对每个 rolling 窗口逐个评测再求平均。
  6. 读懂 _eval_batch(...) 的意义,是固定 benchmark 协议外壳,之后把注意力转向模型本身。

*记录并在线阅读我的笔记*