一、网络结构

二、详情

1、补丁嵌入

  • 原始图像:H×W×C,高度,宽度,通道数

  • 补丁(patch):N×(P^2×C),

    N=H×W/(P×P)叫补丁数或者输入序列长度

    (P,P)叫补丁的分辨率

将图像分割成多个补丁(patch),将补丁平铺成一维数据格式,再通过线性投影将补丁平坦化映射到低维空间

1
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)

2、位置嵌入

  • 添加位置编码表示两个补丁之间的距离,即添加了补丁的位置信息

网络自动学习

1
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

sin-cos规则

  • 将向量的维度切分为奇数行和偶数行
  • 偶数行采用sin函数编码,奇数行采用cos函数编码
  • 然后按照原始行号拼接
1
2
3
4
5
6
7
8
9
10
def get_position_angle_vec(position):
# hid_j是0-511,d_hid是512,position表示单词位置0~N-1
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

# 每个单词位置0~N-1都可以编码得到512长度的向量
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
# 偶数列进行sin
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
# 奇数列进行cos
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

3、类别标记

  • 输出特征加上一个线性分类器实现分类

cls_token方式,cls位置表示向量

1
2
3
4
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) # 复制令牌,并生成新的转变为(B,N,D)格式的张量
x = torch.cat((cls_tokens, x), dim=1) # 类别标记拼接补丁嵌入

4、Transformer编码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.): # dim_head每个头的维度
super().__init__()
inner_dim = dim_head * heads # 所有头的维度
project_out = not (heads == 1 and dim_head == dim) # 不是一个头和每个头的维度等于总维度同时成立时,返回true

self.heads = heads # 注意力头数
self.scale = dim_head ** -0.5 # 缩放因子

self.norm = nn.LayerNorm(dim) # 层归一化

self.attend = nn.Softmax(dim=-1) # 均匀分布在[0,1]之间
self.dropout = nn.Dropout(dropout) # 丢弃层

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
# inner_dim乘以三,因为输出分成查询,键,值,每个部分的维度都是inner_dim,过程相当于三次全连接层

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity() # 若是多头注意力,则进入全连接层,丢弃层;若不是,则恒等映射,不做任何操作输出

def forward(self, x):
x = self.norm(x) # 层归一化

qkv = self.to_qkv(x).chunk(3, dim=-1) # 三次全连接后将最后一个维度分成三份
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
# b是batch_size,h是注意力头数,n是输入序列长度(输入图像分割成图块的数量patches)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # q乘以k的转置,再乘以1/根号dim_head

attn = self.attend(dots) # softmax
attn = self.dropout(attn) # 丢弃层

out = torch.matmul(attn, v) # 再乘以v
out = rearrange(out, 'b h n d -> b n (h d)') # 恢复原状
return self.to_out(out)

5、MLP多层感知机/前馈神经网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class FeedForward(nn.Module):  # 前馈神经网络FFN,也叫多层感知机MLP
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim), # 层归一化
nn.Linear(dim, hidden_dim), # 全连接层
nn.GELU(), # GELU激活函数
nn.Dropout(dropout), # 丢弃层
nn.Linear(hidden_dim, dim), # 全连接层
nn.Dropout(dropout) # 丢弃层
)

def forward(self, x):
return self.net(x)