BiFormer

  • Paper: BiFormer: Vision Transformer with Bi-Level Routing Attention

  • Authors: Lei Zhu, Xinjiang Wang, Zhanghan Ke, Wayne Zhang, Rynson Lau

  • Code: GitHub

  • Framework: BiFormer

Transformer

优势

  • long-range dependency
  • inductive-bias-free
  • high parallelism

劣势

  • 计算量大

  • 内存占用大

  • 现有方案:引入稀疏性

    • 局部窗口
    • 轴向注意力
    • 空洞注意力
  • 存在问题

    • 筛选 key/value 时没有区分 query

Bi-level Routing Attention (BRA)

Sparsity

  • 利用稀疏性来节省计算量和内存,同时只包含 GPU 友好的稠密矩阵乘法

Query-aware

  • 为各个 Query 筛选语义最相关的 Key-Value 对

伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# input: features (H, W, C). Assume H==W.
# output: features (H, W, C).
# S: square root of number of regions.
# k: number of regions to attend.

# patchify input (H, W, C) -> (Sˆ2, HW/Sˆ2, C)
x = patchify(input, patch_size=H//S)

# linear projection of query, key, value
query, key, value = linear_qkv(x).chunk(3, dim=-1)

# regional query and key (Sˆ2, C)
query_r, key_r = query.mean(dim=1), key.mean(dim=1)

# adjacency matrix for regional graph (Sˆ2, Sˆ2)
A_r = mm(query_r, key_r.transpose(-1, -2))
# compute index matrix of routed regions (Sˆ2, K)
I_r = topk(A_r, k).index
# gather key-value pairs
key_g = gather(key, I_r)
# (Sˆ2, kHW/Sˆ2, C)
value_g = gather(value, I_r)
# (Sˆ2, kHW/Sˆ2, C)
# token-to-token attention
A = bmm(query, key_g.transpose(-2, -1))
A = softmax(A, dim=-1)
output = bmm(A, value_g) + dwconv(value)
# recover to (H, W, C) shape
output = unpatchify(output, patch_size=H//S)

p