Appearance
Level 1 配置进入 Informer
Abstract
这一篇只讲一件事:
当前命令行里的配置,怎样一步步变成
Informer(config)。
1. 上下文
上一层:
下一层:
这一层的出口是:
python
model = Informer(config)1.5 本层入口参数含义
args.model_name- 决定最终要导入哪个模型类
args.adapter- 决定模型类外面包哪层接入壳
args.model_hyper_params- 提供显式超参数字典
config_data- 提供配置文件里的默认骨架
1.6 本层输出含义
本层输出不是预测结果,而是:
python
model = Informer(config)含义是:
当前命令已经被解析成一个真正可调用的 Informer 模型对象,后面只差把数据送进去。
2. 当前主问题
这里的真正入口不是 get_model_info(...)。
真正入口是:
因为最开始读取你这条命令的就是它。
3. 顺序图
4. 抽象树
5. 先固定这条命令里哪些参数在这一层起作用
当前命令里和这一层直接相关的是:
text
--model-name time_series_library.Informer
--adapter transformer_adapter
--model-hyper-params {"batch_size":4,"d_model":32,"d_ff":128,"dropout":0.0,"lr":0.0001,"num_epochs":1,"num_workers":0,"seq_len":96,"horizon":24}在这一层里,它们的作用要分清:
model-name- 决定最终找哪个模型类
- 当前是
Informer
adapter- 决定模型类外面包哪层接入壳
- 当前是
transformer_adapter
model-hyper-params- 先只是被收集进
model_config - 还没有全部被真正消费
- 比如
d_model=32、seq_len=96会在后面config和模型内部才真正起作用
- 先只是被收集进
所以这一层不要误解成:
build_model_config(...)已经在“使用”d_model
不是。
它现在只是先把这些参数装进配置对象。
6. 代码块 1:build_model_config(...)
位置:
完整代码:
python
def build_model_config(args: argparse.Namespace, config_data: Dict) -> Dict:
model_config = config_data.get("model_config", None)
if args.adapter is not None:
args.adapter = [None if item == "None" else item for item in args.adapter]
if len(args.model_name) > len(args.adapter):
args.adapter.extend([None] * (len(args.model_name) - len(args.adapter)))
else:
args.adapter = [None] * len(args.model_name)
if args.model_hyper_params is not None:
args.model_hyper_params = [
None if item == "None" else item for item in args.model_hyper_params
]
if len(args.model_name) > len(args.model_hyper_params):
args.model_hyper_params.extend(
[None] * (len(args.model_name) - len(args.model_hyper_params))
)
else:
args.model_hyper_params = [None] * len(args.model_name)
for adapter, model_name, model_hyper_params in zip(
args.adapter, args.model_name, args.model_hyper_params
):
model_config["models"].append(
{
"adapter": adapter,
"model_name": model_name,
"model_hyper_params": (
json.loads(model_hyper_params)
if model_hyper_params is not None
else {}
),
}
)
return model_config带中文注释的完整代码:
python
def build_model_config(args: argparse.Namespace, config_data: Dict) -> Dict:
# 先从 json 配置文件拿一个 model_config 骨架
model_config = config_data.get("model_config", None)
# 把命令行里的 adapter 参数整理成和 model_name 对齐的列表
if args.adapter is not None:
args.adapter = [None if item == "None" else item for item in args.adapter]
if len(args.model_name) > len(args.adapter):
args.adapter.extend([None] * (len(args.model_name) - len(args.adapter)))
else:
args.adapter = [None] * len(args.model_name)
# 把命令行里的 model_hyper_params 整理成和 model_name 对齐的列表
if args.model_hyper_params is not None:
args.model_hyper_params = [
None if item == "None" else item for item in args.model_hyper_params
]
if len(args.model_name) > len(args.model_hyper_params):
args.model_hyper_params.extend(
[None] * (len(args.model_name) - len(args.model_hyper_params))
)
else:
args.model_hyper_params = [None] * len(args.model_name)
# 把一组命令行参数真正组装成 models 列表里的一个模型条目
for adapter, model_name, model_hyper_params in zip(
args.adapter, args.model_name, args.model_hyper_params
):
model_config["models"].append(
{
"adapter": adapter,
"model_name": model_name,
"model_hyper_params": (
json.loads(model_hyper_params)
if model_hyper_params is not None
else {}
),
}
)
return model_config对着这段代码讲当前例子
你的命令行到这里,会先变成:
python
args.model_name = ["time_series_library.Informer"]
args.adapter = ["transformer_adapter"]
args.model_hyper_params = [
"{\"batch_size\":4,\"d_model\":32,\"d_ff\":128,\"dropout\":0.0,"
"\"lr\":0.0001,\"num_epochs\":1,\"num_workers\":0,"
"\"seq_len\":96,\"horizon\":24}"
]然后这段代码把它们组装成:
python
model_config = {
"models": [
{
"adapter": "transformer_adapter",
"model_name": "time_series_library.Informer",
"model_hyper_params": {
"batch_size": 4,
"d_model": 32,
"d_ff": 128,
"dropout": 0.0,
"lr": 0.0001,
"num_epochs": 1,
"num_workers": 0,
"seq_len": 96,
"horizon": 24,
},
}
],
...
}这里已经能回答两个最关键的问题:
- 为什么后面会走到 Informer
因为这里已经把model_name固定成了time_series_library.Informer - 为什么后面会有 adapter
因为这里已经把adapter固定成了transformer_adapter
7. 代码块 2:pipeline(...) -> get_models(...)
位置:
关键代码:
python
model_factory_list = get_models(model_config)这一行的输入就是上一节组装好的 model_config。
它的任务不是实例化模型,而是:
把“配置里的模型条目”变成“可调用的模型工厂”。
8. 代码块 3:get_models(...)
位置:
完整代码:
python
def get_models(all_model_config: Dict) -> List[ModelFactory]:
model_factory_list = []
for model_config in all_model_config["models"]:
model_info = get_model_info(model_config)
fallback_model_name = model_config["model_name"].split(".")[-1]
if isinstance(model_info, dict):
model_name = model_info.get("model_name", fallback_model_name)
model_factory = model_info["model_factory"]
required_hyper_params = model_info.get("required_hyper_params", {})
elif callable(model_info):
model_name = fallback_model_name
model_factory = model_info
required_hyper_params = {}
else:
raise TypeError("model info returned by get_model_info has invalid type")
model_hyper_params = get_model_hyper_params(
all_model_config.get("recommend_model_hyper_params", {}),
required_hyper_params,
model_config,
)
model_factory_list.append(
ModelFactory(model_name, model_factory, model_hyper_params)
)
return model_factory_list带中文注释的完整代码:
python
def get_models(all_model_config: Dict) -> List[ModelFactory]:
model_factory_list = []
# 遍历 models 列表里的每一个模型条目
for model_config in all_model_config["models"]:
# 先解析“这个条目到底对应哪个模型 + 哪个 adapter”
model_info = get_model_info(model_config)
fallback_model_name = model_config["model_name"].split(".")[-1]
# 再拿到真正的 model_factory 和它要求的参数名映射
if isinstance(model_info, dict):
model_name = model_info.get("model_name", fallback_model_name)
model_factory = model_info["model_factory"]
required_hyper_params = model_info.get("required_hyper_params", {})
elif callable(model_info):
model_name = fallback_model_name
model_factory = model_info
required_hyper_params = {}
else:
raise TypeError("model info returned by get_model_info has invalid type")
# 把推荐参数、adapter 要求参数、显式命令行参数合并
model_hyper_params = get_model_hyper_params(
all_model_config.get("recommend_model_hyper_params", {}),
required_hyper_params,
model_config,
)
# 最后打包成 ModelFactory
model_factory_list.append(
ModelFactory(model_name, model_factory, model_hyper_params)
)
return model_factory_list对着这段代码讲当前例子
当前 all_model_config["models"] 里只有一个模型条目,所以这个循环只跑一次。
这一轮的核心是:
python
model_info = get_model_info(
{
"adapter": "transformer_adapter",
"model_name": "time_series_library.Informer",
"model_hyper_params": {...}
}
)也就是说:
get_model_info(...)不是入口,它是get_models(...)内部的一个步骤。
9. 代码块 4:get_model_info(...)
位置:
完整代码:
python
def get_model_info(model_config: Dict) -> Union[Dict, Callable]:
model_name_candidates = [
(
model_config["model_name"][7:]
if model_config["model_name"].startswith("global.")
else None
),
"ts_benchmark.baselines." + model_config["model_name"],
model_config["model_name"],
]
model_name_candidates = list(filter(None, model_name_candidates))
model_info = None
for model_name in model_name_candidates:
try:
logger.info("Trying to load model %s", model_name)
model_info = import_model_info(model_name)
except (ImportError, AttributeError) as e:
continue
else:
break
adapter_name = model_config.get("adapter")
if adapter_name is not None:
model_info = _import_attribute(ADAPTER[adapter_name])(model_info)
return model_info带中文注释的完整代码:
python
def get_model_info(model_config: Dict) -> Union[Dict, Callable]:
# 先构造若干候选 import 路径
model_name_candidates = [
(
model_config["model_name"][7:]
if model_config["model_name"].startswith("global.")
else None
),
"ts_benchmark.baselines." + model_config["model_name"],
model_config["model_name"],
]
model_name_candidates = list(filter(None, model_name_candidates))
model_info = None
# 尝试导入真正的模型类
for model_name in model_name_candidates:
try:
logger.info("Trying to load model %s", model_name)
model_info = import_model_info(model_name)
except (ImportError, AttributeError) as e:
continue
else:
break
# 如果配置了 adapter,就在模型类外面再包一层
adapter_name = model_config.get("adapter")
if adapter_name is not None:
model_info = _import_attribute(ADAPTER[adapter_name])(model_info)
return model_info对着这段代码讲当前例子
当前这一轮真正发生的是:
python
先导入 ts_benchmark.baselines.time_series_library.Informer
再执行 transformer_adapter(Informer)这里就能把几个参数和它们的作用对上:
model_name = time_series_library.Informer- 控制的是“导入哪个模型类”
adapter = transformer_adapter- 控制的是“在模型类外面包哪层接入壳”
d_model / seq_len / horizon- 在这里还没有被真正用到
- 它们只是后面会进入
config
10. 代码块 5:ModelFactory -> Informer(config)
10.1 ModelFactory.__call__(...)
位置:
python
class ModelFactory:
def __call__(self) -> Any:
return self.model_factory(**self.model_hyper_params)当前例子里:
python
model = model_factory()得到的还不是裸 Informer,而是:
python
TransformerAdapter(...)10.2 _init_model()
位置:
python
def _init_model(self):
return self.model_class(self.config)当前例子里:
python
self.model = self._init_model()等价于:
python
self.model = Informer(self.config)这里才是这篇文档的真正出口。
10.3 这里的参数终于开始真正进模型
到 Informer(self.config) 这一步,这些参数才开始真正和模型对象绑定:
d_model = 32d_ff = 128seq_len = 96pred_len = 24label_lenenc_in / dec_in / c_oute_layers / d_layers / n_heads
其中:
seq_len / pred_len / d_model / d_ff- 你显式传或间接指定
label_len / enc_in / dec_in / c_out- 会在后面的运行过程中补齐
11. 这一层结束时应该固定什么
- 真正入口是:
text
run_benchmark.py + 当前命令行参数get_model_info(...)不是入口,而是这条链中间的一个解析步骤:
text
run_benchmark.py
-> build_model_config(...)
-> pipeline(...)
-> get_models(...)
-> get_model_info(...)model-name决定导入Informeradapter决定包transformer_adaptermodel_factory()得到的是TransformerAdapter_init_model()才真正得到:
python
Informer(config)12. 下一步
继续看: