Skip to content

ReformerLayer:LSH 注意力包装

Abstract

ReformerLayer 是对 reformer_pytorch.LSHSelfAttention 的薄包装。

它最重要的代码不是 attention 公式,而是 fit_length:把序列长度补到 bucket_size * 2 的倍数。

1. 图解

![[zdocs/pytorch-basics/assets/self_attention_reformer_twostage.svg]]

左半部分是 ReformerLayer

2. 源码

python
class ReformerLayer(nn.Module):
    def __init__(..., bucket_size=4, n_hashes=4):
        ...
        self.bucket_size = bucket_size
        self.attn = LSHSelfAttention(
            dim=d_model,
            heads=n_heads,
            bucket_size=bucket_size,
            n_hashes=n_hashes,
            causal=causal,
        )

真正需要精读的是:

python
def fit_length(self, queries):
    B, N, C = queries.shape
    if N % (self.bucket_size * 2) == 0:
        return queries
    else:
        fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
        return torch.cat(
            [queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1
        )

3. 长度公式

Reformer 内部要求:

Nmod(2bucket_size)=0

如果不满足,就补:

fill_len=2bucket_size(Nmod(2bucket_size))

toy:

text
bucket_size = 4
2 * bucket_size = 8
N = 10

则:

fill_len=8(10mod8)=82=6

所以:

text
(B,N,C) = (2,10,8)
-> cat 6 个全 0 token
(2,16,8)

4. forward

源码:

python
def forward(self, queries, keys, values, attn_mask, tau, delta):
    B, N, C = queries.shape
    queries = self.attn(self.fit_length(queries))[:, :N, :]
    return queries, None

解释:

text
1. 记录原长度 N
2. fit_length 补齐到 LSHSelfAttention 要求的长度
3. self.attn(...) 计算 Reformer attention
4. [:, :N, :] 裁剪回原长度

shape:

text
(2,10,8)
-> fit_length
(2,16,8)
-> LSHSelfAttention
(2,16,8)
-> 裁剪
(2,10,8)

5. 常见误区

ReformerLayer.forward 里没有使用 keysvalues。注释里也写了:

python
# in Reformer: defalut queries=keys

所以它更像 self-attention 层,而不是通用 cross-attention 层。

*记录并在线阅读我的笔记*