magic-pdf 0.5.13__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. magic_pdf/cli/magicpdf.py +18 -7
  2. magic_pdf/dict2md/ocr_mkcontent.py +2 -2
  3. magic_pdf/libs/config_reader.py +10 -0
  4. magic_pdf/libs/version.py +1 -1
  5. magic_pdf/model/__init__.py +1 -0
  6. magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
  7. magic_pdf/model/model_list.py +1 -0
  8. magic_pdf/model/pdf_extract_kit.py +200 -0
  9. magic_pdf/model/pek_sub_modules/__init__.py +0 -0
  10. magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
  11. magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
  12. magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
  13. magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
  14. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
  15. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
  16. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
  17. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
  18. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
  19. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
  20. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
  21. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
  22. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
  23. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
  24. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
  25. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
  26. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
  27. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
  28. magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
  29. magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
  30. magic_pdf/model/pek_sub_modules/post_process.py +36 -0
  31. magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
  32. magic_pdf/model/pp_structure_v2.py +7 -0
  33. magic_pdf/pipe/AbsPipe.py +8 -14
  34. magic_pdf/pipe/OCRPipe.py +12 -8
  35. magic_pdf/pipe/TXTPipe.py +12 -8
  36. magic_pdf/pipe/UNIPipe.py +9 -7
  37. magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
  38. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
  39. magic_pdf/resources/model_config/model_configs.yaml +9 -0
  40. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/METADATA +95 -12
  41. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/RECORD +45 -19
  42. magic_pdf/model/360_layout_analysis.py +0 -8
  43. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/LICENSE.md +0 -0
  44. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/WHEEL +0 -0
  45. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/entry_points.txt +0 -0
  46. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,476 @@
1
+ """
2
+ Mostly copy-paste from DINO and timm library:
3
+ https://github.com/facebookresearch/dino
4
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
5
+ """
6
+ import warnings
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import trunc_normal_, drop_path, to_2tuple
13
+ from functools import partial
14
+
15
+ def _cfg(url='', **kwargs):
16
+ return {
17
+ 'url': url,
18
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
19
+ 'crop_pct': .9, 'interpolation': 'bicubic',
20
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
21
+ **kwargs
22
+ }
23
+
24
+ class DropPath(nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26
+ """
27
+
28
+ def __init__(self, drop_prob=None):
29
+ super(DropPath, self).__init__()
30
+ self.drop_prob = drop_prob
31
+
32
+ def forward(self, x):
33
+ return drop_path(x, self.drop_prob, self.training)
34
+
35
+ def extra_repr(self) -> str:
36
+ return 'p={}'.format(self.drop_prob)
37
+
38
+
39
+ class Mlp(nn.Module):
40
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
41
+ super().__init__()
42
+ out_features = out_features or in_features
43
+ hidden_features = hidden_features or in_features
44
+ self.fc1 = nn.Linear(in_features, hidden_features)
45
+ self.act = act_layer()
46
+ self.fc2 = nn.Linear(hidden_features, out_features)
47
+ self.drop = nn.Dropout(drop)
48
+
49
+ def forward(self, x):
50
+ x = self.fc1(x)
51
+ x = self.act(x)
52
+ x = self.drop(x)
53
+ x = self.fc2(x)
54
+ x = self.drop(x)
55
+ return x
56
+
57
+
58
+ class Attention(nn.Module):
59
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
64
+ self.scale = qk_scale or head_dim ** -0.5
65
+
66
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
74
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
75
+
76
+ attn = (q @ k.transpose(-2, -1)) * self.scale
77
+ attn = attn.softmax(dim=-1)
78
+ attn = self.attn_drop(attn)
79
+
80
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
81
+ x = self.proj(x)
82
+ x = self.proj_drop(x)
83
+ return x
84
+
85
+
86
+ class Block(nn.Module):
87
+
88
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
89
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
90
+ super().__init__()
91
+ self.norm1 = norm_layer(dim)
92
+ self.attn = Attention(
93
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
94
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
95
+ self.drop_path = DropPath(
96
+ drop_path) if drop_path > 0. else nn.Identity()
97
+ self.norm2 = norm_layer(dim)
98
+ mlp_hidden_dim = int(dim * mlp_ratio)
99
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
100
+ act_layer=act_layer, drop=drop)
101
+
102
+ def forward(self, x):
103
+ x = x + self.drop_path(self.attn(self.norm1(x)))
104
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
105
+ return x
106
+
107
+
108
+ class PatchEmbed(nn.Module):
109
+ """ Image to Patch Embedding
110
+ """
111
+
112
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
113
+ super().__init__()
114
+ img_size = to_2tuple(img_size)
115
+ patch_size = to_2tuple(patch_size)
116
+
117
+ self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
118
+
119
+ self.num_patches_w, self.num_patches_h = self.window_size
120
+
121
+ self.num_patches = self.window_size[0] * self.window_size[1]
122
+ self.img_size = img_size
123
+ self.patch_size = patch_size
124
+
125
+ self.proj = nn.Conv2d(in_chans, embed_dim,
126
+ kernel_size=patch_size, stride=patch_size)
127
+
128
+ def forward(self, x):
129
+ x = self.proj(x)
130
+ return x
131
+
132
+
133
+ class HybridEmbed(nn.Module):
134
+ """ CNN Feature Map Embedding
135
+ Extract feature map from CNN, flatten, project to embedding dim.
136
+ """
137
+
138
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
139
+ super().__init__()
140
+ assert isinstance(backbone, nn.Module)
141
+ img_size = to_2tuple(img_size)
142
+ self.img_size = img_size
143
+ self.backbone = backbone
144
+ if feature_size is None:
145
+ with torch.no_grad():
146
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
147
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
148
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
149
+ training = backbone.training
150
+ if training:
151
+ backbone.eval()
152
+ o = self.backbone(torch.zeros(
153
+ 1, in_chans, img_size[0], img_size[1]))[-1]
154
+ feature_size = o.shape[-2:]
155
+ feature_dim = o.shape[1]
156
+ backbone.train(training)
157
+ else:
158
+ feature_size = to_2tuple(feature_size)
159
+ feature_dim = self.backbone.feature_info.channels()[-1]
160
+ self.num_patches = feature_size[0] * feature_size[1]
161
+ self.proj = nn.Linear(feature_dim, embed_dim)
162
+
163
+ def forward(self, x):
164
+ x = self.backbone(x)[-1]
165
+ x = x.flatten(2).transpose(1, 2)
166
+ x = self.proj(x)
167
+ return x
168
+
169
+
170
+ class ViT(nn.Module):
171
+ """ Vision Transformer with support for patch or hybrid CNN input stage
172
+ """
173
+
174
+ def __init__(self,
175
+ model_name='vit_base_patch16_224',
176
+ img_size=384,
177
+ patch_size=16,
178
+ in_chans=3,
179
+ embed_dim=1024,
180
+ depth=24,
181
+ num_heads=16,
182
+ num_classes=19,
183
+ mlp_ratio=4.,
184
+ qkv_bias=True,
185
+ qk_scale=None,
186
+ drop_rate=0.1,
187
+ attn_drop_rate=0.,
188
+ drop_path_rate=0.,
189
+ hybrid_backbone=None,
190
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
191
+ norm_cfg=None,
192
+ pos_embed_interp=False,
193
+ random_init=False,
194
+ align_corners=False,
195
+ use_checkpoint=False,
196
+ num_extra_tokens=1,
197
+ out_features=None,
198
+ **kwargs,
199
+ ):
200
+
201
+ super(ViT, self).__init__()
202
+ self.model_name = model_name
203
+ self.img_size = img_size
204
+ self.patch_size = patch_size
205
+ self.in_chans = in_chans
206
+ self.embed_dim = embed_dim
207
+ self.depth = depth
208
+ self.num_heads = num_heads
209
+ self.num_classes = num_classes
210
+ self.mlp_ratio = mlp_ratio
211
+ self.qkv_bias = qkv_bias
212
+ self.qk_scale = qk_scale
213
+ self.drop_rate = drop_rate
214
+ self.attn_drop_rate = attn_drop_rate
215
+ self.drop_path_rate = drop_path_rate
216
+ self.hybrid_backbone = hybrid_backbone
217
+ self.norm_layer = norm_layer
218
+ self.norm_cfg = norm_cfg
219
+ self.pos_embed_interp = pos_embed_interp
220
+ self.random_init = random_init
221
+ self.align_corners = align_corners
222
+ self.use_checkpoint = use_checkpoint
223
+ self.num_extra_tokens = num_extra_tokens
224
+ self.out_features = out_features
225
+ self.out_indices = [int(name[5:]) for name in out_features]
226
+
227
+ # self.num_stages = self.depth
228
+ # self.out_indices = tuple(range(self.num_stages))
229
+
230
+ if self.hybrid_backbone is not None:
231
+ self.patch_embed = HybridEmbed(
232
+ self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
233
+ else:
234
+ self.patch_embed = PatchEmbed(
235
+ img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
236
+ self.num_patches = self.patch_embed.num_patches
237
+
238
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
239
+
240
+ if self.num_extra_tokens == 2:
241
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
242
+
243
+ self.pos_embed = nn.Parameter(torch.zeros(
244
+ 1, self.num_patches + self.num_extra_tokens, self.embed_dim))
245
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
246
+
247
+ # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
248
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
249
+ self.depth)] # stochastic depth decay rule
250
+ self.blocks = nn.ModuleList([
251
+ Block(
252
+ dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
253
+ qk_scale=self.qk_scale,
254
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
255
+ for i in range(self.depth)])
256
+
257
+ # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
258
+ # self.repr = nn.Linear(embed_dim, representation_size)
259
+ # self.repr_act = nn.Tanh()
260
+
261
+ if patch_size == 16:
262
+ self.fpn1 = nn.Sequential(
263
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
264
+ nn.SyncBatchNorm(embed_dim),
265
+ nn.GELU(),
266
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
267
+ )
268
+
269
+ self.fpn2 = nn.Sequential(
270
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
271
+ )
272
+
273
+ self.fpn3 = nn.Identity()
274
+
275
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
276
+ elif patch_size == 8:
277
+ self.fpn1 = nn.Sequential(
278
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
279
+ )
280
+
281
+ self.fpn2 = nn.Identity()
282
+
283
+ self.fpn3 = nn.Sequential(
284
+ nn.MaxPool2d(kernel_size=2, stride=2),
285
+ )
286
+
287
+ self.fpn4 = nn.Sequential(
288
+ nn.MaxPool2d(kernel_size=4, stride=4),
289
+ )
290
+
291
+ trunc_normal_(self.pos_embed, std=.02)
292
+ trunc_normal_(self.cls_token, std=.02)
293
+ if self.num_extra_tokens==2:
294
+ trunc_normal_(self.dist_token, std=0.2)
295
+ self.apply(self._init_weights)
296
+ # self.fix_init_weight()
297
+
298
+ def fix_init_weight(self):
299
+ def rescale(param, layer_id):
300
+ param.div_(math.sqrt(2.0 * layer_id))
301
+
302
+ for layer_id, layer in enumerate(self.blocks):
303
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
304
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
305
+
306
+ def _init_weights(self, m):
307
+ if isinstance(m, nn.Linear):
308
+ trunc_normal_(m.weight, std=.02)
309
+ if isinstance(m, nn.Linear) and m.bias is not None:
310
+ nn.init.constant_(m.bias, 0)
311
+ elif isinstance(m, nn.LayerNorm):
312
+ nn.init.constant_(m.bias, 0)
313
+ nn.init.constant_(m.weight, 1.0)
314
+
315
+ '''
316
+ def init_weights(self):
317
+ logger = get_root_logger()
318
+
319
+ trunc_normal_(self.pos_embed, std=.02)
320
+ trunc_normal_(self.cls_token, std=.02)
321
+ self.apply(self._init_weights)
322
+
323
+ if self.init_cfg is None:
324
+ logger.warn(f'No pre-trained weights for '
325
+ f'{self.__class__.__name__}, '
326
+ f'training start from scratch')
327
+ else:
328
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
329
+ f'specify `Pretrained` in ' \
330
+ f'`init_cfg` in ' \
331
+ f'{self.__class__.__name__} '
332
+ logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
333
+ load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
334
+ '''
335
+
336
+ def get_num_layers(self):
337
+ return len(self.blocks)
338
+
339
+ @torch.jit.ignore
340
+ def no_weight_decay(self):
341
+ return {'pos_embed', 'cls_token'}
342
+
343
+ def _conv_filter(self, state_dict, patch_size=16):
344
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
345
+ out_dict = {}
346
+ for k, v in state_dict.items():
347
+ if 'patch_embed.proj.weight' in k:
348
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
349
+ out_dict[k] = v
350
+ return out_dict
351
+
352
+ def to_2D(self, x):
353
+ n, hw, c = x.shape
354
+ h = w = int(math.sqrt(hw))
355
+ x = x.transpose(1, 2).reshape(n, c, h, w)
356
+ return x
357
+
358
+ def to_1D(self, x):
359
+ n, c, h, w = x.shape
360
+ x = x.reshape(n, c, -1).transpose(1, 2)
361
+ return x
362
+
363
+ def interpolate_pos_encoding(self, x, w, h):
364
+ npatch = x.shape[1] - self.num_extra_tokens
365
+ N = self.pos_embed.shape[1] - self.num_extra_tokens
366
+ if npatch == N and w == h:
367
+ return self.pos_embed
368
+
369
+ class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
370
+
371
+ patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
372
+
373
+ dim = x.shape[-1]
374
+ w0 = w // self.patch_embed.patch_size[0]
375
+ h0 = h // self.patch_embed.patch_size[1]
376
+ # we add a small number to avoid floating point error in the interpolation
377
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
378
+ w0, h0 = w0 + 0.1, h0 + 0.1
379
+ patch_pos_embed = nn.functional.interpolate(
380
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
381
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
382
+ mode='bicubic',
383
+ )
384
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
385
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
386
+
387
+ return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
388
+
389
+ def prepare_tokens(self, x, mask=None):
390
+ B, nc, w, h = x.shape
391
+ # patch linear embedding
392
+ x = self.patch_embed(x)
393
+
394
+ # mask image modeling
395
+ if mask is not None:
396
+ x = self.mask_model(x, mask)
397
+ x = x.flatten(2).transpose(1, 2)
398
+
399
+ # add the [CLS] token to the embed patch tokens
400
+ all_tokens = [self.cls_token.expand(B, -1, -1)]
401
+
402
+ if self.num_extra_tokens == 2:
403
+ dist_tokens = self.dist_token.expand(B, -1, -1)
404
+ all_tokens.append(dist_tokens)
405
+ all_tokens.append(x)
406
+
407
+ x = torch.cat(all_tokens, dim=1)
408
+
409
+ # add positional encoding to each token
410
+ x = x + self.interpolate_pos_encoding(x, w, h)
411
+
412
+ return self.pos_drop(x)
413
+
414
+ def forward_features(self, x):
415
+ # print(f"==========shape of x is {x.shape}==========")
416
+ B, _, H, W = x.shape
417
+ Hp, Wp = H // self.patch_size, W // self.patch_size
418
+ x = self.prepare_tokens(x)
419
+
420
+ features = []
421
+ for i, blk in enumerate(self.blocks):
422
+ if self.use_checkpoint:
423
+ x = checkpoint.checkpoint(blk, x)
424
+ else:
425
+ x = blk(x)
426
+ if i in self.out_indices:
427
+ xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
428
+ features.append(xp.contiguous())
429
+
430
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
431
+ for i in range(len(features)):
432
+ features[i] = ops[i](features[i])
433
+
434
+ feat_out = {}
435
+
436
+ for name, value in zip(self.out_features, features):
437
+ feat_out[name] = value
438
+
439
+ return feat_out
440
+
441
+ def forward(self, x):
442
+ x = self.forward_features(x)
443
+ return x
444
+
445
+
446
+ def deit_base_patch16(pretrained=False, **kwargs):
447
+ model = ViT(
448
+ patch_size=16,
449
+ drop_rate=0.,
450
+ embed_dim=768,
451
+ depth=12,
452
+ num_heads=12,
453
+ num_classes=1000,
454
+ mlp_ratio=4.,
455
+ qkv_bias=True,
456
+ use_checkpoint=True,
457
+ num_extra_tokens=2,
458
+ **kwargs)
459
+ model.default_cfg = _cfg()
460
+ return model
461
+
462
+ def mae_base_patch16(pretrained=False, **kwargs):
463
+ model = ViT(
464
+ patch_size=16,
465
+ drop_rate=0.,
466
+ embed_dim=768,
467
+ depth=12,
468
+ num_heads=12,
469
+ num_classes=1000,
470
+ mlp_ratio=4.,
471
+ qkv_bias=True,
472
+ use_checkpoint=True,
473
+ num_extra_tokens=1,
474
+ **kwargs)
475
+ model.default_cfg = _cfg()
476
+ return model
@@ -0,0 +1,7 @@
1
+ from .models import (
2
+ LayoutLMv3Config,
3
+ LayoutLMv3ForTokenClassification,
4
+ LayoutLMv3ForQuestionAnswering,
5
+ LayoutLMv3ForSequenceClassification,
6
+ LayoutLMv3Tokenizer,
7
+ )
@@ -0,0 +1,2 @@
1
+ # flake8: noqa
2
+ from .data_collator import DataCollatorForKeyValueExtraction
@@ -0,0 +1,171 @@
1
+ '''
2
+ Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
3
+ '''
4
+
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ import datasets
10
+ from .image_utils import load_image, normalize_bbox
11
+ logger = datasets.logging.get_logger(__name__)
12
+ _CITATION = """\
13
+ @article{park2019cord,
14
+ title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
15
+ author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
16
+ booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
17
+ year={2019}
18
+ }
19
+ """
20
+ _DESCRIPTION = """\
21
+ https://github.com/clovaai/cord/
22
+ """
23
+
24
+ def quad_to_box(quad):
25
+ # test 87 is wrongly annotated
26
+ box = (
27
+ max(0, quad["x1"]),
28
+ max(0, quad["y1"]),
29
+ quad["x3"],
30
+ quad["y3"]
31
+ )
32
+ if box[3] < box[1]:
33
+ bbox = list(box)
34
+ tmp = bbox[3]
35
+ bbox[3] = bbox[1]
36
+ bbox[1] = tmp
37
+ box = tuple(bbox)
38
+ if box[2] < box[0]:
39
+ bbox = list(box)
40
+ tmp = bbox[2]
41
+ bbox[2] = bbox[0]
42
+ bbox[0] = tmp
43
+ box = tuple(bbox)
44
+ return box
45
+
46
+ def _get_drive_url(url):
47
+ base_url = 'https://drive.google.com/uc?id='
48
+ split_url = url.split('/')
49
+ return base_url + split_url[5]
50
+
51
+ _URLS = [
52
+ _get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
53
+ _get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
54
+ # If you failed to download the dataset through the automatic downloader,
55
+ # you can download it manually and modify the code to get the local dataset.
56
+ # Or you can use the following links. Please follow the original LICENSE of CORD for usage.
57
+ # "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
58
+ # "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
59
+ ]
60
+
61
+ class CordConfig(datasets.BuilderConfig):
62
+ """BuilderConfig for CORD"""
63
+ def __init__(self, **kwargs):
64
+ """BuilderConfig for CORD.
65
+ Args:
66
+ **kwargs: keyword arguments forwarded to super.
67
+ """
68
+ super(CordConfig, self).__init__(**kwargs)
69
+
70
+ class Cord(datasets.GeneratorBasedBuilder):
71
+ BUILDER_CONFIGS = [
72
+ CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
73
+ ]
74
+
75
+ def _info(self):
76
+ return datasets.DatasetInfo(
77
+ description=_DESCRIPTION,
78
+ features=datasets.Features(
79
+ {
80
+ "id": datasets.Value("string"),
81
+ "words": datasets.Sequence(datasets.Value("string")),
82
+ "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
83
+ "ner_tags": datasets.Sequence(
84
+ datasets.features.ClassLabel(
85
+ names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
86
+ )
87
+ ),
88
+ "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
89
+ "image_path": datasets.Value("string"),
90
+ }
91
+ ),
92
+ supervised_keys=None,
93
+ citation=_CITATION,
94
+ homepage="https://github.com/clovaai/cord/",
95
+ )
96
+
97
+ def _split_generators(self, dl_manager):
98
+ """Returns SplitGenerators."""
99
+ """Uses local files located with data_dir"""
100
+ downloaded_file = dl_manager.download_and_extract(_URLS)
101
+ # move files from the second URL together with files from the first one.
102
+ dest = Path(downloaded_file[0])/"CORD"
103
+ for split in ["train", "dev", "test"]:
104
+ for file_type in ["image", "json"]:
105
+ if split == "test" and file_type == "json":
106
+ continue
107
+ files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
108
+ for f in files:
109
+ os.rename(f, dest/split/file_type/f.name)
110
+ return [
111
+ datasets.SplitGenerator(
112
+ name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
113
+ ),
114
+ datasets.SplitGenerator(
115
+ name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
116
+ ),
117
+ datasets.SplitGenerator(
118
+ name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
119
+ ),
120
+ ]
121
+
122
+ def get_line_bbox(self, bboxs):
123
+ x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
124
+ y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
125
+
126
+ x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
127
+
128
+ assert x1 >= x0 and y1 >= y0
129
+ bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
130
+ return bbox
131
+
132
+ def _generate_examples(self, filepath):
133
+ logger.info("⏳ Generating examples from = %s", filepath)
134
+ ann_dir = os.path.join(filepath, "json")
135
+ img_dir = os.path.join(filepath, "image")
136
+ for guid, file in enumerate(sorted(os.listdir(ann_dir))):
137
+ words = []
138
+ bboxes = []
139
+ ner_tags = []
140
+ file_path = os.path.join(ann_dir, file)
141
+ with open(file_path, "r", encoding="utf8") as f:
142
+ data = json.load(f)
143
+ image_path = os.path.join(img_dir, file)
144
+ image_path = image_path.replace("json", "png")
145
+ image, size = load_image(image_path)
146
+ for item in data["valid_line"]:
147
+ cur_line_bboxes = []
148
+ line_words, label = item["words"], item["category"]
149
+ line_words = [w for w in line_words if w["text"].strip() != ""]
150
+ if len(line_words) == 0:
151
+ continue
152
+ if label == "other":
153
+ for w in line_words:
154
+ words.append(w["text"])
155
+ ner_tags.append("O")
156
+ cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
157
+ else:
158
+ words.append(line_words[0]["text"])
159
+ ner_tags.append("B-" + label.upper())
160
+ cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
161
+ for w in line_words[1:]:
162
+ words.append(w["text"])
163
+ ner_tags.append("I-" + label.upper())
164
+ cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
165
+ # by default: --segment_level_layout 1
166
+ # if do not want to use segment_level_layout, comment the following line
167
+ cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
168
+ bboxes.extend(cur_line_bboxes)
169
+ # yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
170
+ yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
171
+ "image": image, "image_path": image_path}