Appearance
TwoStageAttentionLayer:两阶段注意力
Abstract
TwoStageAttentionLayer处理四维输入:(batch, variable_dim, segment_num, d_model)。它先在每个变量内部沿时间段做 attention,再用 learnable router 在变量维之间传递信息。
1. 图解
![[zdocs/pytorch-basics/assets/self_attention_reformer_twostage.svg]]
右半部分是 TwoStageAttentionLayer。
2. 输入输出
源码注释:
python
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]toy:
text
x.shape = (b, ts_d, seg_num, d_model) = (2,3,5,8)
factor = 2输出:
text
final_out.shape = (2,3,5,8)3. 初始化结构
核心子模块:
python
self.time_attention = AttentionLayer(FullAttention(...), d_model, n_heads)
self.dim_sender = AttentionLayer(FullAttention(...), d_model, n_heads)
self.dim_receiver = AttentionLayer(FullAttention(...), d_model, n_heads)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))含义:
| 模块 | 作用 |
|---|---|
time_attention | 每个变量内部,沿 seg_num 做 attention |
dim_sender | router 从所有变量聚合信息 |
dim_receiver | 每个变量从 router 接收信息 |
router | 可学习的中转 token,数量是 factor |
4. 第一阶段:Cross Time Stage
源码:
python
batch = x.shape[0]
time_in = rearrange(x, "b ts_d seg_num d_model -> (b ts_d) seg_num d_model")
time_enc, attn = self.time_attention(
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None
)
dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in)shape:
text
x: (2,3,5,8)
time_in: (2*3,5,8) = (6,5,8)这一步的逻辑是:
text
每个变量独立看自己的 5 个 segment。
变量之间暂时不交流。公式化理解:
然后对每个 (b,d) 独立做 self-attention。
5. 第二阶段:Cross Dimension Stage
源码:
python
dim_send = rearrange(
dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch
)
batch_router = repeat(
self.router,
"seg_num factor d_model -> (repeat seg_num) factor d_model",
repeat=batch,
)shape:
text
dim_in: (6,5,8)
dim_send: (2*5,3,8) = (10,3,8)
router: (5,2,8)
batch_router: (2*5,2,8) = (10,2,8)此时每个 segment 位置单独看变量维:
text
dim_send 的长度维 ts_d=3 表示 3 个变量。
batch_router 的长度维 factor=2 表示 2 个路由 token。6. sender / receiver
sender:
python
dim_buffer, attn = self.dim_sender(
batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None
)含义:
text
Q = router
K/V = variables
router 从变量收集信息shape:
text
batch_router: (10,2,8)
dim_send: (10,3,8)
dim_buffer: (10,2,8)receiver:
python
dim_receive, attn = self.dim_receiver(
dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None
)含义:
text
Q = variables
K/V = router buffer
变量从 router 接收跨变量信息shape:
text
dim_send: (10,3,8)
dim_buffer: (10,2,8)
dim_receive: (10,3,8)7. 收尾 reshape
源码:
python
final_out = rearrange(
dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch
)shape:
text
dim_enc: (10,3,8)
final_out: (2,3,5,8)8. 抽象逻辑
两阶段可以记成:
text
Stage 1: 固定变量,沿时间 segment 做 attention
Stage 2: 固定时间 segment,用 router 在变量之间传消息router 的好处是避免所有变量两两直接 attention,先聚合到少量 factor 个中转 token,再分发回来。