Appearance
03B2-Layer3-Inception_Block_V1
本文件位置
上层:[[03B-Layer2B-TimesBlock]]
入口代码:out = self.conv(out)
入口类:Inception_Block_V1
出口张量:保持二维周期网格的高宽不变,只改变 channel 数。
1. 本层顺序树
1.1 语义分组图
2. 输入输出接口
以 period=4 分支为例:
| 变量 | toy shape | 含义 |
|---|---|---|
out before conv | (3,6,4,4) | B=3,hidden channel 6,周期段数 4,周期长度 4 |
Inception_Block_V1(6,7) | (3,7,4,4) | 多尺度二维卷积,channel 从 d_model=6 到 d_ff=7 |
GELU | (3,7,4,4) | 非线性激活 |
Inception_Block_V1(7,6) | (3,6,4,4) | channel 从 d_ff=7 回到 d_model=6 |
3. 对照源码
位置:ts_benchmark/baselines/time_series_library/layers/Conv_Blocks.py
python
class Inception_Block_V1(nn.Module):
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
super(Inception_Block_V1, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
kernels = []
for i in range(self.num_kernels):
kernels.append(
nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)
)
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
res_list = []
for i in range(self.num_kernels):
res_list.append(self.kernels[i](x))
res = torch.stack(res_list, dim=-1).mean(-1)
return resTimesBlock 中的调用:
python
self.conv = nn.Sequential(
Inception_Block_V1(
configs.d_model, configs.d_ff, num_kernels=configs.num_kernels
),
nn.GELU(),
Inception_Block_V1(
configs.d_ff, configs.d_model, num_kernels=configs.num_kernels
),
)4. 多尺度卷积核
num_kernels=3 时,代码创建:
i | kernel_size=2*i+1 | padding=i | 空间尺寸是否保持 |
|---|---|---|---|
| 0 | 1 | 0 | 保持 |
| 1 | 3 | 1 | 保持 |
| 2 | 5 | 2 | 保持 |
二维卷积输出大小公式:
这里 stride=1,padding=i,kernel=2i+1:
宽度 W 同理保持不变。
5. Conv2d 的 toy 数值例子
只看单通道输入和一个 3x3 卷积核,说明二维网格中一个位置如何被计算。
text
输入局部窗口:
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
卷积核:
[[0, 1, 0],
[1, 2, 1],
[0, 1, 0]]
输出中心位置 =
1*0 + 2*1 + 3*0
+ 4*1 + 5*2 + 6*1
+ 7*0 + 8*1 + 9*0
= 30真实 Conv2d(in_channels=6,out_channels=7,kernel_size=3,padding=1) 会对所有输入 channel 求和:
6. torch.stack(res_list, dim=-1).mean(-1) 的含义
源码:
python
res_list = []
for i in range(self.num_kernels):
res_list.append(self.kernels[i](x))
res = torch.stack(res_list, dim=-1).mean(-1)
return resshape:
text
kernel=1 输出: (3,7,4,4)
kernel=3 输出: (3,7,4,4)
kernel=5 输出: (3,7,4,4)
torch.stack(res_list, dim=-1): (3,7,4,4,3)
mean(-1): (3,7,4,4)toy 例子只看同一个位置:
text
kernel=1 分支输出: 4.0
kernel=3 分支输出: 5.0
kernel=5 分支输出: 6.0
stack后: [4.0, 5.0, 6.0]
mean: (4.0 + 5.0 + 6.0) / 3 = 5.0这一步的含义是把不同感受野的二维模式取平均。kernel=1 偏局部点变换,kernel=3/5 能看到更大周期邻域。
7. 出口接回上层
text
self.conv(out): (3,6,4,4)
回到 [[03B-Layer2B-TimesBlock]]
下一步: permute + reshape,把二维周期网格还原为一维时间序列