Skip to content

Level 1 配置进入 Informer

Abstract

这一篇只讲一件事:

当前命令行里的配置,怎样一步步变成 Informer(config)

1. 上下文

上一层:

下一层:

这一层的出口是:

python
model = Informer(config)

1.5 本层入口参数含义

  • args.model_name
    • 决定最终要导入哪个模型类
  • args.adapter
    • 决定模型类外面包哪层接入壳
  • args.model_hyper_params
    • 提供显式超参数字典
  • config_data
    • 提供配置文件里的默认骨架

1.6 本层输出含义

本层输出不是预测结果,而是:

python
model = Informer(config)

含义是:

当前命令已经被解析成一个真正可调用的 Informer 模型对象,后面只差把数据送进去。

2. 当前主问题

这里的真正入口不是 get_model_info(...)

真正入口是:

因为最开始读取你这条命令的就是它。

3. 顺序图

4. 抽象树

5. 先固定这条命令里哪些参数在这一层起作用

当前命令里和这一层直接相关的是:

text
--model-name time_series_library.Informer
--adapter transformer_adapter
--model-hyper-params {"batch_size":4,"d_model":32,"d_ff":128,"dropout":0.0,"lr":0.0001,"num_epochs":1,"num_workers":0,"seq_len":96,"horizon":24}

在这一层里,它们的作用要分清:

  • model-name
    • 决定最终找哪个模型类
    • 当前是 Informer
  • adapter
    • 决定模型类外面包哪层接入壳
    • 当前是 transformer_adapter
  • model-hyper-params
    • 先只是被收集进 model_config
    • 还没有全部被真正消费
    • 比如 d_model=32seq_len=96 会在后面 config 和模型内部才真正起作用

所以这一层不要误解成:

build_model_config(...) 已经在“使用” d_model

不是。
它现在只是先把这些参数装进配置对象。

6. 代码块 1:build_model_config(...)

位置:

完整代码:

python
def build_model_config(args: argparse.Namespace, config_data: Dict) -> Dict:
    model_config = config_data.get("model_config", None)

    if args.adapter is not None:
        args.adapter = [None if item == "None" else item for item in args.adapter]
        if len(args.model_name) > len(args.adapter):
            args.adapter.extend([None] * (len(args.model_name) - len(args.adapter)))
    else:
        args.adapter = [None] * len(args.model_name)

    if args.model_hyper_params is not None:
        args.model_hyper_params = [
            None if item == "None" else item for item in args.model_hyper_params
        ]
        if len(args.model_name) > len(args.model_hyper_params):
            args.model_hyper_params.extend(
                [None] * (len(args.model_name) - len(args.model_hyper_params))
            )
    else:
        args.model_hyper_params = [None] * len(args.model_name)

    for adapter, model_name, model_hyper_params in zip(
        args.adapter, args.model_name, args.model_hyper_params
    ):
        model_config["models"].append(
            {
                "adapter": adapter,
                "model_name": model_name,
                "model_hyper_params": (
                    json.loads(model_hyper_params)
                    if model_hyper_params is not None
                    else {}
                ),
            }
        )

    return model_config

带中文注释的完整代码:

python
def build_model_config(args: argparse.Namespace, config_data: Dict) -> Dict:
    # 先从 json 配置文件拿一个 model_config 骨架
    model_config = config_data.get("model_config", None)

    # 把命令行里的 adapter 参数整理成和 model_name 对齐的列表
    if args.adapter is not None:
        args.adapter = [None if item == "None" else item for item in args.adapter]
        if len(args.model_name) > len(args.adapter):
            args.adapter.extend([None] * (len(args.model_name) - len(args.adapter)))
    else:
        args.adapter = [None] * len(args.model_name)

    # 把命令行里的 model_hyper_params 整理成和 model_name 对齐的列表
    if args.model_hyper_params is not None:
        args.model_hyper_params = [
            None if item == "None" else item for item in args.model_hyper_params
        ]
        if len(args.model_name) > len(args.model_hyper_params):
            args.model_hyper_params.extend(
                [None] * (len(args.model_name) - len(args.model_hyper_params))
            )
    else:
        args.model_hyper_params = [None] * len(args.model_name)

    # 把一组命令行参数真正组装成 models 列表里的一个模型条目
    for adapter, model_name, model_hyper_params in zip(
        args.adapter, args.model_name, args.model_hyper_params
    ):
        model_config["models"].append(
            {
                "adapter": adapter,
                "model_name": model_name,
                "model_hyper_params": (
                    json.loads(model_hyper_params)
                    if model_hyper_params is not None
                    else {}
                ),
            }
        )

    return model_config

对着这段代码讲当前例子

你的命令行到这里,会先变成:

python
args.model_name = ["time_series_library.Informer"]
args.adapter = ["transformer_adapter"]
args.model_hyper_params = [
    "{\"batch_size\":4,\"d_model\":32,\"d_ff\":128,\"dropout\":0.0,"
    "\"lr\":0.0001,\"num_epochs\":1,\"num_workers\":0,"
    "\"seq_len\":96,\"horizon\":24}"
]

然后这段代码把它们组装成:

python
model_config = {
    "models": [
        {
            "adapter": "transformer_adapter",
            "model_name": "time_series_library.Informer",
            "model_hyper_params": {
                "batch_size": 4,
                "d_model": 32,
                "d_ff": 128,
                "dropout": 0.0,
                "lr": 0.0001,
                "num_epochs": 1,
                "num_workers": 0,
                "seq_len": 96,
                "horizon": 24,
            },
        }
    ],
    ...
}

这里已经能回答两个最关键的问题:

  1. 为什么后面会走到 Informer
    因为这里已经把 model_name 固定成了 time_series_library.Informer
  2. 为什么后面会有 adapter
    因为这里已经把 adapter 固定成了 transformer_adapter

7. 代码块 2:pipeline(...) -> get_models(...)

位置:

关键代码:

python
model_factory_list = get_models(model_config)

这一行的输入就是上一节组装好的 model_config
它的任务不是实例化模型,而是:

把“配置里的模型条目”变成“可调用的模型工厂”。

8. 代码块 3:get_models(...)

位置:

完整代码:

python
def get_models(all_model_config: Dict) -> List[ModelFactory]:
    model_factory_list = []
    for model_config in all_model_config["models"]:
        model_info = get_model_info(model_config)
        fallback_model_name = model_config["model_name"].split(".")[-1]

        if isinstance(model_info, dict):
            model_name = model_info.get("model_name", fallback_model_name)
            model_factory = model_info["model_factory"]
            required_hyper_params = model_info.get("required_hyper_params", {})
        elif callable(model_info):
            model_name = fallback_model_name
            model_factory = model_info
            required_hyper_params = {}
        else:
            raise TypeError("model info returned by get_model_info has invalid type")

        model_hyper_params = get_model_hyper_params(
            all_model_config.get("recommend_model_hyper_params", {}),
            required_hyper_params,
            model_config,
        )
        model_factory_list.append(
            ModelFactory(model_name, model_factory, model_hyper_params)
        )
    return model_factory_list

带中文注释的完整代码:

python
def get_models(all_model_config: Dict) -> List[ModelFactory]:
    model_factory_list = []

    # 遍历 models 列表里的每一个模型条目
    for model_config in all_model_config["models"]:
        # 先解析“这个条目到底对应哪个模型 + 哪个 adapter”
        model_info = get_model_info(model_config)
        fallback_model_name = model_config["model_name"].split(".")[-1]

        # 再拿到真正的 model_factory 和它要求的参数名映射
        if isinstance(model_info, dict):
            model_name = model_info.get("model_name", fallback_model_name)
            model_factory = model_info["model_factory"]
            required_hyper_params = model_info.get("required_hyper_params", {})
        elif callable(model_info):
            model_name = fallback_model_name
            model_factory = model_info
            required_hyper_params = {}
        else:
            raise TypeError("model info returned by get_model_info has invalid type")

        # 把推荐参数、adapter 要求参数、显式命令行参数合并
        model_hyper_params = get_model_hyper_params(
            all_model_config.get("recommend_model_hyper_params", {}),
            required_hyper_params,
            model_config,
        )

        # 最后打包成 ModelFactory
        model_factory_list.append(
            ModelFactory(model_name, model_factory, model_hyper_params)
        )

    return model_factory_list

对着这段代码讲当前例子

当前 all_model_config["models"] 里只有一个模型条目,所以这个循环只跑一次。
这一轮的核心是:

python
model_info = get_model_info(
    {
        "adapter": "transformer_adapter",
        "model_name": "time_series_library.Informer",
        "model_hyper_params": {...}
    }
)

也就是说:

get_model_info(...) 不是入口,它是 get_models(...) 内部的一个步骤。

9. 代码块 4:get_model_info(...)

位置:

完整代码:

python
def get_model_info(model_config: Dict) -> Union[Dict, Callable]:
    model_name_candidates = [
        (
            model_config["model_name"][7:]
            if model_config["model_name"].startswith("global.")
            else None
        ),
        "ts_benchmark.baselines." + model_config["model_name"],
        model_config["model_name"],
    ]
    model_name_candidates = list(filter(None, model_name_candidates))

    model_info = None
    for model_name in model_name_candidates:
        try:
            logger.info("Trying to load model %s", model_name)
            model_info = import_model_info(model_name)
        except (ImportError, AttributeError) as e:
            continue
        else:
            break

    adapter_name = model_config.get("adapter")
    if adapter_name is not None:
        model_info = _import_attribute(ADAPTER[adapter_name])(model_info)

    return model_info

带中文注释的完整代码:

python
def get_model_info(model_config: Dict) -> Union[Dict, Callable]:
    # 先构造若干候选 import 路径
    model_name_candidates = [
        (
            model_config["model_name"][7:]
            if model_config["model_name"].startswith("global.")
            else None
        ),
        "ts_benchmark.baselines." + model_config["model_name"],
        model_config["model_name"],
    ]
    model_name_candidates = list(filter(None, model_name_candidates))

    model_info = None

    # 尝试导入真正的模型类
    for model_name in model_name_candidates:
        try:
            logger.info("Trying to load model %s", model_name)
            model_info = import_model_info(model_name)
        except (ImportError, AttributeError) as e:
            continue
        else:
            break

    # 如果配置了 adapter,就在模型类外面再包一层
    adapter_name = model_config.get("adapter")
    if adapter_name is not None:
        model_info = _import_attribute(ADAPTER[adapter_name])(model_info)

    return model_info

对着这段代码讲当前例子

当前这一轮真正发生的是:

python
先导入 ts_benchmark.baselines.time_series_library.Informer
再执行 transformer_adapter(Informer)

这里就能把几个参数和它们的作用对上:

  • model_name = time_series_library.Informer
    • 控制的是“导入哪个模型类”
  • adapter = transformer_adapter
    • 控制的是“在模型类外面包哪层接入壳”
  • d_model / seq_len / horizon
    • 在这里还没有被真正用到
    • 它们只是后面会进入 config

10. 代码块 5:ModelFactory -> Informer(config)

10.1 ModelFactory.__call__(...)

位置:

python
class ModelFactory:
    def __call__(self) -> Any:
        return self.model_factory(**self.model_hyper_params)

当前例子里:

python
model = model_factory()

得到的还不是裸 Informer,而是:

python
TransformerAdapter(...)

10.2 _init_model()

位置:

python
def _init_model(self):
    return self.model_class(self.config)

当前例子里:

python
self.model = self._init_model()

等价于:

python
self.model = Informer(self.config)

这里才是这篇文档的真正出口。

10.3 这里的参数终于开始真正进模型

Informer(self.config) 这一步,这些参数才开始真正和模型对象绑定:

  • d_model = 32
  • d_ff = 128
  • seq_len = 96
  • pred_len = 24
  • label_len
  • enc_in / dec_in / c_out
  • e_layers / d_layers / n_heads

其中:

  • seq_len / pred_len / d_model / d_ff
    • 你显式传或间接指定
  • label_len / enc_in / dec_in / c_out
    • 会在后面的运行过程中补齐

11. 这一层结束时应该固定什么

  1. 真正入口是:
text
run_benchmark.py + 当前命令行参数
  1. get_model_info(...) 不是入口,而是这条链中间的一个解析步骤:
text
run_benchmark.py
-> build_model_config(...)
-> pipeline(...)
-> get_models(...)
-> get_model_info(...)
  1. model-name 决定导入 Informer
  2. adapter 决定包 transformer_adapter
  3. model_factory() 得到的是 TransformerAdapter
  4. _init_model() 才真正得到:
python
Informer(config)

12. 下一步

继续看:

*记录并在线阅读我的笔记*