magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. magic_pdf/cli/magicpdf.py +18 -7
  2. magic_pdf/libs/config_reader.py +10 -0
  3. magic_pdf/libs/version.py +1 -1
  4. magic_pdf/model/__init__.py +1 -0
  5. magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
  6. magic_pdf/model/model_list.py +1 -0
  7. magic_pdf/model/pdf_extract_kit.py +196 -0
  8. magic_pdf/model/pek_sub_modules/__init__.py +0 -0
  9. magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
  10. magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
  11. magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
  12. magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
  13. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
  14. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
  15. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
  16. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
  17. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
  18. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
  19. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
  20. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
  21. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
  22. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
  23. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
  24. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
  25. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
  26. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
  27. magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
  28. magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
  29. magic_pdf/model/pek_sub_modules/post_process.py +36 -0
  30. magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
  31. magic_pdf/model/pp_structure_v2.py +7 -0
  32. magic_pdf/pipe/AbsPipe.py +8 -14
  33. magic_pdf/pipe/OCRPipe.py +12 -8
  34. magic_pdf/pipe/TXTPipe.py +12 -8
  35. magic_pdf/pipe/UNIPipe.py +9 -7
  36. magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
  37. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
  38. magic_pdf/resources/model_config/model_configs.yaml +9 -0
  39. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
  40. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
  41. magic_pdf/model/360_layout_analysis.py +0 -8
  42. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
  43. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
  44. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
  45. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,213 @@
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ from torch.utils.data.dataset import Dataset
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+
9
+ from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
10
+
11
+ XFund_label2ids = {
12
+ "O":0,
13
+ 'B-HEADER':1,
14
+ 'I-HEADER':2,
15
+ 'B-QUESTION':3,
16
+ 'I-QUESTION':4,
17
+ 'B-ANSWER':5,
18
+ 'I-ANSWER':6,
19
+ }
20
+
21
+ class xfund_dataset(Dataset):
22
+ def box_norm(self, box, width, height):
23
+ def clip(min_num, num, max_num):
24
+ return min(max(num, min_num), max_num)
25
+
26
+ x0, y0, x1, y1 = box
27
+ x0 = clip(0, int((x0 / width) * 1000), 1000)
28
+ y0 = clip(0, int((y0 / height) * 1000), 1000)
29
+ x1 = clip(0, int((x1 / width) * 1000), 1000)
30
+ y1 = clip(0, int((y1 / height) * 1000), 1000)
31
+ assert x1 >= x0
32
+ assert y1 >= y0
33
+ return [x0, y0, x1, y1]
34
+
35
+ def get_segment_ids(self, bboxs):
36
+ segment_ids = []
37
+ for i in range(len(bboxs)):
38
+ if i == 0:
39
+ segment_ids.append(0)
40
+ else:
41
+ if bboxs[i - 1] == bboxs[i]:
42
+ segment_ids.append(segment_ids[-1])
43
+ else:
44
+ segment_ids.append(segment_ids[-1] + 1)
45
+ return segment_ids
46
+
47
+ def get_position_ids(self, segment_ids):
48
+ position_ids = []
49
+ for i in range(len(segment_ids)):
50
+ if i == 0:
51
+ position_ids.append(2)
52
+ else:
53
+ if segment_ids[i] == segment_ids[i - 1]:
54
+ position_ids.append(position_ids[-1] + 1)
55
+ else:
56
+ position_ids.append(2)
57
+ return position_ids
58
+
59
+ def load_data(
60
+ self,
61
+ data_file,
62
+ ):
63
+ # re-org data format
64
+ total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
65
+ for i in range(len(data_file['documents'])):
66
+ width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
67
+ 'height']
68
+
69
+ cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
70
+ for j in range(len(data_file['documents'][i]['document'])):
71
+ cur_item = data_file['documents'][i]['document'][j]
72
+ cur_doc_lines.append(cur_item['text'])
73
+ cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
74
+ cur_doc_ner_tags.append(cur_item['label'])
75
+ total_data['id'] += [len(total_data['id'])]
76
+ total_data['lines'] += [cur_doc_lines]
77
+ total_data['bboxes'] += [cur_doc_bboxes]
78
+ total_data['ner_tags'] += [cur_doc_ner_tags]
79
+ total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
80
+
81
+ # tokenize text and get bbox/label
82
+ total_input_ids, total_bboxs, total_label_ids = [], [], []
83
+ for i in range(len(total_data['lines'])):
84
+ cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
85
+ for j in range(len(total_data['lines'][i])):
86
+ cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
87
+ if len(cur_input_ids) == 0: continue
88
+
89
+ cur_label = total_data['ner_tags'][i][j].upper()
90
+ if cur_label == 'OTHER':
91
+ cur_labels = ["O"] * len(cur_input_ids)
92
+ for k in range(len(cur_labels)):
93
+ cur_labels[k] = self.label2ids[cur_labels[k]]
94
+ else:
95
+ cur_labels = [cur_label] * len(cur_input_ids)
96
+ cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
97
+ for k in range(1, len(cur_labels)):
98
+ cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
99
+ assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
100
+ cur_doc_input_ids += cur_input_ids
101
+ cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
102
+ cur_doc_labels += cur_labels
103
+ assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
104
+ assert len(cur_doc_input_ids) > 0
105
+
106
+ total_input_ids.append(cur_doc_input_ids)
107
+ total_bboxs.append(cur_doc_bboxs)
108
+ total_label_ids.append(cur_doc_labels)
109
+ assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
110
+
111
+ # split text to several slices because of over-length
112
+ input_ids, bboxs, labels = [], [], []
113
+ segment_ids, position_ids = [], []
114
+ image_path = []
115
+ for i in range(len(total_input_ids)):
116
+ start = 0
117
+ cur_iter = 0
118
+ while start < len(total_input_ids[i]):
119
+ end = min(start + 510, len(total_input_ids[i]))
120
+
121
+ input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
122
+ bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
123
+ labels.append([-100] + total_label_ids[i][start: end] + [-100])
124
+
125
+ cur_segment_ids = self.get_segment_ids(bboxs[-1])
126
+ cur_position_ids = self.get_position_ids(cur_segment_ids)
127
+ segment_ids.append(cur_segment_ids)
128
+ position_ids.append(cur_position_ids)
129
+ image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
130
+
131
+ start = end
132
+ cur_iter += 1
133
+
134
+ assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
135
+ assert len(segment_ids) == len(image_path)
136
+
137
+ res = {
138
+ 'input_ids': input_ids,
139
+ 'bbox': bboxs,
140
+ 'labels': labels,
141
+ 'segment_ids': segment_ids,
142
+ 'position_ids': position_ids,
143
+ 'image_path': image_path,
144
+ }
145
+ return res
146
+
147
+ def __init__(
148
+ self,
149
+ args,
150
+ tokenizer,
151
+ mode
152
+ ):
153
+ self.args = args
154
+ self.mode = mode
155
+ self.cur_la = args.language
156
+ self.tokenizer = tokenizer
157
+ self.label2ids = XFund_label2ids
158
+
159
+
160
+ self.common_transform = Compose([
161
+ RandomResizedCropAndInterpolationWithTwoPic(
162
+ size=args.input_size, interpolation=args.train_interpolation,
163
+ ),
164
+ ])
165
+
166
+ self.patch_transform = transforms.Compose([
167
+ transforms.ToTensor(),
168
+ transforms.Normalize(
169
+ mean=torch.tensor((0.5, 0.5, 0.5)),
170
+ std=torch.tensor((0.5, 0.5, 0.5)))
171
+ ])
172
+
173
+ data_file = json.load(
174
+ open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
175
+ 'r'))
176
+
177
+ self.feature = self.load_data(data_file)
178
+
179
+ def __len__(self):
180
+ return len(self.feature['input_ids'])
181
+
182
+ def __getitem__(self, index):
183
+ input_ids = self.feature["input_ids"][index]
184
+
185
+ # attention_mask = self.feature["attention_mask"][index]
186
+ attention_mask = [1] * len(input_ids)
187
+ labels = self.feature["labels"][index]
188
+ bbox = self.feature["bbox"][index]
189
+ segment_ids = self.feature['segment_ids'][index]
190
+ position_ids = self.feature['position_ids'][index]
191
+
192
+ img = pil_loader(self.feature['image_path'][index])
193
+ for_patches, _ = self.common_transform(img, augmentation=False)
194
+ patch = self.patch_transform(for_patches)
195
+
196
+ assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
197
+
198
+ res = {
199
+ "input_ids": input_ids,
200
+ "attention_mask": attention_mask,
201
+ "labels": labels,
202
+ "bbox": bbox,
203
+ "segment_ids": segment_ids,
204
+ "position_ids": position_ids,
205
+ "images": patch,
206
+ }
207
+ return res
208
+
209
+ def pil_loader(path: str) -> Image.Image:
210
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
211
+ with open(path, 'rb') as f:
212
+ img = Image.open(f)
213
+ return img.convert('RGB')
@@ -0,0 +1,7 @@
1
+ from .layoutlmv3 import (
2
+ LayoutLMv3Config,
3
+ LayoutLMv3ForTokenClassification,
4
+ LayoutLMv3ForQuestionAnswering,
5
+ LayoutLMv3ForSequenceClassification,
6
+ LayoutLMv3Tokenizer,
7
+ )
@@ -0,0 +1,24 @@
1
+ from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \
2
+ AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
3
+ from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter
4
+
5
+ from .configuration_layoutlmv3 import LayoutLMv3Config
6
+ from .modeling_layoutlmv3 import (
7
+ LayoutLMv3ForTokenClassification,
8
+ LayoutLMv3ForQuestionAnswering,
9
+ LayoutLMv3ForSequenceClassification,
10
+ LayoutLMv3Model,
11
+ )
12
+ from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
13
+ from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
14
+
15
+
16
+ #AutoConfig.register("layoutlmv3", LayoutLMv3Config)
17
+ #AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)
18
+ #AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)
19
+ #AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)
20
+ #AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)
21
+ #AutoTokenizer.register(
22
+ # LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast
23
+ #)
24
+ SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter})
@@ -0,0 +1,60 @@
1
+ # coding=utf-8
2
+ from transformers.models.bert.configuration_bert import BertConfig
3
+ from transformers.utils import logging
4
+
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+ LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
9
+ "layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
10
+ "layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json",
11
+ # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
12
+ }
13
+
14
+
15
+ class LayoutLMv3Config(BertConfig):
16
+ model_type = "layoutlmv3"
17
+
18
+ def __init__(
19
+ self,
20
+ pad_token_id=1,
21
+ bos_token_id=0,
22
+ eos_token_id=2,
23
+ max_2d_position_embeddings=1024,
24
+ coordinate_size=None,
25
+ shape_size=None,
26
+ has_relative_attention_bias=False,
27
+ rel_pos_bins=32,
28
+ max_rel_pos=128,
29
+ has_spatial_attention_bias=False,
30
+ rel_2d_pos_bins=64,
31
+ max_rel_2d_pos=256,
32
+ visual_embed=True,
33
+ mim=False,
34
+ wpa_task=False,
35
+ discrete_vae_weight_path='',
36
+ discrete_vae_type='dall-e',
37
+ input_size=224,
38
+ second_input_size=112,
39
+ device='cuda',
40
+ **kwargs
41
+ ):
42
+ """Constructs RobertaConfig."""
43
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
44
+ self.max_2d_position_embeddings = max_2d_position_embeddings
45
+ self.coordinate_size = coordinate_size
46
+ self.shape_size = shape_size
47
+ self.has_relative_attention_bias = has_relative_attention_bias
48
+ self.rel_pos_bins = rel_pos_bins
49
+ self.max_rel_pos = max_rel_pos
50
+ self.has_spatial_attention_bias = has_spatial_attention_bias
51
+ self.rel_2d_pos_bins = rel_2d_pos_bins
52
+ self.max_rel_2d_pos = max_rel_2d_pos
53
+ self.visual_embed = visual_embed
54
+ self.mim = mim
55
+ self.wpa_task = wpa_task
56
+ self.discrete_vae_weight_path = discrete_vae_weight_path
57
+ self.discrete_vae_type = discrete_vae_type
58
+ self.input_size = input_size
59
+ self.second_input_size = second_input_size
60
+ self.device = device