magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. magic_pdf/cli/magicpdf.py +18 -7
  2. magic_pdf/libs/config_reader.py +10 -0
  3. magic_pdf/libs/version.py +1 -1
  4. magic_pdf/model/__init__.py +1 -0
  5. magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
  6. magic_pdf/model/model_list.py +1 -0
  7. magic_pdf/model/pdf_extract_kit.py +196 -0
  8. magic_pdf/model/pek_sub_modules/__init__.py +0 -0
  9. magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
  10. magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
  11. magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
  12. magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
  13. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
  14. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
  15. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
  16. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
  17. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
  18. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
  19. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
  20. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
  21. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
  22. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
  23. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
  24. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
  25. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
  26. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
  27. magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
  28. magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
  29. magic_pdf/model/pek_sub_modules/post_process.py +36 -0
  30. magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
  31. magic_pdf/model/pp_structure_v2.py +7 -0
  32. magic_pdf/pipe/AbsPipe.py +8 -14
  33. magic_pdf/pipe/OCRPipe.py +12 -8
  34. magic_pdf/pipe/TXTPipe.py +12 -8
  35. magic_pdf/pipe/UNIPipe.py +9 -7
  36. magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
  37. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
  38. magic_pdf/resources/model_config/model_configs.yaml +9 -0
  39. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
  40. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
  41. magic_pdf/model/360_layout_analysis.py +0 -8
  42. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
  43. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
  44. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
  45. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ Status/TODO:
9
+ * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
10
+ * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
11
+ * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
12
+ * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
13
+
14
+ Acknowledgments:
15
+ * The paper authors for releasing code and weights, thanks!
16
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
17
+ for some einops/einsum fun
18
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
19
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
20
+
21
+ Hacked together by / Copyright 2020 Ross Wightman
22
+ """
23
+ import warnings
24
+ import math
25
+ import torch
26
+ from functools import partial
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint as checkpoint
30
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
31
+
32
+
33
+ def _cfg(url='', **kwargs):
34
+ return {
35
+ 'url': url,
36
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
37
+ 'crop_pct': .9, 'interpolation': 'bicubic',
38
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
39
+ **kwargs
40
+ }
41
+
42
+
43
+ class DropPath(nn.Module):
44
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
45
+ """
46
+
47
+ def __init__(self, drop_prob=None):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+
51
+ def forward(self, x):
52
+ return drop_path(x, self.drop_prob, self.training)
53
+
54
+ def extra_repr(self) -> str:
55
+ return 'p={}'.format(self.drop_prob)
56
+
57
+
58
+ class Mlp(nn.Module):
59
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
60
+ super().__init__()
61
+ out_features = out_features or in_features
62
+ hidden_features = hidden_features or in_features
63
+ self.fc1 = nn.Linear(in_features, hidden_features)
64
+ self.act = act_layer()
65
+ self.fc2 = nn.Linear(hidden_features, out_features)
66
+ self.drop = nn.Dropout(drop)
67
+
68
+ def forward(self, x):
69
+ x = self.fc1(x)
70
+ x = self.act(x)
71
+ # x = self.drop(x)
72
+ # commit this for the orignal BERT implement
73
+ x = self.fc2(x)
74
+ x = self.drop(x)
75
+ return x
76
+
77
+
78
+ class Attention(nn.Module):
79
+ def __init__(
80
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
81
+ proj_drop=0., window_size=None, attn_head_dim=None):
82
+ super().__init__()
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ if attn_head_dim is not None:
86
+ head_dim = attn_head_dim
87
+ all_head_dim = head_dim * self.num_heads
88
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+
91
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
92
+ if qkv_bias:
93
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
94
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
95
+ else:
96
+ self.q_bias = None
97
+ self.v_bias = None
98
+
99
+ if window_size:
100
+ self.window_size = window_size
101
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
102
+ self.relative_position_bias_table = nn.Parameter(
103
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
104
+ # cls to token & token 2 cls & cls to cls
105
+
106
+ # get pair-wise relative position index for each token inside the window
107
+ coords_h = torch.arange(window_size[0])
108
+ coords_w = torch.arange(window_size[1])
109
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
110
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
111
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
112
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
113
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
114
+ relative_coords[:, :, 1] += window_size[1] - 1
115
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
116
+ relative_position_index = \
117
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
118
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
119
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
120
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
121
+ relative_position_index[0, 0] = self.num_relative_distance - 1
122
+
123
+ self.register_buffer("relative_position_index", relative_position_index)
124
+
125
+ # trunc_normal_(self.relative_position_bias_table, std=.0)
126
+ else:
127
+ self.window_size = None
128
+ self.relative_position_bias_table = None
129
+ self.relative_position_index = None
130
+
131
+ self.attn_drop = nn.Dropout(attn_drop)
132
+ self.proj = nn.Linear(all_head_dim, dim)
133
+ self.proj_drop = nn.Dropout(proj_drop)
134
+
135
+ def forward(self, x, rel_pos_bias=None, training_window_size=None):
136
+ B, N, C = x.shape
137
+ qkv_bias = None
138
+ if self.q_bias is not None:
139
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
140
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
141
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
142
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = (q @ k.transpose(-2, -1))
147
+
148
+ if self.relative_position_bias_table is not None:
149
+ if training_window_size == self.window_size:
150
+ relative_position_bias = \
151
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
152
+ self.window_size[0] * self.window_size[1] + 1,
153
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
154
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
155
+ attn = attn + relative_position_bias.unsqueeze(0)
156
+ else:
157
+ training_window_size = tuple(training_window_size.tolist())
158
+ new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
159
+ # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
160
+ new_relative_position_bias_table = F.interpolate(
161
+ self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
162
+ 2 * self.window_size[0] - 1,
163
+ 2 * self.window_size[1] - 1),
164
+ size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
165
+ align_corners=False)
166
+ new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
167
+ new_num_relative_distance - 3).permute(
168
+ 1, 0)
169
+ new_relative_position_bias_table = torch.cat(
170
+ [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
171
+
172
+ # get pair-wise relative position index for each token inside the window
173
+ coords_h = torch.arange(training_window_size[0])
174
+ coords_w = torch.arange(training_window_size[1])
175
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
176
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
177
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
178
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
179
+ relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
180
+ relative_coords[:, :, 1] += training_window_size[1] - 1
181
+ relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
182
+ relative_position_index = \
183
+ torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
184
+ dtype=relative_coords.dtype)
185
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
186
+ relative_position_index[0, 0:] = new_num_relative_distance - 3
187
+ relative_position_index[0:, 0] = new_num_relative_distance - 2
188
+ relative_position_index[0, 0] = new_num_relative_distance - 1
189
+
190
+ relative_position_bias = \
191
+ new_relative_position_bias_table[relative_position_index.view(-1)].view(
192
+ training_window_size[0] * training_window_size[1] + 1,
193
+ training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
194
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
195
+ attn = attn + relative_position_bias.unsqueeze(0)
196
+
197
+ if rel_pos_bias is not None:
198
+ attn = attn + rel_pos_bias
199
+
200
+ attn = attn.softmax(dim=-1)
201
+ attn = self.attn_drop(attn)
202
+
203
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
204
+ x = self.proj(x)
205
+ x = self.proj_drop(x)
206
+ return x
207
+
208
+
209
+ class Block(nn.Module):
210
+
211
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
212
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
213
+ window_size=None, attn_head_dim=None):
214
+ super().__init__()
215
+ self.norm1 = norm_layer(dim)
216
+ self.attn = Attention(
217
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
218
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
219
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
220
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
221
+ self.norm2 = norm_layer(dim)
222
+ mlp_hidden_dim = int(dim * mlp_ratio)
223
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
224
+
225
+ if init_values is not None:
226
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
227
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
228
+ else:
229
+ self.gamma_1, self.gamma_2 = None, None
230
+
231
+ def forward(self, x, rel_pos_bias=None, training_window_size=None):
232
+ if self.gamma_1 is None:
233
+ x = x + self.drop_path(
234
+ self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
235
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
236
+ else:
237
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
238
+ training_window_size=training_window_size))
239
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
240
+ return x
241
+
242
+
243
+ class PatchEmbed(nn.Module):
244
+ """ Image to Patch Embedding
245
+ """
246
+
247
+ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
248
+ super().__init__()
249
+ img_size = to_2tuple(img_size)
250
+ patch_size = to_2tuple(patch_size)
251
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
252
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
253
+ self.num_patches_w = self.patch_shape[0]
254
+ self.num_patches_h = self.patch_shape[1]
255
+ # the so-called patch_shape is the patch shape during pre-training
256
+ self.img_size = img_size
257
+ self.patch_size = patch_size
258
+ self.num_patches = num_patches
259
+
260
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
261
+
262
+ def forward(self, x, position_embedding=None, **kwargs):
263
+ # FIXME look at relaxing size constraints
264
+ # assert H == self.img_size[0] and W == self.img_size[1], \
265
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
266
+ x = self.proj(x)
267
+ Hp, Wp = x.shape[2], x.shape[3]
268
+
269
+ if position_embedding is not None:
270
+ # interpolate the position embedding to the corresponding size
271
+ position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
272
+ 1, 2)
273
+ position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
274
+ x = x + position_embedding
275
+
276
+ x = x.flatten(2).transpose(1, 2)
277
+ return x, (Hp, Wp)
278
+
279
+
280
+ class HybridEmbed(nn.Module):
281
+ """ CNN Feature Map Embedding
282
+ Extract feature map from CNN, flatten, project to embedding dim.
283
+ """
284
+
285
+ def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
286
+ super().__init__()
287
+ assert isinstance(backbone, nn.Module)
288
+ img_size = to_2tuple(img_size)
289
+ self.img_size = img_size
290
+ self.backbone = backbone
291
+ if feature_size is None:
292
+ with torch.no_grad():
293
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
294
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
295
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
296
+ training = backbone.training
297
+ if training:
298
+ backbone.eval()
299
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
300
+ feature_size = o.shape[-2:]
301
+ feature_dim = o.shape[1]
302
+ backbone.train(training)
303
+ else:
304
+ feature_size = to_2tuple(feature_size)
305
+ feature_dim = self.backbone.feature_info.channels()[-1]
306
+ self.num_patches = feature_size[0] * feature_size[1]
307
+ self.proj = nn.Linear(feature_dim, embed_dim)
308
+
309
+ def forward(self, x):
310
+ x = self.backbone(x)[-1]
311
+ x = x.flatten(2).transpose(1, 2)
312
+ x = self.proj(x)
313
+ return x
314
+
315
+
316
+ class RelativePositionBias(nn.Module):
317
+
318
+ def __init__(self, window_size, num_heads):
319
+ super().__init__()
320
+ self.window_size = window_size
321
+ self.num_heads = num_heads
322
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
323
+ self.relative_position_bias_table = nn.Parameter(
324
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
325
+ # cls to token & token 2 cls & cls to cls
326
+
327
+ # get pair-wise relative position index for each token inside the window
328
+ coords_h = torch.arange(window_size[0])
329
+ coords_w = torch.arange(window_size[1])
330
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
331
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
332
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
333
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
334
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
335
+ relative_coords[:, :, 1] += window_size[1] - 1
336
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
337
+ relative_position_index = \
338
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
339
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
340
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
341
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
342
+ relative_position_index[0, 0] = self.num_relative_distance - 1
343
+
344
+ self.register_buffer("relative_position_index", relative_position_index)
345
+
346
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
347
+
348
+ def forward(self, training_window_size):
349
+ if training_window_size == self.window_size:
350
+ relative_position_bias = \
351
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
352
+ self.window_size[0] * self.window_size[1] + 1,
353
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
354
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
355
+ else:
356
+ training_window_size = tuple(training_window_size.tolist())
357
+ new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
358
+ # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
359
+ new_relative_position_bias_table = F.interpolate(
360
+ self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
361
+ 2 * self.window_size[0] - 1,
362
+ 2 * self.window_size[1] - 1),
363
+ size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
364
+ align_corners=False)
365
+ new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
366
+ new_num_relative_distance - 3).permute(
367
+ 1, 0)
368
+ new_relative_position_bias_table = torch.cat(
369
+ [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
370
+
371
+ # get pair-wise relative position index for each token inside the window
372
+ coords_h = torch.arange(training_window_size[0])
373
+ coords_w = torch.arange(training_window_size[1])
374
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
375
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
376
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
377
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
378
+ relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
379
+ relative_coords[:, :, 1] += training_window_size[1] - 1
380
+ relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
381
+ relative_position_index = \
382
+ torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
383
+ dtype=relative_coords.dtype)
384
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
385
+ relative_position_index[0, 0:] = new_num_relative_distance - 3
386
+ relative_position_index[0:, 0] = new_num_relative_distance - 2
387
+ relative_position_index[0, 0] = new_num_relative_distance - 1
388
+
389
+ relative_position_bias = \
390
+ new_relative_position_bias_table[relative_position_index.view(-1)].view(
391
+ training_window_size[0] * training_window_size[1] + 1,
392
+ training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
393
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
394
+
395
+ return relative_position_bias
396
+
397
+
398
+ class BEiT(nn.Module):
399
+ """ Vision Transformer with support for patch or hybrid CNN input stage
400
+ """
401
+
402
+ def __init__(self,
403
+ img_size=[224, 224],
404
+ patch_size=16,
405
+ in_chans=3,
406
+ num_classes=80,
407
+ embed_dim=768,
408
+ depth=12,
409
+ num_heads=12,
410
+ mlp_ratio=4.,
411
+ qkv_bias=False,
412
+ qk_scale=None,
413
+ drop_rate=0.,
414
+ attn_drop_rate=0.,
415
+ drop_path_rate=0.,
416
+ hybrid_backbone=None,
417
+ norm_layer=None,
418
+ init_values=None,
419
+ use_abs_pos_emb=False,
420
+ use_rel_pos_bias=False,
421
+ use_shared_rel_pos_bias=False,
422
+ use_checkpoint=True,
423
+ pretrained=None,
424
+ out_features=None,
425
+ ):
426
+
427
+ super(BEiT, self).__init__()
428
+
429
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
430
+ self.num_classes = num_classes
431
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
432
+ self.use_checkpoint = use_checkpoint
433
+
434
+ if hybrid_backbone is not None:
435
+ self.patch_embed = HybridEmbed(
436
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
437
+ else:
438
+ self.patch_embed = PatchEmbed(
439
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
440
+ num_patches = self.patch_embed.num_patches
441
+ self.out_features = out_features
442
+ self.out_indices = [int(name[5:]) for name in out_features]
443
+
444
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
445
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
446
+ if use_abs_pos_emb:
447
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
448
+ else:
449
+ self.pos_embed = None
450
+ self.pos_drop = nn.Dropout(p=drop_rate)
451
+
452
+ self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
453
+ if use_shared_rel_pos_bias:
454
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
455
+ else:
456
+ self.rel_pos_bias = None
457
+
458
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
459
+ self.use_rel_pos_bias = use_rel_pos_bias
460
+ self.blocks = nn.ModuleList([
461
+ Block(
462
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
463
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
464
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
465
+ for i in range(depth)])
466
+
467
+ # trunc_normal_(self.mask_token, std=.02)
468
+
469
+ if patch_size == 16:
470
+ self.fpn1 = nn.Sequential(
471
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
472
+ # nn.SyncBatchNorm(embed_dim),
473
+ nn.BatchNorm2d(embed_dim),
474
+ nn.GELU(),
475
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
476
+ )
477
+
478
+ self.fpn2 = nn.Sequential(
479
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
480
+ )
481
+
482
+ self.fpn3 = nn.Identity()
483
+
484
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
485
+ elif patch_size == 8:
486
+ self.fpn1 = nn.Sequential(
487
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
488
+ )
489
+
490
+ self.fpn2 = nn.Identity()
491
+
492
+ self.fpn3 = nn.Sequential(
493
+ nn.MaxPool2d(kernel_size=2, stride=2),
494
+ )
495
+
496
+ self.fpn4 = nn.Sequential(
497
+ nn.MaxPool2d(kernel_size=4, stride=4),
498
+ )
499
+
500
+ if self.pos_embed is not None:
501
+ trunc_normal_(self.pos_embed, std=.02)
502
+ trunc_normal_(self.cls_token, std=.02)
503
+ self.apply(self._init_weights)
504
+ self.fix_init_weight()
505
+
506
+ def fix_init_weight(self):
507
+ def rescale(param, layer_id):
508
+ param.div_(math.sqrt(2.0 * layer_id))
509
+
510
+ for layer_id, layer in enumerate(self.blocks):
511
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
512
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
513
+
514
+ def _init_weights(self, m):
515
+ if isinstance(m, nn.Linear):
516
+ trunc_normal_(m.weight, std=.02)
517
+ if isinstance(m, nn.Linear) and m.bias is not None:
518
+ nn.init.constant_(m.bias, 0)
519
+ elif isinstance(m, nn.LayerNorm):
520
+ nn.init.constant_(m.bias, 0)
521
+ nn.init.constant_(m.weight, 1.0)
522
+
523
+ '''
524
+ def init_weights(self):
525
+ """Initialize the weights in backbone.
526
+
527
+ Args:
528
+ pretrained (str, optional): Path to pre-trained weights.
529
+ Defaults to None.
530
+ """
531
+ logger = get_root_logger()
532
+
533
+ if self.pos_embed is not None:
534
+ trunc_normal_(self.pos_embed, std=.02)
535
+ trunc_normal_(self.cls_token, std=.02)
536
+ self.apply(self._init_weights)
537
+ self.fix_init_weight()
538
+
539
+ if self.init_cfg is None:
540
+ logger.warn(f'No pre-trained weights for '
541
+ f'{self.__class__.__name__}, '
542
+ f'training start from scratch')
543
+ else:
544
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
545
+ f'specify `Pretrained` in ' \
546
+ f'`init_cfg` in ' \
547
+ f'{self.__class__.__name__} '
548
+ logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
549
+ load_checkpoint(self,
550
+ filename=self.init_cfg['checkpoint'],
551
+ strict=False,
552
+ logger=logger,
553
+ beit_spec_expand_rel_pos = self.use_rel_pos_bias,
554
+ )
555
+ '''
556
+
557
+ def get_num_layers(self):
558
+ return len(self.blocks)
559
+
560
+ @torch.jit.ignore
561
+ def no_weight_decay(self):
562
+ return {'pos_embed', 'cls_token'}
563
+
564
+ def forward_features(self, x):
565
+ B, C, H, W = x.shape
566
+ x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
567
+ # Hp, Wp are HW for patches
568
+ batch_size, seq_len, _ = x.size()
569
+
570
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
571
+ if self.pos_embed is not None:
572
+ cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
573
+ x = torch.cat((cls_tokens, x), dim=1)
574
+ x = self.pos_drop(x)
575
+
576
+ features = []
577
+ training_window_size = torch.tensor([Hp, Wp])
578
+
579
+ rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
580
+
581
+ for i, blk in enumerate(self.blocks):
582
+ if self.use_checkpoint:
583
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
584
+ else:
585
+ x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
586
+ if i in self.out_indices:
587
+ xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
588
+ features.append(xp.contiguous())
589
+
590
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
591
+ for i in range(len(features)):
592
+ features[i] = ops[i](features[i])
593
+
594
+ feat_out = {}
595
+
596
+ for name, value in zip(self.out_features, features):
597
+ feat_out[name] = value
598
+
599
+ return feat_out
600
+
601
+ def forward(self, x):
602
+ x = self.forward_features(x)
603
+ return x
604
+
605
+
606
+ def beit_base_patch16(pretrained=False, **kwargs):
607
+ model = BEiT(
608
+ patch_size=16,
609
+ embed_dim=768,
610
+ depth=12,
611
+ num_heads=12,
612
+ mlp_ratio=4,
613
+ qkv_bias=True,
614
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
615
+ init_values=None,
616
+ **kwargs)
617
+ model.default_cfg = _cfg()
618
+ return model
619
+
620
+ def beit_large_patch16(pretrained=False, **kwargs):
621
+ model = BEiT(
622
+ patch_size=16,
623
+ embed_dim=1024,
624
+ depth=24,
625
+ num_heads=16,
626
+ mlp_ratio=4,
627
+ qkv_bias=True,
628
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
629
+ init_values=None,
630
+ **kwargs)
631
+ model.default_cfg = _cfg()
632
+ return model
633
+
634
+ def dit_base_patch16(pretrained=False, **kwargs):
635
+ model = BEiT(
636
+ patch_size=16,
637
+ embed_dim=768,
638
+ depth=12,
639
+ num_heads=12,
640
+ mlp_ratio=4,
641
+ qkv_bias=True,
642
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
643
+ init_values=0.1,
644
+ **kwargs)
645
+ model.default_cfg = _cfg()
646
+ return model
647
+
648
+ def dit_large_patch16(pretrained=False, **kwargs):
649
+ model = BEiT(
650
+ patch_size=16,
651
+ embed_dim=1024,
652
+ depth=24,
653
+ num_heads=16,
654
+ mlp_ratio=4,
655
+ qkv_bias=True,
656
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
657
+ init_values=1e-5,
658
+ **kwargs)
659
+ model.default_cfg = _cfg()
660
+ return model
661
+
662
+ if __name__ == '__main__':
663
+ model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
664
+ model = model.to("cuda:0")
665
+ input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
666
+ input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
667
+ input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
668
+ output1 = model(input1)
669
+ output2 = model(input2)
670
+ output3 = model(input3)
671
+ print("all done")