Skip to content

TwoStageAttentionLayer:两阶段注意力

Abstract

TwoStageAttentionLayer 处理四维输入:(batch, variable_dim, segment_num, d_model)

它先在每个变量内部沿时间段做 attention,再用 learnable router 在变量维之间传递信息。

1. 图解

![[zdocs/pytorch-basics/assets/self_attention_reformer_twostage.svg]]

右半部分是 TwoStageAttentionLayer

2. 输入输出

源码注释:

python
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]

toy:

text
x.shape = (b, ts_d, seg_num, d_model) = (2,3,5,8)
factor = 2

输出:

text
final_out.shape = (2,3,5,8)

3. 初始化结构

核心子模块:

python
self.time_attention = AttentionLayer(FullAttention(...), d_model, n_heads)
self.dim_sender = AttentionLayer(FullAttention(...), d_model, n_heads)
self.dim_receiver = AttentionLayer(FullAttention(...), d_model, n_heads)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))

含义:

模块作用
time_attention每个变量内部,沿 seg_num 做 attention
dim_senderrouter 从所有变量聚合信息
dim_receiver每个变量从 router 接收信息
router可学习的中转 token,数量是 factor

4. 第一阶段:Cross Time Stage

源码:

python
batch = x.shape[0]
time_in = rearrange(x, "b ts_d seg_num d_model -> (b ts_d) seg_num d_model")
time_enc, attn = self.time_attention(
    time_in, time_in, time_in, attn_mask=None, tau=None, delta=None
)
dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in)

shape:

text
x:       (2,3,5,8)
time_in: (2*3,5,8) = (6,5,8)

这一步的逻辑是:

text
每个变量独立看自己的 5 个 segment。
变量之间暂时不交流。

公式化理解:

time_in(b,d),l,:=xb,d,l,:

然后对每个 (b,d) 独立做 self-attention。

5. 第二阶段:Cross Dimension Stage

源码:

python
dim_send = rearrange(
    dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch
)
batch_router = repeat(
    self.router,
    "seg_num factor d_model -> (repeat seg_num) factor d_model",
    repeat=batch,
)

shape:

text
dim_in:   (6,5,8)
dim_send: (2*5,3,8) = (10,3,8)
router:   (5,2,8)
batch_router: (2*5,2,8) = (10,2,8)

此时每个 segment 位置单独看变量维:

text
dim_send 的长度维 ts_d=3 表示 3 个变量。
batch_router 的长度维 factor=2 表示 2 个路由 token。

6. sender / receiver

sender:

python
dim_buffer, attn = self.dim_sender(
    batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None
)

含义:

text
Q = router
K/V = variables
router 从变量收集信息

shape:

text
batch_router: (10,2,8)
dim_send:     (10,3,8)
dim_buffer:   (10,2,8)

receiver:

python
dim_receive, attn = self.dim_receiver(
    dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None
)

含义:

text
Q = variables
K/V = router buffer
变量从 router 接收跨变量信息

shape:

text
dim_send:    (10,3,8)
dim_buffer:  (10,2,8)
dim_receive: (10,3,8)

7. 收尾 reshape

源码:

python
final_out = rearrange(
    dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch
)

shape:

text
dim_enc:   (10,3,8)
final_out: (2,3,5,8)

8. 抽象逻辑

两阶段可以记成:

text
Stage 1: 固定变量,沿时间 segment 做 attention
Stage 2: 固定时间 segment,用 router 在变量之间传消息

router 的好处是避免所有变量两两直接 attention,先聚合到少量 factor 个中转 token,再分发回来。

*记录并在线阅读我的笔记*