magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/cli/magicpdf.py +18 -7
- magic_pdf/libs/config_reader.py +10 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +196 -0
- magic_pdf/model/pek_sub_modules/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
- magic_pdf/model/pek_sub_modules/post_process.py +36 -0
- magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
- magic_pdf/model/pp_structure_v2.py +7 -0
- magic_pdf/pipe/AbsPipe.py +8 -14
- magic_pdf/pipe/OCRPipe.py +12 -8
- magic_pdf/pipe/TXTPipe.py +12 -8
- magic_pdf/pipe/UNIPipe.py +9 -7
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
- magic_pdf/resources/model_config/model_configs.yaml +9 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
- magic_pdf/model/360_layout_analysis.py +0 -8
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,476 @@
|
|
1
|
+
"""
|
2
|
+
Mostly copy-paste from DINO and timm library:
|
3
|
+
https://github.com/facebookresearch/dino
|
4
|
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
5
|
+
"""
|
6
|
+
import warnings
|
7
|
+
|
8
|
+
import math
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.utils.checkpoint as checkpoint
|
12
|
+
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
|
13
|
+
from functools import partial
|
14
|
+
|
15
|
+
def _cfg(url='', **kwargs):
|
16
|
+
return {
|
17
|
+
'url': url,
|
18
|
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
19
|
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
20
|
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
21
|
+
**kwargs
|
22
|
+
}
|
23
|
+
|
24
|
+
class DropPath(nn.Module):
|
25
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, drop_prob=None):
|
29
|
+
super(DropPath, self).__init__()
|
30
|
+
self.drop_prob = drop_prob
|
31
|
+
|
32
|
+
def forward(self, x):
|
33
|
+
return drop_path(x, self.drop_prob, self.training)
|
34
|
+
|
35
|
+
def extra_repr(self) -> str:
|
36
|
+
return 'p={}'.format(self.drop_prob)
|
37
|
+
|
38
|
+
|
39
|
+
class Mlp(nn.Module):
|
40
|
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
41
|
+
super().__init__()
|
42
|
+
out_features = out_features or in_features
|
43
|
+
hidden_features = hidden_features or in_features
|
44
|
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
45
|
+
self.act = act_layer()
|
46
|
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
47
|
+
self.drop = nn.Dropout(drop)
|
48
|
+
|
49
|
+
def forward(self, x):
|
50
|
+
x = self.fc1(x)
|
51
|
+
x = self.act(x)
|
52
|
+
x = self.drop(x)
|
53
|
+
x = self.fc2(x)
|
54
|
+
x = self.drop(x)
|
55
|
+
return x
|
56
|
+
|
57
|
+
|
58
|
+
class Attention(nn.Module):
|
59
|
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
60
|
+
super().__init__()
|
61
|
+
self.num_heads = num_heads
|
62
|
+
head_dim = dim // num_heads
|
63
|
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
64
|
+
self.scale = qk_scale or head_dim ** -0.5
|
65
|
+
|
66
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
67
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
68
|
+
self.proj = nn.Linear(dim, dim)
|
69
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
70
|
+
|
71
|
+
def forward(self, x):
|
72
|
+
B, N, C = x.shape
|
73
|
+
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
74
|
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
75
|
+
|
76
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
77
|
+
attn = attn.softmax(dim=-1)
|
78
|
+
attn = self.attn_drop(attn)
|
79
|
+
|
80
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
81
|
+
x = self.proj(x)
|
82
|
+
x = self.proj_drop(x)
|
83
|
+
return x
|
84
|
+
|
85
|
+
|
86
|
+
class Block(nn.Module):
|
87
|
+
|
88
|
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
89
|
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
90
|
+
super().__init__()
|
91
|
+
self.norm1 = norm_layer(dim)
|
92
|
+
self.attn = Attention(
|
93
|
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
94
|
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
95
|
+
self.drop_path = DropPath(
|
96
|
+
drop_path) if drop_path > 0. else nn.Identity()
|
97
|
+
self.norm2 = norm_layer(dim)
|
98
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
99
|
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
100
|
+
act_layer=act_layer, drop=drop)
|
101
|
+
|
102
|
+
def forward(self, x):
|
103
|
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
104
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
105
|
+
return x
|
106
|
+
|
107
|
+
|
108
|
+
class PatchEmbed(nn.Module):
|
109
|
+
""" Image to Patch Embedding
|
110
|
+
"""
|
111
|
+
|
112
|
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
113
|
+
super().__init__()
|
114
|
+
img_size = to_2tuple(img_size)
|
115
|
+
patch_size = to_2tuple(patch_size)
|
116
|
+
|
117
|
+
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
118
|
+
|
119
|
+
self.num_patches_w, self.num_patches_h = self.window_size
|
120
|
+
|
121
|
+
self.num_patches = self.window_size[0] * self.window_size[1]
|
122
|
+
self.img_size = img_size
|
123
|
+
self.patch_size = patch_size
|
124
|
+
|
125
|
+
self.proj = nn.Conv2d(in_chans, embed_dim,
|
126
|
+
kernel_size=patch_size, stride=patch_size)
|
127
|
+
|
128
|
+
def forward(self, x):
|
129
|
+
x = self.proj(x)
|
130
|
+
return x
|
131
|
+
|
132
|
+
|
133
|
+
class HybridEmbed(nn.Module):
|
134
|
+
""" CNN Feature Map Embedding
|
135
|
+
Extract feature map from CNN, flatten, project to embedding dim.
|
136
|
+
"""
|
137
|
+
|
138
|
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
139
|
+
super().__init__()
|
140
|
+
assert isinstance(backbone, nn.Module)
|
141
|
+
img_size = to_2tuple(img_size)
|
142
|
+
self.img_size = img_size
|
143
|
+
self.backbone = backbone
|
144
|
+
if feature_size is None:
|
145
|
+
with torch.no_grad():
|
146
|
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
147
|
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
148
|
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
149
|
+
training = backbone.training
|
150
|
+
if training:
|
151
|
+
backbone.eval()
|
152
|
+
o = self.backbone(torch.zeros(
|
153
|
+
1, in_chans, img_size[0], img_size[1]))[-1]
|
154
|
+
feature_size = o.shape[-2:]
|
155
|
+
feature_dim = o.shape[1]
|
156
|
+
backbone.train(training)
|
157
|
+
else:
|
158
|
+
feature_size = to_2tuple(feature_size)
|
159
|
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
160
|
+
self.num_patches = feature_size[0] * feature_size[1]
|
161
|
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
162
|
+
|
163
|
+
def forward(self, x):
|
164
|
+
x = self.backbone(x)[-1]
|
165
|
+
x = x.flatten(2).transpose(1, 2)
|
166
|
+
x = self.proj(x)
|
167
|
+
return x
|
168
|
+
|
169
|
+
|
170
|
+
class ViT(nn.Module):
|
171
|
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(self,
|
175
|
+
model_name='vit_base_patch16_224',
|
176
|
+
img_size=384,
|
177
|
+
patch_size=16,
|
178
|
+
in_chans=3,
|
179
|
+
embed_dim=1024,
|
180
|
+
depth=24,
|
181
|
+
num_heads=16,
|
182
|
+
num_classes=19,
|
183
|
+
mlp_ratio=4.,
|
184
|
+
qkv_bias=True,
|
185
|
+
qk_scale=None,
|
186
|
+
drop_rate=0.1,
|
187
|
+
attn_drop_rate=0.,
|
188
|
+
drop_path_rate=0.,
|
189
|
+
hybrid_backbone=None,
|
190
|
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
191
|
+
norm_cfg=None,
|
192
|
+
pos_embed_interp=False,
|
193
|
+
random_init=False,
|
194
|
+
align_corners=False,
|
195
|
+
use_checkpoint=False,
|
196
|
+
num_extra_tokens=1,
|
197
|
+
out_features=None,
|
198
|
+
**kwargs,
|
199
|
+
):
|
200
|
+
|
201
|
+
super(ViT, self).__init__()
|
202
|
+
self.model_name = model_name
|
203
|
+
self.img_size = img_size
|
204
|
+
self.patch_size = patch_size
|
205
|
+
self.in_chans = in_chans
|
206
|
+
self.embed_dim = embed_dim
|
207
|
+
self.depth = depth
|
208
|
+
self.num_heads = num_heads
|
209
|
+
self.num_classes = num_classes
|
210
|
+
self.mlp_ratio = mlp_ratio
|
211
|
+
self.qkv_bias = qkv_bias
|
212
|
+
self.qk_scale = qk_scale
|
213
|
+
self.drop_rate = drop_rate
|
214
|
+
self.attn_drop_rate = attn_drop_rate
|
215
|
+
self.drop_path_rate = drop_path_rate
|
216
|
+
self.hybrid_backbone = hybrid_backbone
|
217
|
+
self.norm_layer = norm_layer
|
218
|
+
self.norm_cfg = norm_cfg
|
219
|
+
self.pos_embed_interp = pos_embed_interp
|
220
|
+
self.random_init = random_init
|
221
|
+
self.align_corners = align_corners
|
222
|
+
self.use_checkpoint = use_checkpoint
|
223
|
+
self.num_extra_tokens = num_extra_tokens
|
224
|
+
self.out_features = out_features
|
225
|
+
self.out_indices = [int(name[5:]) for name in out_features]
|
226
|
+
|
227
|
+
# self.num_stages = self.depth
|
228
|
+
# self.out_indices = tuple(range(self.num_stages))
|
229
|
+
|
230
|
+
if self.hybrid_backbone is not None:
|
231
|
+
self.patch_embed = HybridEmbed(
|
232
|
+
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
|
233
|
+
else:
|
234
|
+
self.patch_embed = PatchEmbed(
|
235
|
+
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
|
236
|
+
self.num_patches = self.patch_embed.num_patches
|
237
|
+
|
238
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
239
|
+
|
240
|
+
if self.num_extra_tokens == 2:
|
241
|
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
242
|
+
|
243
|
+
self.pos_embed = nn.Parameter(torch.zeros(
|
244
|
+
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
|
245
|
+
self.pos_drop = nn.Dropout(p=self.drop_rate)
|
246
|
+
|
247
|
+
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
248
|
+
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
|
249
|
+
self.depth)] # stochastic depth decay rule
|
250
|
+
self.blocks = nn.ModuleList([
|
251
|
+
Block(
|
252
|
+
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
|
253
|
+
qk_scale=self.qk_scale,
|
254
|
+
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
|
255
|
+
for i in range(self.depth)])
|
256
|
+
|
257
|
+
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
258
|
+
# self.repr = nn.Linear(embed_dim, representation_size)
|
259
|
+
# self.repr_act = nn.Tanh()
|
260
|
+
|
261
|
+
if patch_size == 16:
|
262
|
+
self.fpn1 = nn.Sequential(
|
263
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
264
|
+
nn.SyncBatchNorm(embed_dim),
|
265
|
+
nn.GELU(),
|
266
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
267
|
+
)
|
268
|
+
|
269
|
+
self.fpn2 = nn.Sequential(
|
270
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
271
|
+
)
|
272
|
+
|
273
|
+
self.fpn3 = nn.Identity()
|
274
|
+
|
275
|
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
276
|
+
elif patch_size == 8:
|
277
|
+
self.fpn1 = nn.Sequential(
|
278
|
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
279
|
+
)
|
280
|
+
|
281
|
+
self.fpn2 = nn.Identity()
|
282
|
+
|
283
|
+
self.fpn3 = nn.Sequential(
|
284
|
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
285
|
+
)
|
286
|
+
|
287
|
+
self.fpn4 = nn.Sequential(
|
288
|
+
nn.MaxPool2d(kernel_size=4, stride=4),
|
289
|
+
)
|
290
|
+
|
291
|
+
trunc_normal_(self.pos_embed, std=.02)
|
292
|
+
trunc_normal_(self.cls_token, std=.02)
|
293
|
+
if self.num_extra_tokens==2:
|
294
|
+
trunc_normal_(self.dist_token, std=0.2)
|
295
|
+
self.apply(self._init_weights)
|
296
|
+
# self.fix_init_weight()
|
297
|
+
|
298
|
+
def fix_init_weight(self):
|
299
|
+
def rescale(param, layer_id):
|
300
|
+
param.div_(math.sqrt(2.0 * layer_id))
|
301
|
+
|
302
|
+
for layer_id, layer in enumerate(self.blocks):
|
303
|
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
304
|
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
305
|
+
|
306
|
+
def _init_weights(self, m):
|
307
|
+
if isinstance(m, nn.Linear):
|
308
|
+
trunc_normal_(m.weight, std=.02)
|
309
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
310
|
+
nn.init.constant_(m.bias, 0)
|
311
|
+
elif isinstance(m, nn.LayerNorm):
|
312
|
+
nn.init.constant_(m.bias, 0)
|
313
|
+
nn.init.constant_(m.weight, 1.0)
|
314
|
+
|
315
|
+
'''
|
316
|
+
def init_weights(self):
|
317
|
+
logger = get_root_logger()
|
318
|
+
|
319
|
+
trunc_normal_(self.pos_embed, std=.02)
|
320
|
+
trunc_normal_(self.cls_token, std=.02)
|
321
|
+
self.apply(self._init_weights)
|
322
|
+
|
323
|
+
if self.init_cfg is None:
|
324
|
+
logger.warn(f'No pre-trained weights for '
|
325
|
+
f'{self.__class__.__name__}, '
|
326
|
+
f'training start from scratch')
|
327
|
+
else:
|
328
|
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
329
|
+
f'specify `Pretrained` in ' \
|
330
|
+
f'`init_cfg` in ' \
|
331
|
+
f'{self.__class__.__name__} '
|
332
|
+
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
|
333
|
+
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
|
334
|
+
'''
|
335
|
+
|
336
|
+
def get_num_layers(self):
|
337
|
+
return len(self.blocks)
|
338
|
+
|
339
|
+
@torch.jit.ignore
|
340
|
+
def no_weight_decay(self):
|
341
|
+
return {'pos_embed', 'cls_token'}
|
342
|
+
|
343
|
+
def _conv_filter(self, state_dict, patch_size=16):
|
344
|
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
345
|
+
out_dict = {}
|
346
|
+
for k, v in state_dict.items():
|
347
|
+
if 'patch_embed.proj.weight' in k:
|
348
|
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
349
|
+
out_dict[k] = v
|
350
|
+
return out_dict
|
351
|
+
|
352
|
+
def to_2D(self, x):
|
353
|
+
n, hw, c = x.shape
|
354
|
+
h = w = int(math.sqrt(hw))
|
355
|
+
x = x.transpose(1, 2).reshape(n, c, h, w)
|
356
|
+
return x
|
357
|
+
|
358
|
+
def to_1D(self, x):
|
359
|
+
n, c, h, w = x.shape
|
360
|
+
x = x.reshape(n, c, -1).transpose(1, 2)
|
361
|
+
return x
|
362
|
+
|
363
|
+
def interpolate_pos_encoding(self, x, w, h):
|
364
|
+
npatch = x.shape[1] - self.num_extra_tokens
|
365
|
+
N = self.pos_embed.shape[1] - self.num_extra_tokens
|
366
|
+
if npatch == N and w == h:
|
367
|
+
return self.pos_embed
|
368
|
+
|
369
|
+
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
|
370
|
+
|
371
|
+
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
|
372
|
+
|
373
|
+
dim = x.shape[-1]
|
374
|
+
w0 = w // self.patch_embed.patch_size[0]
|
375
|
+
h0 = h // self.patch_embed.patch_size[1]
|
376
|
+
# we add a small number to avoid floating point error in the interpolation
|
377
|
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
378
|
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
379
|
+
patch_pos_embed = nn.functional.interpolate(
|
380
|
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
381
|
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
382
|
+
mode='bicubic',
|
383
|
+
)
|
384
|
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
385
|
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
386
|
+
|
387
|
+
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
|
388
|
+
|
389
|
+
def prepare_tokens(self, x, mask=None):
|
390
|
+
B, nc, w, h = x.shape
|
391
|
+
# patch linear embedding
|
392
|
+
x = self.patch_embed(x)
|
393
|
+
|
394
|
+
# mask image modeling
|
395
|
+
if mask is not None:
|
396
|
+
x = self.mask_model(x, mask)
|
397
|
+
x = x.flatten(2).transpose(1, 2)
|
398
|
+
|
399
|
+
# add the [CLS] token to the embed patch tokens
|
400
|
+
all_tokens = [self.cls_token.expand(B, -1, -1)]
|
401
|
+
|
402
|
+
if self.num_extra_tokens == 2:
|
403
|
+
dist_tokens = self.dist_token.expand(B, -1, -1)
|
404
|
+
all_tokens.append(dist_tokens)
|
405
|
+
all_tokens.append(x)
|
406
|
+
|
407
|
+
x = torch.cat(all_tokens, dim=1)
|
408
|
+
|
409
|
+
# add positional encoding to each token
|
410
|
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
411
|
+
|
412
|
+
return self.pos_drop(x)
|
413
|
+
|
414
|
+
def forward_features(self, x):
|
415
|
+
# print(f"==========shape of x is {x.shape}==========")
|
416
|
+
B, _, H, W = x.shape
|
417
|
+
Hp, Wp = H // self.patch_size, W // self.patch_size
|
418
|
+
x = self.prepare_tokens(x)
|
419
|
+
|
420
|
+
features = []
|
421
|
+
for i, blk in enumerate(self.blocks):
|
422
|
+
if self.use_checkpoint:
|
423
|
+
x = checkpoint.checkpoint(blk, x)
|
424
|
+
else:
|
425
|
+
x = blk(x)
|
426
|
+
if i in self.out_indices:
|
427
|
+
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
|
428
|
+
features.append(xp.contiguous())
|
429
|
+
|
430
|
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
431
|
+
for i in range(len(features)):
|
432
|
+
features[i] = ops[i](features[i])
|
433
|
+
|
434
|
+
feat_out = {}
|
435
|
+
|
436
|
+
for name, value in zip(self.out_features, features):
|
437
|
+
feat_out[name] = value
|
438
|
+
|
439
|
+
return feat_out
|
440
|
+
|
441
|
+
def forward(self, x):
|
442
|
+
x = self.forward_features(x)
|
443
|
+
return x
|
444
|
+
|
445
|
+
|
446
|
+
def deit_base_patch16(pretrained=False, **kwargs):
|
447
|
+
model = ViT(
|
448
|
+
patch_size=16,
|
449
|
+
drop_rate=0.,
|
450
|
+
embed_dim=768,
|
451
|
+
depth=12,
|
452
|
+
num_heads=12,
|
453
|
+
num_classes=1000,
|
454
|
+
mlp_ratio=4.,
|
455
|
+
qkv_bias=True,
|
456
|
+
use_checkpoint=True,
|
457
|
+
num_extra_tokens=2,
|
458
|
+
**kwargs)
|
459
|
+
model.default_cfg = _cfg()
|
460
|
+
return model
|
461
|
+
|
462
|
+
def mae_base_patch16(pretrained=False, **kwargs):
|
463
|
+
model = ViT(
|
464
|
+
patch_size=16,
|
465
|
+
drop_rate=0.,
|
466
|
+
embed_dim=768,
|
467
|
+
depth=12,
|
468
|
+
num_heads=12,
|
469
|
+
num_classes=1000,
|
470
|
+
mlp_ratio=4.,
|
471
|
+
qkv_bias=True,
|
472
|
+
use_checkpoint=True,
|
473
|
+
num_extra_tokens=1,
|
474
|
+
**kwargs)
|
475
|
+
model.default_cfg = _cfg()
|
476
|
+
return model
|
@@ -0,0 +1,171 @@
|
|
1
|
+
'''
|
2
|
+
Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
|
3
|
+
'''
|
4
|
+
|
5
|
+
|
6
|
+
import json
|
7
|
+
import os
|
8
|
+
from pathlib import Path
|
9
|
+
import datasets
|
10
|
+
from .image_utils import load_image, normalize_bbox
|
11
|
+
logger = datasets.logging.get_logger(__name__)
|
12
|
+
_CITATION = """\
|
13
|
+
@article{park2019cord,
|
14
|
+
title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
|
15
|
+
author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
|
16
|
+
booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
|
17
|
+
year={2019}
|
18
|
+
}
|
19
|
+
"""
|
20
|
+
_DESCRIPTION = """\
|
21
|
+
https://github.com/clovaai/cord/
|
22
|
+
"""
|
23
|
+
|
24
|
+
def quad_to_box(quad):
|
25
|
+
# test 87 is wrongly annotated
|
26
|
+
box = (
|
27
|
+
max(0, quad["x1"]),
|
28
|
+
max(0, quad["y1"]),
|
29
|
+
quad["x3"],
|
30
|
+
quad["y3"]
|
31
|
+
)
|
32
|
+
if box[3] < box[1]:
|
33
|
+
bbox = list(box)
|
34
|
+
tmp = bbox[3]
|
35
|
+
bbox[3] = bbox[1]
|
36
|
+
bbox[1] = tmp
|
37
|
+
box = tuple(bbox)
|
38
|
+
if box[2] < box[0]:
|
39
|
+
bbox = list(box)
|
40
|
+
tmp = bbox[2]
|
41
|
+
bbox[2] = bbox[0]
|
42
|
+
bbox[0] = tmp
|
43
|
+
box = tuple(bbox)
|
44
|
+
return box
|
45
|
+
|
46
|
+
def _get_drive_url(url):
|
47
|
+
base_url = 'https://drive.google.com/uc?id='
|
48
|
+
split_url = url.split('/')
|
49
|
+
return base_url + split_url[5]
|
50
|
+
|
51
|
+
_URLS = [
|
52
|
+
_get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
|
53
|
+
_get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
|
54
|
+
# If you failed to download the dataset through the automatic downloader,
|
55
|
+
# you can download it manually and modify the code to get the local dataset.
|
56
|
+
# Or you can use the following links. Please follow the original LICENSE of CORD for usage.
|
57
|
+
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
|
58
|
+
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
|
59
|
+
]
|
60
|
+
|
61
|
+
class CordConfig(datasets.BuilderConfig):
|
62
|
+
"""BuilderConfig for CORD"""
|
63
|
+
def __init__(self, **kwargs):
|
64
|
+
"""BuilderConfig for CORD.
|
65
|
+
Args:
|
66
|
+
**kwargs: keyword arguments forwarded to super.
|
67
|
+
"""
|
68
|
+
super(CordConfig, self).__init__(**kwargs)
|
69
|
+
|
70
|
+
class Cord(datasets.GeneratorBasedBuilder):
|
71
|
+
BUILDER_CONFIGS = [
|
72
|
+
CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
|
73
|
+
]
|
74
|
+
|
75
|
+
def _info(self):
|
76
|
+
return datasets.DatasetInfo(
|
77
|
+
description=_DESCRIPTION,
|
78
|
+
features=datasets.Features(
|
79
|
+
{
|
80
|
+
"id": datasets.Value("string"),
|
81
|
+
"words": datasets.Sequence(datasets.Value("string")),
|
82
|
+
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
|
83
|
+
"ner_tags": datasets.Sequence(
|
84
|
+
datasets.features.ClassLabel(
|
85
|
+
names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
|
86
|
+
)
|
87
|
+
),
|
88
|
+
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
|
89
|
+
"image_path": datasets.Value("string"),
|
90
|
+
}
|
91
|
+
),
|
92
|
+
supervised_keys=None,
|
93
|
+
citation=_CITATION,
|
94
|
+
homepage="https://github.com/clovaai/cord/",
|
95
|
+
)
|
96
|
+
|
97
|
+
def _split_generators(self, dl_manager):
|
98
|
+
"""Returns SplitGenerators."""
|
99
|
+
"""Uses local files located with data_dir"""
|
100
|
+
downloaded_file = dl_manager.download_and_extract(_URLS)
|
101
|
+
# move files from the second URL together with files from the first one.
|
102
|
+
dest = Path(downloaded_file[0])/"CORD"
|
103
|
+
for split in ["train", "dev", "test"]:
|
104
|
+
for file_type in ["image", "json"]:
|
105
|
+
if split == "test" and file_type == "json":
|
106
|
+
continue
|
107
|
+
files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
|
108
|
+
for f in files:
|
109
|
+
os.rename(f, dest/split/file_type/f.name)
|
110
|
+
return [
|
111
|
+
datasets.SplitGenerator(
|
112
|
+
name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
|
113
|
+
),
|
114
|
+
datasets.SplitGenerator(
|
115
|
+
name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
|
116
|
+
),
|
117
|
+
datasets.SplitGenerator(
|
118
|
+
name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
|
119
|
+
),
|
120
|
+
]
|
121
|
+
|
122
|
+
def get_line_bbox(self, bboxs):
|
123
|
+
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
|
124
|
+
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
|
125
|
+
|
126
|
+
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
|
127
|
+
|
128
|
+
assert x1 >= x0 and y1 >= y0
|
129
|
+
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
|
130
|
+
return bbox
|
131
|
+
|
132
|
+
def _generate_examples(self, filepath):
|
133
|
+
logger.info("⏳ Generating examples from = %s", filepath)
|
134
|
+
ann_dir = os.path.join(filepath, "json")
|
135
|
+
img_dir = os.path.join(filepath, "image")
|
136
|
+
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
|
137
|
+
words = []
|
138
|
+
bboxes = []
|
139
|
+
ner_tags = []
|
140
|
+
file_path = os.path.join(ann_dir, file)
|
141
|
+
with open(file_path, "r", encoding="utf8") as f:
|
142
|
+
data = json.load(f)
|
143
|
+
image_path = os.path.join(img_dir, file)
|
144
|
+
image_path = image_path.replace("json", "png")
|
145
|
+
image, size = load_image(image_path)
|
146
|
+
for item in data["valid_line"]:
|
147
|
+
cur_line_bboxes = []
|
148
|
+
line_words, label = item["words"], item["category"]
|
149
|
+
line_words = [w for w in line_words if w["text"].strip() != ""]
|
150
|
+
if len(line_words) == 0:
|
151
|
+
continue
|
152
|
+
if label == "other":
|
153
|
+
for w in line_words:
|
154
|
+
words.append(w["text"])
|
155
|
+
ner_tags.append("O")
|
156
|
+
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
|
157
|
+
else:
|
158
|
+
words.append(line_words[0]["text"])
|
159
|
+
ner_tags.append("B-" + label.upper())
|
160
|
+
cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
|
161
|
+
for w in line_words[1:]:
|
162
|
+
words.append(w["text"])
|
163
|
+
ner_tags.append("I-" + label.upper())
|
164
|
+
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
|
165
|
+
# by default: --segment_level_layout 1
|
166
|
+
# if do not want to use segment_level_layout, comment the following line
|
167
|
+
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
|
168
|
+
bboxes.extend(cur_line_bboxes)
|
169
|
+
# yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
|
170
|
+
yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
|
171
|
+
"image": image, "image_path": image_path}
|