Appearance
ReformerLayer:LSH 注意力包装
Abstract
ReformerLayer是对reformer_pytorch.LSHSelfAttention的薄包装。它最重要的代码不是 attention 公式,而是
fit_length:把序列长度补到bucket_size * 2的倍数。
1. 图解
![[zdocs/pytorch-basics/assets/self_attention_reformer_twostage.svg]]
左半部分是 ReformerLayer。
2. 源码
python
class ReformerLayer(nn.Module):
def __init__(..., bucket_size=4, n_hashes=4):
...
self.bucket_size = bucket_size
self.attn = LSHSelfAttention(
dim=d_model,
heads=n_heads,
bucket_size=bucket_size,
n_hashes=n_hashes,
causal=causal,
)真正需要精读的是:
python
def fit_length(self, queries):
B, N, C = queries.shape
if N % (self.bucket_size * 2) == 0:
return queries
else:
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
return torch.cat(
[queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1
)3. 长度公式
Reformer 内部要求:
如果不满足,就补:
toy:
text
bucket_size = 4
2 * bucket_size = 8
N = 10则:
所以:
text
(B,N,C) = (2,10,8)
-> cat 6 个全 0 token
(2,16,8)4. forward
源码:
python
def forward(self, queries, keys, values, attn_mask, tau, delta):
B, N, C = queries.shape
queries = self.attn(self.fit_length(queries))[:, :N, :]
return queries, None解释:
text
1. 记录原长度 N
2. fit_length 补齐到 LSHSelfAttention 要求的长度
3. self.attn(...) 计算 Reformer attention
4. [:, :N, :] 裁剪回原长度shape:
text
(2,10,8)
-> fit_length
(2,16,8)
-> LSHSelfAttention
(2,16,8)
-> 裁剪
(2,10,8)5. 常见误区
ReformerLayer.forward 里没有使用 keys 和 values。注释里也写了:
python
# in Reformer: defalut queries=keys所以它更像 self-attention 层,而不是通用 cross-attention 层。