magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/cli/magicpdf.py +18 -7
- magic_pdf/libs/config_reader.py +10 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +196 -0
- magic_pdf/model/pek_sub_modules/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
- magic_pdf/model/pek_sub_modules/post_process.py +36 -0
- magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
- magic_pdf/model/pp_structure_v2.py +7 -0
- magic_pdf/pipe/AbsPipe.py +8 -14
- magic_pdf/pipe/OCRPipe.py +12 -8
- magic_pdf/pipe/TXTPipe.py +12 -8
- magic_pdf/pipe/UNIPipe.py +9 -7
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
- magic_pdf/resources/model_config/model_configs.yaml +9 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
- magic_pdf/model/360_layout_analysis.py +0 -8
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
|
|
1
|
+
""" Vision Transformer (ViT) in PyTorch
|
2
|
+
|
3
|
+
A PyTorch implement of Vision Transformers as described in
|
4
|
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
5
|
+
|
6
|
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
7
|
+
|
8
|
+
Status/TODO:
|
9
|
+
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
10
|
+
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
11
|
+
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
12
|
+
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
13
|
+
|
14
|
+
Acknowledgments:
|
15
|
+
* The paper authors for releasing code and weights, thanks!
|
16
|
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
17
|
+
for some einops/einsum fun
|
18
|
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
19
|
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
20
|
+
|
21
|
+
Hacked together by / Copyright 2020 Ross Wightman
|
22
|
+
"""
|
23
|
+
import warnings
|
24
|
+
import math
|
25
|
+
import torch
|
26
|
+
from functools import partial
|
27
|
+
import torch.nn as nn
|
28
|
+
import torch.nn.functional as F
|
29
|
+
import torch.utils.checkpoint as checkpoint
|
30
|
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
31
|
+
|
32
|
+
|
33
|
+
def _cfg(url='', **kwargs):
|
34
|
+
return {
|
35
|
+
'url': url,
|
36
|
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
37
|
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
38
|
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
39
|
+
**kwargs
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
class DropPath(nn.Module):
|
44
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(self, drop_prob=None):
|
48
|
+
super(DropPath, self).__init__()
|
49
|
+
self.drop_prob = drop_prob
|
50
|
+
|
51
|
+
def forward(self, x):
|
52
|
+
return drop_path(x, self.drop_prob, self.training)
|
53
|
+
|
54
|
+
def extra_repr(self) -> str:
|
55
|
+
return 'p={}'.format(self.drop_prob)
|
56
|
+
|
57
|
+
|
58
|
+
class Mlp(nn.Module):
|
59
|
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
60
|
+
super().__init__()
|
61
|
+
out_features = out_features or in_features
|
62
|
+
hidden_features = hidden_features or in_features
|
63
|
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
64
|
+
self.act = act_layer()
|
65
|
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
66
|
+
self.drop = nn.Dropout(drop)
|
67
|
+
|
68
|
+
def forward(self, x):
|
69
|
+
x = self.fc1(x)
|
70
|
+
x = self.act(x)
|
71
|
+
# x = self.drop(x)
|
72
|
+
# commit this for the orignal BERT implement
|
73
|
+
x = self.fc2(x)
|
74
|
+
x = self.drop(x)
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class Attention(nn.Module):
|
79
|
+
def __init__(
|
80
|
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
81
|
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
82
|
+
super().__init__()
|
83
|
+
self.num_heads = num_heads
|
84
|
+
head_dim = dim // num_heads
|
85
|
+
if attn_head_dim is not None:
|
86
|
+
head_dim = attn_head_dim
|
87
|
+
all_head_dim = head_dim * self.num_heads
|
88
|
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
89
|
+
self.scale = qk_scale or head_dim ** -0.5
|
90
|
+
|
91
|
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
92
|
+
if qkv_bias:
|
93
|
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
94
|
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
95
|
+
else:
|
96
|
+
self.q_bias = None
|
97
|
+
self.v_bias = None
|
98
|
+
|
99
|
+
if window_size:
|
100
|
+
self.window_size = window_size
|
101
|
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
102
|
+
self.relative_position_bias_table = nn.Parameter(
|
103
|
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
104
|
+
# cls to token & token 2 cls & cls to cls
|
105
|
+
|
106
|
+
# get pair-wise relative position index for each token inside the window
|
107
|
+
coords_h = torch.arange(window_size[0])
|
108
|
+
coords_w = torch.arange(window_size[1])
|
109
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
110
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
111
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
112
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
113
|
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
114
|
+
relative_coords[:, :, 1] += window_size[1] - 1
|
115
|
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
116
|
+
relative_position_index = \
|
117
|
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
118
|
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
119
|
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
120
|
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
121
|
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
122
|
+
|
123
|
+
self.register_buffer("relative_position_index", relative_position_index)
|
124
|
+
|
125
|
+
# trunc_normal_(self.relative_position_bias_table, std=.0)
|
126
|
+
else:
|
127
|
+
self.window_size = None
|
128
|
+
self.relative_position_bias_table = None
|
129
|
+
self.relative_position_index = None
|
130
|
+
|
131
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
132
|
+
self.proj = nn.Linear(all_head_dim, dim)
|
133
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
134
|
+
|
135
|
+
def forward(self, x, rel_pos_bias=None, training_window_size=None):
|
136
|
+
B, N, C = x.shape
|
137
|
+
qkv_bias = None
|
138
|
+
if self.q_bias is not None:
|
139
|
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
140
|
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
141
|
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
142
|
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
143
|
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
144
|
+
|
145
|
+
q = q * self.scale
|
146
|
+
attn = (q @ k.transpose(-2, -1))
|
147
|
+
|
148
|
+
if self.relative_position_bias_table is not None:
|
149
|
+
if training_window_size == self.window_size:
|
150
|
+
relative_position_bias = \
|
151
|
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
152
|
+
self.window_size[0] * self.window_size[1] + 1,
|
153
|
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
154
|
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
155
|
+
attn = attn + relative_position_bias.unsqueeze(0)
|
156
|
+
else:
|
157
|
+
training_window_size = tuple(training_window_size.tolist())
|
158
|
+
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
|
159
|
+
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
|
160
|
+
new_relative_position_bias_table = F.interpolate(
|
161
|
+
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
|
162
|
+
2 * self.window_size[0] - 1,
|
163
|
+
2 * self.window_size[1] - 1),
|
164
|
+
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
|
165
|
+
align_corners=False)
|
166
|
+
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
|
167
|
+
new_num_relative_distance - 3).permute(
|
168
|
+
1, 0)
|
169
|
+
new_relative_position_bias_table = torch.cat(
|
170
|
+
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
|
171
|
+
|
172
|
+
# get pair-wise relative position index for each token inside the window
|
173
|
+
coords_h = torch.arange(training_window_size[0])
|
174
|
+
coords_w = torch.arange(training_window_size[1])
|
175
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
176
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
177
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
178
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
179
|
+
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
|
180
|
+
relative_coords[:, :, 1] += training_window_size[1] - 1
|
181
|
+
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
|
182
|
+
relative_position_index = \
|
183
|
+
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
|
184
|
+
dtype=relative_coords.dtype)
|
185
|
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
186
|
+
relative_position_index[0, 0:] = new_num_relative_distance - 3
|
187
|
+
relative_position_index[0:, 0] = new_num_relative_distance - 2
|
188
|
+
relative_position_index[0, 0] = new_num_relative_distance - 1
|
189
|
+
|
190
|
+
relative_position_bias = \
|
191
|
+
new_relative_position_bias_table[relative_position_index.view(-1)].view(
|
192
|
+
training_window_size[0] * training_window_size[1] + 1,
|
193
|
+
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
194
|
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
195
|
+
attn = attn + relative_position_bias.unsqueeze(0)
|
196
|
+
|
197
|
+
if rel_pos_bias is not None:
|
198
|
+
attn = attn + rel_pos_bias
|
199
|
+
|
200
|
+
attn = attn.softmax(dim=-1)
|
201
|
+
attn = self.attn_drop(attn)
|
202
|
+
|
203
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
204
|
+
x = self.proj(x)
|
205
|
+
x = self.proj_drop(x)
|
206
|
+
return x
|
207
|
+
|
208
|
+
|
209
|
+
class Block(nn.Module):
|
210
|
+
|
211
|
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
212
|
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
213
|
+
window_size=None, attn_head_dim=None):
|
214
|
+
super().__init__()
|
215
|
+
self.norm1 = norm_layer(dim)
|
216
|
+
self.attn = Attention(
|
217
|
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
218
|
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
219
|
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
220
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
221
|
+
self.norm2 = norm_layer(dim)
|
222
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
223
|
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
224
|
+
|
225
|
+
if init_values is not None:
|
226
|
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
227
|
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
228
|
+
else:
|
229
|
+
self.gamma_1, self.gamma_2 = None, None
|
230
|
+
|
231
|
+
def forward(self, x, rel_pos_bias=None, training_window_size=None):
|
232
|
+
if self.gamma_1 is None:
|
233
|
+
x = x + self.drop_path(
|
234
|
+
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
|
235
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
236
|
+
else:
|
237
|
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
|
238
|
+
training_window_size=training_window_size))
|
239
|
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
240
|
+
return x
|
241
|
+
|
242
|
+
|
243
|
+
class PatchEmbed(nn.Module):
|
244
|
+
""" Image to Patch Embedding
|
245
|
+
"""
|
246
|
+
|
247
|
+
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
|
248
|
+
super().__init__()
|
249
|
+
img_size = to_2tuple(img_size)
|
250
|
+
patch_size = to_2tuple(patch_size)
|
251
|
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
252
|
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
253
|
+
self.num_patches_w = self.patch_shape[0]
|
254
|
+
self.num_patches_h = self.patch_shape[1]
|
255
|
+
# the so-called patch_shape is the patch shape during pre-training
|
256
|
+
self.img_size = img_size
|
257
|
+
self.patch_size = patch_size
|
258
|
+
self.num_patches = num_patches
|
259
|
+
|
260
|
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
261
|
+
|
262
|
+
def forward(self, x, position_embedding=None, **kwargs):
|
263
|
+
# FIXME look at relaxing size constraints
|
264
|
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
265
|
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
266
|
+
x = self.proj(x)
|
267
|
+
Hp, Wp = x.shape[2], x.shape[3]
|
268
|
+
|
269
|
+
if position_embedding is not None:
|
270
|
+
# interpolate the position embedding to the corresponding size
|
271
|
+
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
|
272
|
+
1, 2)
|
273
|
+
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
|
274
|
+
x = x + position_embedding
|
275
|
+
|
276
|
+
x = x.flatten(2).transpose(1, 2)
|
277
|
+
return x, (Hp, Wp)
|
278
|
+
|
279
|
+
|
280
|
+
class HybridEmbed(nn.Module):
|
281
|
+
""" CNN Feature Map Embedding
|
282
|
+
Extract feature map from CNN, flatten, project to embedding dim.
|
283
|
+
"""
|
284
|
+
|
285
|
+
def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
|
286
|
+
super().__init__()
|
287
|
+
assert isinstance(backbone, nn.Module)
|
288
|
+
img_size = to_2tuple(img_size)
|
289
|
+
self.img_size = img_size
|
290
|
+
self.backbone = backbone
|
291
|
+
if feature_size is None:
|
292
|
+
with torch.no_grad():
|
293
|
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
294
|
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
295
|
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
296
|
+
training = backbone.training
|
297
|
+
if training:
|
298
|
+
backbone.eval()
|
299
|
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
300
|
+
feature_size = o.shape[-2:]
|
301
|
+
feature_dim = o.shape[1]
|
302
|
+
backbone.train(training)
|
303
|
+
else:
|
304
|
+
feature_size = to_2tuple(feature_size)
|
305
|
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
306
|
+
self.num_patches = feature_size[0] * feature_size[1]
|
307
|
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
308
|
+
|
309
|
+
def forward(self, x):
|
310
|
+
x = self.backbone(x)[-1]
|
311
|
+
x = x.flatten(2).transpose(1, 2)
|
312
|
+
x = self.proj(x)
|
313
|
+
return x
|
314
|
+
|
315
|
+
|
316
|
+
class RelativePositionBias(nn.Module):
|
317
|
+
|
318
|
+
def __init__(self, window_size, num_heads):
|
319
|
+
super().__init__()
|
320
|
+
self.window_size = window_size
|
321
|
+
self.num_heads = num_heads
|
322
|
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
323
|
+
self.relative_position_bias_table = nn.Parameter(
|
324
|
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
325
|
+
# cls to token & token 2 cls & cls to cls
|
326
|
+
|
327
|
+
# get pair-wise relative position index for each token inside the window
|
328
|
+
coords_h = torch.arange(window_size[0])
|
329
|
+
coords_w = torch.arange(window_size[1])
|
330
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
331
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
332
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
333
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
334
|
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
335
|
+
relative_coords[:, :, 1] += window_size[1] - 1
|
336
|
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
337
|
+
relative_position_index = \
|
338
|
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
339
|
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
340
|
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
341
|
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
342
|
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
343
|
+
|
344
|
+
self.register_buffer("relative_position_index", relative_position_index)
|
345
|
+
|
346
|
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
347
|
+
|
348
|
+
def forward(self, training_window_size):
|
349
|
+
if training_window_size == self.window_size:
|
350
|
+
relative_position_bias = \
|
351
|
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
352
|
+
self.window_size[0] * self.window_size[1] + 1,
|
353
|
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
354
|
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
355
|
+
else:
|
356
|
+
training_window_size = tuple(training_window_size.tolist())
|
357
|
+
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
|
358
|
+
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
|
359
|
+
new_relative_position_bias_table = F.interpolate(
|
360
|
+
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
|
361
|
+
2 * self.window_size[0] - 1,
|
362
|
+
2 * self.window_size[1] - 1),
|
363
|
+
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
|
364
|
+
align_corners=False)
|
365
|
+
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
|
366
|
+
new_num_relative_distance - 3).permute(
|
367
|
+
1, 0)
|
368
|
+
new_relative_position_bias_table = torch.cat(
|
369
|
+
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
|
370
|
+
|
371
|
+
# get pair-wise relative position index for each token inside the window
|
372
|
+
coords_h = torch.arange(training_window_size[0])
|
373
|
+
coords_w = torch.arange(training_window_size[1])
|
374
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
375
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
376
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
377
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
378
|
+
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
|
379
|
+
relative_coords[:, :, 1] += training_window_size[1] - 1
|
380
|
+
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
|
381
|
+
relative_position_index = \
|
382
|
+
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
|
383
|
+
dtype=relative_coords.dtype)
|
384
|
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
385
|
+
relative_position_index[0, 0:] = new_num_relative_distance - 3
|
386
|
+
relative_position_index[0:, 0] = new_num_relative_distance - 2
|
387
|
+
relative_position_index[0, 0] = new_num_relative_distance - 1
|
388
|
+
|
389
|
+
relative_position_bias = \
|
390
|
+
new_relative_position_bias_table[relative_position_index.view(-1)].view(
|
391
|
+
training_window_size[0] * training_window_size[1] + 1,
|
392
|
+
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
393
|
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
394
|
+
|
395
|
+
return relative_position_bias
|
396
|
+
|
397
|
+
|
398
|
+
class BEiT(nn.Module):
|
399
|
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
400
|
+
"""
|
401
|
+
|
402
|
+
def __init__(self,
|
403
|
+
img_size=[224, 224],
|
404
|
+
patch_size=16,
|
405
|
+
in_chans=3,
|
406
|
+
num_classes=80,
|
407
|
+
embed_dim=768,
|
408
|
+
depth=12,
|
409
|
+
num_heads=12,
|
410
|
+
mlp_ratio=4.,
|
411
|
+
qkv_bias=False,
|
412
|
+
qk_scale=None,
|
413
|
+
drop_rate=0.,
|
414
|
+
attn_drop_rate=0.,
|
415
|
+
drop_path_rate=0.,
|
416
|
+
hybrid_backbone=None,
|
417
|
+
norm_layer=None,
|
418
|
+
init_values=None,
|
419
|
+
use_abs_pos_emb=False,
|
420
|
+
use_rel_pos_bias=False,
|
421
|
+
use_shared_rel_pos_bias=False,
|
422
|
+
use_checkpoint=True,
|
423
|
+
pretrained=None,
|
424
|
+
out_features=None,
|
425
|
+
):
|
426
|
+
|
427
|
+
super(BEiT, self).__init__()
|
428
|
+
|
429
|
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
430
|
+
self.num_classes = num_classes
|
431
|
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
432
|
+
self.use_checkpoint = use_checkpoint
|
433
|
+
|
434
|
+
if hybrid_backbone is not None:
|
435
|
+
self.patch_embed = HybridEmbed(
|
436
|
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
437
|
+
else:
|
438
|
+
self.patch_embed = PatchEmbed(
|
439
|
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
440
|
+
num_patches = self.patch_embed.num_patches
|
441
|
+
self.out_features = out_features
|
442
|
+
self.out_indices = [int(name[5:]) for name in out_features]
|
443
|
+
|
444
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
445
|
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
446
|
+
if use_abs_pos_emb:
|
447
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
448
|
+
else:
|
449
|
+
self.pos_embed = None
|
450
|
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
451
|
+
|
452
|
+
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
|
453
|
+
if use_shared_rel_pos_bias:
|
454
|
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
455
|
+
else:
|
456
|
+
self.rel_pos_bias = None
|
457
|
+
|
458
|
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
459
|
+
self.use_rel_pos_bias = use_rel_pos_bias
|
460
|
+
self.blocks = nn.ModuleList([
|
461
|
+
Block(
|
462
|
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
463
|
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
464
|
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
465
|
+
for i in range(depth)])
|
466
|
+
|
467
|
+
# trunc_normal_(self.mask_token, std=.02)
|
468
|
+
|
469
|
+
if patch_size == 16:
|
470
|
+
self.fpn1 = nn.Sequential(
|
471
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
472
|
+
# nn.SyncBatchNorm(embed_dim),
|
473
|
+
nn.BatchNorm2d(embed_dim),
|
474
|
+
nn.GELU(),
|
475
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
476
|
+
)
|
477
|
+
|
478
|
+
self.fpn2 = nn.Sequential(
|
479
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
480
|
+
)
|
481
|
+
|
482
|
+
self.fpn3 = nn.Identity()
|
483
|
+
|
484
|
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
485
|
+
elif patch_size == 8:
|
486
|
+
self.fpn1 = nn.Sequential(
|
487
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
488
|
+
)
|
489
|
+
|
490
|
+
self.fpn2 = nn.Identity()
|
491
|
+
|
492
|
+
self.fpn3 = nn.Sequential(
|
493
|
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
494
|
+
)
|
495
|
+
|
496
|
+
self.fpn4 = nn.Sequential(
|
497
|
+
nn.MaxPool2d(kernel_size=4, stride=4),
|
498
|
+
)
|
499
|
+
|
500
|
+
if self.pos_embed is not None:
|
501
|
+
trunc_normal_(self.pos_embed, std=.02)
|
502
|
+
trunc_normal_(self.cls_token, std=.02)
|
503
|
+
self.apply(self._init_weights)
|
504
|
+
self.fix_init_weight()
|
505
|
+
|
506
|
+
def fix_init_weight(self):
|
507
|
+
def rescale(param, layer_id):
|
508
|
+
param.div_(math.sqrt(2.0 * layer_id))
|
509
|
+
|
510
|
+
for layer_id, layer in enumerate(self.blocks):
|
511
|
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
512
|
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
513
|
+
|
514
|
+
def _init_weights(self, m):
|
515
|
+
if isinstance(m, nn.Linear):
|
516
|
+
trunc_normal_(m.weight, std=.02)
|
517
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
518
|
+
nn.init.constant_(m.bias, 0)
|
519
|
+
elif isinstance(m, nn.LayerNorm):
|
520
|
+
nn.init.constant_(m.bias, 0)
|
521
|
+
nn.init.constant_(m.weight, 1.0)
|
522
|
+
|
523
|
+
'''
|
524
|
+
def init_weights(self):
|
525
|
+
"""Initialize the weights in backbone.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
pretrained (str, optional): Path to pre-trained weights.
|
529
|
+
Defaults to None.
|
530
|
+
"""
|
531
|
+
logger = get_root_logger()
|
532
|
+
|
533
|
+
if self.pos_embed is not None:
|
534
|
+
trunc_normal_(self.pos_embed, std=.02)
|
535
|
+
trunc_normal_(self.cls_token, std=.02)
|
536
|
+
self.apply(self._init_weights)
|
537
|
+
self.fix_init_weight()
|
538
|
+
|
539
|
+
if self.init_cfg is None:
|
540
|
+
logger.warn(f'No pre-trained weights for '
|
541
|
+
f'{self.__class__.__name__}, '
|
542
|
+
f'training start from scratch')
|
543
|
+
else:
|
544
|
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
545
|
+
f'specify `Pretrained` in ' \
|
546
|
+
f'`init_cfg` in ' \
|
547
|
+
f'{self.__class__.__name__} '
|
548
|
+
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
|
549
|
+
load_checkpoint(self,
|
550
|
+
filename=self.init_cfg['checkpoint'],
|
551
|
+
strict=False,
|
552
|
+
logger=logger,
|
553
|
+
beit_spec_expand_rel_pos = self.use_rel_pos_bias,
|
554
|
+
)
|
555
|
+
'''
|
556
|
+
|
557
|
+
def get_num_layers(self):
|
558
|
+
return len(self.blocks)
|
559
|
+
|
560
|
+
@torch.jit.ignore
|
561
|
+
def no_weight_decay(self):
|
562
|
+
return {'pos_embed', 'cls_token'}
|
563
|
+
|
564
|
+
def forward_features(self, x):
|
565
|
+
B, C, H, W = x.shape
|
566
|
+
x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
|
567
|
+
# Hp, Wp are HW for patches
|
568
|
+
batch_size, seq_len, _ = x.size()
|
569
|
+
|
570
|
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
571
|
+
if self.pos_embed is not None:
|
572
|
+
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
|
573
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
574
|
+
x = self.pos_drop(x)
|
575
|
+
|
576
|
+
features = []
|
577
|
+
training_window_size = torch.tensor([Hp, Wp])
|
578
|
+
|
579
|
+
rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
|
580
|
+
|
581
|
+
for i, blk in enumerate(self.blocks):
|
582
|
+
if self.use_checkpoint:
|
583
|
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
|
584
|
+
else:
|
585
|
+
x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
|
586
|
+
if i in self.out_indices:
|
587
|
+
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
|
588
|
+
features.append(xp.contiguous())
|
589
|
+
|
590
|
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
591
|
+
for i in range(len(features)):
|
592
|
+
features[i] = ops[i](features[i])
|
593
|
+
|
594
|
+
feat_out = {}
|
595
|
+
|
596
|
+
for name, value in zip(self.out_features, features):
|
597
|
+
feat_out[name] = value
|
598
|
+
|
599
|
+
return feat_out
|
600
|
+
|
601
|
+
def forward(self, x):
|
602
|
+
x = self.forward_features(x)
|
603
|
+
return x
|
604
|
+
|
605
|
+
|
606
|
+
def beit_base_patch16(pretrained=False, **kwargs):
|
607
|
+
model = BEiT(
|
608
|
+
patch_size=16,
|
609
|
+
embed_dim=768,
|
610
|
+
depth=12,
|
611
|
+
num_heads=12,
|
612
|
+
mlp_ratio=4,
|
613
|
+
qkv_bias=True,
|
614
|
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
615
|
+
init_values=None,
|
616
|
+
**kwargs)
|
617
|
+
model.default_cfg = _cfg()
|
618
|
+
return model
|
619
|
+
|
620
|
+
def beit_large_patch16(pretrained=False, **kwargs):
|
621
|
+
model = BEiT(
|
622
|
+
patch_size=16,
|
623
|
+
embed_dim=1024,
|
624
|
+
depth=24,
|
625
|
+
num_heads=16,
|
626
|
+
mlp_ratio=4,
|
627
|
+
qkv_bias=True,
|
628
|
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
629
|
+
init_values=None,
|
630
|
+
**kwargs)
|
631
|
+
model.default_cfg = _cfg()
|
632
|
+
return model
|
633
|
+
|
634
|
+
def dit_base_patch16(pretrained=False, **kwargs):
|
635
|
+
model = BEiT(
|
636
|
+
patch_size=16,
|
637
|
+
embed_dim=768,
|
638
|
+
depth=12,
|
639
|
+
num_heads=12,
|
640
|
+
mlp_ratio=4,
|
641
|
+
qkv_bias=True,
|
642
|
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
643
|
+
init_values=0.1,
|
644
|
+
**kwargs)
|
645
|
+
model.default_cfg = _cfg()
|
646
|
+
return model
|
647
|
+
|
648
|
+
def dit_large_patch16(pretrained=False, **kwargs):
|
649
|
+
model = BEiT(
|
650
|
+
patch_size=16,
|
651
|
+
embed_dim=1024,
|
652
|
+
depth=24,
|
653
|
+
num_heads=16,
|
654
|
+
mlp_ratio=4,
|
655
|
+
qkv_bias=True,
|
656
|
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
657
|
+
init_values=1e-5,
|
658
|
+
**kwargs)
|
659
|
+
model.default_cfg = _cfg()
|
660
|
+
return model
|
661
|
+
|
662
|
+
if __name__ == '__main__':
|
663
|
+
model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
|
664
|
+
model = model.to("cuda:0")
|
665
|
+
input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
|
666
|
+
input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
|
667
|
+
input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
|
668
|
+
output1 = model(input1)
|
669
|
+
output2 = model(input2)
|
670
|
+
output3 = model(input3)
|
671
|
+
print("all done")
|