magic-pdf 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/tools/cli.py +30 -12
  85. magic_pdf/tools/common.py +90 -12
  86. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +50 -40
  87. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  88. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  90. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  91. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  92. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  93. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  94. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  95. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  96. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  97. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  98. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  99. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,690 @@
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+
17
+
18
+ class BaseRecLabelDecode(object):
19
+ """ Convert between text-label and text-index """
20
+
21
+ def __init__(self,
22
+ character_dict_path=None,
23
+ use_space_char=False):
24
+
25
+ self.beg_str = "sos"
26
+ self.end_str = "eos"
27
+
28
+ self.character_str = []
29
+ if character_dict_path is None:
30
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
31
+ dict_character = list(self.character_str)
32
+ else:
33
+ with open(character_dict_path, "rb") as fin:
34
+ lines = fin.readlines()
35
+ for line in lines:
36
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
37
+ self.character_str.append(line)
38
+ if use_space_char:
39
+ self.character_str.append(" ")
40
+ dict_character = list(self.character_str)
41
+
42
+ dict_character = self.add_special_char(dict_character)
43
+ self.dict = {}
44
+ for i, char in enumerate(dict_character):
45
+ self.dict[char] = i
46
+ self.character = dict_character
47
+
48
+ def add_special_char(self, dict_character):
49
+ return dict_character
50
+
51
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
52
+ """ convert text-index into text-label. """
53
+ result_list = []
54
+ ignored_tokens = self.get_ignored_tokens()
55
+ batch_size = len(text_index)
56
+ for batch_idx in range(batch_size):
57
+ char_list = []
58
+ conf_list = []
59
+ for idx in range(len(text_index[batch_idx])):
60
+ if text_index[batch_idx][idx] in ignored_tokens:
61
+ continue
62
+ if is_remove_duplicate:
63
+ # only for predict
64
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
65
+ batch_idx][idx]:
66
+ continue
67
+ char_list.append(self.character[int(text_index[batch_idx][
68
+ idx])])
69
+ if text_prob is not None:
70
+ conf_list.append(text_prob[batch_idx][idx])
71
+ else:
72
+ conf_list.append(1)
73
+ text = ''.join(char_list)
74
+ result_list.append((text, np.mean(conf_list)))
75
+ return result_list
76
+
77
+ def get_ignored_tokens(self):
78
+ return [0] # for ctc blank
79
+
80
+
81
+ class CTCLabelDecode(BaseRecLabelDecode):
82
+ """ Convert between text-label and text-index """
83
+
84
+ def __init__(self,
85
+ character_dict_path=None,
86
+ use_space_char=False,
87
+ **kwargs):
88
+ super(CTCLabelDecode, self).__init__(character_dict_path,
89
+ use_space_char)
90
+
91
+ def __call__(self, preds, label=None, *args, **kwargs):
92
+ if isinstance(preds, torch.Tensor):
93
+ preds = preds.numpy()
94
+ preds_idx = preds.argmax(axis=2)
95
+ preds_prob = preds.max(axis=2)
96
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
97
+
98
+ if label is None:
99
+ return text
100
+ label = self.decode(label)
101
+ return text, label
102
+
103
+ def add_special_char(self, dict_character):
104
+ dict_character = ['blank'] + dict_character
105
+ return dict_character
106
+
107
+
108
+ class NRTRLabelDecode(BaseRecLabelDecode):
109
+ """ Convert between text-label and text-index """
110
+
111
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
112
+ super(NRTRLabelDecode, self).__init__(character_dict_path,
113
+ use_space_char)
114
+
115
+ def __call__(self, preds, label=None, *args, **kwargs):
116
+
117
+ if len(preds) == 2:
118
+ preds_id = preds[0]
119
+ preds_prob = preds[1]
120
+ if isinstance(preds_id, torch.Tensor):
121
+ preds_id = preds_id.numpy()
122
+ if isinstance(preds_prob, torch.Tensor):
123
+ preds_prob = preds_prob.numpy()
124
+ if preds_id[0][0] == 2:
125
+ preds_idx = preds_id[:, 1:]
126
+ preds_prob = preds_prob[:, 1:]
127
+ else:
128
+ preds_idx = preds_id
129
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
130
+ if label is None:
131
+ return text
132
+ label = self.decode(label[:, 1:])
133
+ else:
134
+ if isinstance(preds, torch.Tensor):
135
+ preds = preds.numpy()
136
+ preds_idx = preds.argmax(axis=2)
137
+ preds_prob = preds.max(axis=2)
138
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
139
+ if label is None:
140
+ return text
141
+ label = self.decode(label[:, 1:])
142
+ return text, label
143
+
144
+ def add_special_char(self, dict_character):
145
+ dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
146
+ return dict_character
147
+
148
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
149
+ """ convert text-index into text-label. """
150
+ result_list = []
151
+ batch_size = len(text_index)
152
+ for batch_idx in range(batch_size):
153
+ char_list = []
154
+ conf_list = []
155
+ for idx in range(len(text_index[batch_idx])):
156
+ try:
157
+ char_idx = self.character[int(text_index[batch_idx][idx])]
158
+ except:
159
+ continue
160
+ if char_idx == '</s>': # end
161
+ break
162
+ char_list.append(char_idx)
163
+ if text_prob is not None:
164
+ conf_list.append(text_prob[batch_idx][idx])
165
+ else:
166
+ conf_list.append(1)
167
+ text = ''.join(char_list)
168
+ result_list.append((text.lower(), np.mean(conf_list).tolist()))
169
+ return result_list
170
+
171
+ class ViTSTRLabelDecode(NRTRLabelDecode):
172
+ """ Convert between text-label and text-index """
173
+
174
+ def __init__(self, character_dict_path=None, use_space_char=False,
175
+ **kwargs):
176
+ super(ViTSTRLabelDecode, self).__init__(character_dict_path,
177
+ use_space_char)
178
+
179
+ def __call__(self, preds, label=None, *args, **kwargs):
180
+ if isinstance(preds, torch.Tensor):
181
+ preds = preds[:, 1:].numpy()
182
+ else:
183
+ preds = preds[:, 1:]
184
+ preds_idx = preds.argmax(axis=2)
185
+ preds_prob = preds.max(axis=2)
186
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
187
+ if label is None:
188
+ return text
189
+ label = self.decode(label[:, 1:])
190
+ return text, label
191
+
192
+ def add_special_char(self, dict_character):
193
+ dict_character = ['<s>', '</s>'] + dict_character
194
+ return dict_character
195
+
196
+
197
+ class AttnLabelDecode(BaseRecLabelDecode):
198
+ """ Convert between text-label and text-index """
199
+
200
+ def __init__(self,
201
+ character_dict_path=None,
202
+ use_space_char=False,
203
+ **kwargs):
204
+ super(AttnLabelDecode, self).__init__(character_dict_path,
205
+ use_space_char)
206
+
207
+ def add_special_char(self, dict_character):
208
+ self.beg_str = "sos"
209
+ self.end_str = "eos"
210
+ dict_character = dict_character
211
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
212
+ return dict_character
213
+
214
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
215
+ """ convert text-index into text-label. """
216
+ result_list = []
217
+ ignored_tokens = self.get_ignored_tokens()
218
+ [beg_idx, end_idx] = self.get_ignored_tokens()
219
+ batch_size = len(text_index)
220
+ for batch_idx in range(batch_size):
221
+ char_list = []
222
+ conf_list = []
223
+ for idx in range(len(text_index[batch_idx])):
224
+ if text_index[batch_idx][idx] in ignored_tokens:
225
+ continue
226
+ if int(text_index[batch_idx][idx]) == int(end_idx):
227
+ break
228
+ if is_remove_duplicate:
229
+ # only for predict
230
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
231
+ batch_idx][idx]:
232
+ continue
233
+ char_list.append(self.character[int(text_index[batch_idx][
234
+ idx])])
235
+ if text_prob is not None:
236
+ conf_list.append(text_prob[batch_idx][idx])
237
+ else:
238
+ conf_list.append(1)
239
+ text = ''.join(char_list)
240
+ result_list.append((text, np.mean(conf_list)))
241
+ return result_list
242
+
243
+ def __call__(self, preds, label=None, *args, **kwargs):
244
+ """
245
+ text = self.decode(text)
246
+ if label is None:
247
+ return text
248
+ else:
249
+ label = self.decode(label, is_remove_duplicate=False)
250
+ return text, label
251
+ """
252
+ if isinstance(preds, torch.Tensor):
253
+ preds = preds.cpu().numpy()
254
+
255
+ preds_idx = preds.argmax(axis=2)
256
+ preds_prob = preds.max(axis=2)
257
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
258
+ if label is None:
259
+ return text
260
+ label = self.decode(label, is_remove_duplicate=False)
261
+ return text, label
262
+
263
+ def get_ignored_tokens(self):
264
+ beg_idx = self.get_beg_end_flag_idx("beg")
265
+ end_idx = self.get_beg_end_flag_idx("end")
266
+ return [beg_idx, end_idx]
267
+
268
+ def get_beg_end_flag_idx(self, beg_or_end):
269
+ if beg_or_end == "beg":
270
+ idx = np.array(self.dict[self.beg_str])
271
+ elif beg_or_end == "end":
272
+ idx = np.array(self.dict[self.end_str])
273
+ else:
274
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
275
+ % beg_or_end
276
+ return idx
277
+
278
+
279
+ class RFLLabelDecode(BaseRecLabelDecode):
280
+ """ Convert between text-label and text-index """
281
+
282
+ def __init__(self, character_dict_path=None, use_space_char=False,
283
+ **kwargs):
284
+ super(RFLLabelDecode, self).__init__(character_dict_path,
285
+ use_space_char)
286
+
287
+ def add_special_char(self, dict_character):
288
+ self.beg_str = "sos"
289
+ self.end_str = "eos"
290
+ dict_character = dict_character
291
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
292
+ return dict_character
293
+
294
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
295
+ """ convert text-index into text-label. """
296
+ result_list = []
297
+ ignored_tokens = self.get_ignored_tokens()
298
+ [beg_idx, end_idx] = self.get_ignored_tokens()
299
+ batch_size = len(text_index)
300
+ for batch_idx in range(batch_size):
301
+ char_list = []
302
+ conf_list = []
303
+ for idx in range(len(text_index[batch_idx])):
304
+ if text_index[batch_idx][idx] in ignored_tokens:
305
+ continue
306
+ if int(text_index[batch_idx][idx]) == int(end_idx):
307
+ break
308
+ if is_remove_duplicate:
309
+ # only for predict
310
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
311
+ batch_idx][idx]:
312
+ continue
313
+ char_list.append(self.character[int(text_index[batch_idx][
314
+ idx])])
315
+ if text_prob is not None:
316
+ conf_list.append(text_prob[batch_idx][idx])
317
+ else:
318
+ conf_list.append(1)
319
+ text = ''.join(char_list)
320
+ result_list.append((text, np.mean(conf_list).tolist()))
321
+ return result_list
322
+
323
+ def __call__(self, preds, label=None, *args, **kwargs):
324
+ # if seq_outputs is not None:
325
+ if isinstance(preds, tuple) or isinstance(preds, list):
326
+ cnt_outputs, seq_outputs = preds
327
+ if isinstance(seq_outputs, torch.Tensor):
328
+ seq_outputs = seq_outputs.numpy()
329
+ preds_idx = seq_outputs.argmax(axis=2)
330
+ preds_prob = seq_outputs.max(axis=2)
331
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
332
+
333
+ if label is None:
334
+ return text
335
+ label = self.decode(label, is_remove_duplicate=False)
336
+ return text, label
337
+
338
+ else:
339
+ cnt_outputs = preds
340
+ if isinstance(cnt_outputs, torch.Tensor):
341
+ cnt_outputs = cnt_outputs.numpy()
342
+ cnt_length = []
343
+ for lens in cnt_outputs:
344
+ length = round(np.sum(lens))
345
+ cnt_length.append(length)
346
+ if label is None:
347
+ return cnt_length
348
+ label = self.decode(label, is_remove_duplicate=False)
349
+ length = [len(res[0]) for res in label]
350
+ return cnt_length, length
351
+
352
+ def get_ignored_tokens(self):
353
+ beg_idx = self.get_beg_end_flag_idx("beg")
354
+ end_idx = self.get_beg_end_flag_idx("end")
355
+ return [beg_idx, end_idx]
356
+
357
+ def get_beg_end_flag_idx(self, beg_or_end):
358
+ if beg_or_end == "beg":
359
+ idx = np.array(self.dict[self.beg_str])
360
+ elif beg_or_end == "end":
361
+ idx = np.array(self.dict[self.end_str])
362
+ else:
363
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
364
+ % beg_or_end
365
+ return idx
366
+
367
+
368
+ class SRNLabelDecode(BaseRecLabelDecode):
369
+ """ Convert between text-label and text-index """
370
+
371
+ def __init__(self,
372
+ character_dict_path=None,
373
+ use_space_char=False,
374
+ **kwargs):
375
+ self.max_text_length = kwargs.get('max_text_length', 25)
376
+ super(SRNLabelDecode, self).__init__(character_dict_path,
377
+ use_space_char)
378
+
379
+ def __call__(self, preds, label=None, *args, **kwargs):
380
+ pred = preds['predict']
381
+ char_num = len(self.character_str) + 2
382
+ if isinstance(pred, torch.Tensor):
383
+ pred = pred.numpy()
384
+ pred = np.reshape(pred, [-1, char_num])
385
+
386
+ preds_idx = np.argmax(pred, axis=1)
387
+ preds_prob = np.max(pred, axis=1)
388
+
389
+ preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
390
+
391
+ preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
392
+
393
+ text = self.decode(preds_idx, preds_prob)
394
+
395
+ if label is None:
396
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
397
+ return text
398
+ label = self.decode(label)
399
+ return text, label
400
+
401
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
402
+ """ convert text-index into text-label. """
403
+ result_list = []
404
+ ignored_tokens = self.get_ignored_tokens()
405
+ batch_size = len(text_index)
406
+
407
+ for batch_idx in range(batch_size):
408
+ char_list = []
409
+ conf_list = []
410
+ for idx in range(len(text_index[batch_idx])):
411
+ if text_index[batch_idx][idx] in ignored_tokens:
412
+ continue
413
+ if is_remove_duplicate:
414
+ # only for predict
415
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
416
+ batch_idx][idx]:
417
+ continue
418
+ char_list.append(self.character[int(text_index[batch_idx][
419
+ idx])])
420
+ if text_prob is not None:
421
+ conf_list.append(text_prob[batch_idx][idx])
422
+ else:
423
+ conf_list.append(1)
424
+
425
+ text = ''.join(char_list)
426
+ result_list.append((text, np.mean(conf_list)))
427
+ return result_list
428
+
429
+ def add_special_char(self, dict_character):
430
+ dict_character = dict_character + [self.beg_str, self.end_str]
431
+ return dict_character
432
+
433
+ def get_ignored_tokens(self):
434
+ beg_idx = self.get_beg_end_flag_idx("beg")
435
+ end_idx = self.get_beg_end_flag_idx("end")
436
+ return [beg_idx, end_idx]
437
+
438
+ def get_beg_end_flag_idx(self, beg_or_end):
439
+ if beg_or_end == "beg":
440
+ idx = np.array(self.dict[self.beg_str])
441
+ elif beg_or_end == "end":
442
+ idx = np.array(self.dict[self.end_str])
443
+ else:
444
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
445
+ % beg_or_end
446
+ return idx
447
+
448
+
449
+ class TableLabelDecode(object):
450
+ """ """
451
+
452
+ def __init__(self,
453
+ character_dict_path,
454
+ **kwargs):
455
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
456
+ list_character = self.add_special_char(list_character)
457
+ list_elem = self.add_special_char(list_elem)
458
+ self.dict_character = {}
459
+ self.dict_idx_character = {}
460
+ for i, char in enumerate(list_character):
461
+ self.dict_idx_character[i] = char
462
+ self.dict_character[char] = i
463
+ self.dict_elem = {}
464
+ self.dict_idx_elem = {}
465
+ for i, elem in enumerate(list_elem):
466
+ self.dict_idx_elem[i] = elem
467
+ self.dict_elem[elem] = i
468
+
469
+ def load_char_elem_dict(self, character_dict_path):
470
+ list_character = []
471
+ list_elem = []
472
+ with open(character_dict_path, "rb") as fin:
473
+ lines = fin.readlines()
474
+ substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
475
+ character_num = int(substr[0])
476
+ elem_num = int(substr[1])
477
+ for cno in range(1, 1 + character_num):
478
+ character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
479
+ list_character.append(character)
480
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
481
+ elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
482
+ list_elem.append(elem)
483
+ return list_character, list_elem
484
+
485
+ def add_special_char(self, list_character):
486
+ self.beg_str = "sos"
487
+ self.end_str = "eos"
488
+ list_character = [self.beg_str] + list_character + [self.end_str]
489
+ return list_character
490
+
491
+ def __call__(self, preds):
492
+ structure_probs = preds['structure_probs']
493
+ loc_preds = preds['loc_preds']
494
+ if isinstance(structure_probs,torch.Tensor):
495
+ structure_probs = structure_probs.numpy()
496
+ if isinstance(loc_preds,torch.Tensor):
497
+ loc_preds = loc_preds.numpy()
498
+ structure_idx = structure_probs.argmax(axis=2)
499
+ structure_probs = structure_probs.max(axis=2)
500
+ structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
501
+ structure_probs, 'elem')
502
+ res_html_code_list = []
503
+ res_loc_list = []
504
+ batch_num = len(structure_str)
505
+ for bno in range(batch_num):
506
+ res_loc = []
507
+ for sno in range(len(structure_str[bno])):
508
+ text = structure_str[bno][sno]
509
+ if text in ['<td>', '<td']:
510
+ pos = structure_pos[bno][sno]
511
+ res_loc.append(loc_preds[bno, pos])
512
+ res_html_code = ''.join(structure_str[bno])
513
+ res_loc = np.array(res_loc)
514
+ res_html_code_list.append(res_html_code)
515
+ res_loc_list.append(res_loc)
516
+ return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
517
+ 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
518
+
519
+ def decode(self, text_index, structure_probs, char_or_elem):
520
+ """convert text-label into text-index.
521
+ """
522
+ if char_or_elem == "char":
523
+ current_dict = self.dict_idx_character
524
+ else:
525
+ current_dict = self.dict_idx_elem
526
+ ignored_tokens = self.get_ignored_tokens('elem')
527
+ beg_idx, end_idx = ignored_tokens
528
+
529
+ result_list = []
530
+ result_pos_list = []
531
+ result_score_list = []
532
+ result_elem_idx_list = []
533
+ batch_size = len(text_index)
534
+ for batch_idx in range(batch_size):
535
+ char_list = []
536
+ elem_pos_list = []
537
+ elem_idx_list = []
538
+ score_list = []
539
+ for idx in range(len(text_index[batch_idx])):
540
+ tmp_elem_idx = int(text_index[batch_idx][idx])
541
+ if idx > 0 and tmp_elem_idx == end_idx:
542
+ break
543
+ if tmp_elem_idx in ignored_tokens:
544
+ continue
545
+
546
+ char_list.append(current_dict[tmp_elem_idx])
547
+ elem_pos_list.append(idx)
548
+ score_list.append(structure_probs[batch_idx, idx])
549
+ elem_idx_list.append(tmp_elem_idx)
550
+ result_list.append(char_list)
551
+ result_pos_list.append(elem_pos_list)
552
+ result_score_list.append(score_list)
553
+ result_elem_idx_list.append(elem_idx_list)
554
+ return result_list, result_pos_list, result_score_list, result_elem_idx_list
555
+
556
+ def get_ignored_tokens(self, char_or_elem):
557
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
558
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
559
+ return [beg_idx, end_idx]
560
+
561
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
562
+ if char_or_elem == "char":
563
+ if beg_or_end == "beg":
564
+ idx = self.dict_character[self.beg_str]
565
+ elif beg_or_end == "end":
566
+ idx = self.dict_character[self.end_str]
567
+ else:
568
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
569
+ % beg_or_end
570
+ elif char_or_elem == "elem":
571
+ if beg_or_end == "beg":
572
+ idx = self.dict_elem[self.beg_str]
573
+ elif beg_or_end == "end":
574
+ idx = self.dict_elem[self.end_str]
575
+ else:
576
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
577
+ % beg_or_end
578
+ else:
579
+ assert False, "Unsupport type %s in char_or_elem" \
580
+ % char_or_elem
581
+ return idx
582
+
583
+
584
+ class SARLabelDecode(BaseRecLabelDecode):
585
+ """ Convert between text-label and text-index """
586
+
587
+ def __init__(self, character_dict_path=None, use_space_char=False,
588
+ **kwargs):
589
+ super(SARLabelDecode, self).__init__(character_dict_path,
590
+ use_space_char)
591
+
592
+ self.rm_symbol = kwargs.get('rm_symbol', False)
593
+
594
+ def add_special_char(self, dict_character):
595
+ beg_end_str = "<BOS/EOS>"
596
+ unknown_str = "<UKN>"
597
+ padding_str = "<PAD>"
598
+ dict_character = dict_character + [unknown_str]
599
+ self.unknown_idx = len(dict_character) - 1
600
+ dict_character = dict_character + [beg_end_str]
601
+ self.start_idx = len(dict_character) - 1
602
+ self.end_idx = len(dict_character) - 1
603
+ dict_character = dict_character + [padding_str]
604
+ self.padding_idx = len(dict_character) - 1
605
+ return dict_character
606
+
607
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
608
+ """ convert text-index into text-label. """
609
+ result_list = []
610
+ ignored_tokens = self.get_ignored_tokens()
611
+
612
+ batch_size = len(text_index)
613
+ for batch_idx in range(batch_size):
614
+ char_list = []
615
+ conf_list = []
616
+ for idx in range(len(text_index[batch_idx])):
617
+ if text_index[batch_idx][idx] in ignored_tokens:
618
+ continue
619
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
620
+ if text_prob is None and idx == 0:
621
+ continue
622
+ else:
623
+ break
624
+ if is_remove_duplicate:
625
+ # only for predict
626
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
627
+ batch_idx][idx]:
628
+ continue
629
+ char_list.append(self.character[int(text_index[batch_idx][
630
+ idx])])
631
+ if text_prob is not None:
632
+ conf_list.append(text_prob[batch_idx][idx])
633
+ else:
634
+ conf_list.append(1)
635
+ text = ''.join(char_list)
636
+ if self.rm_symbol:
637
+ comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
638
+ text = text.lower()
639
+ text = comp.sub('', text)
640
+ result_list.append((text, np.mean(conf_list).tolist()))
641
+ return result_list
642
+
643
+ def __call__(self, preds, label=None, *args, **kwargs):
644
+ if isinstance(preds, torch.Tensor):
645
+ preds = preds.cpu().numpy()
646
+ preds_idx = preds.argmax(axis=2)
647
+ preds_prob = preds.max(axis=2)
648
+
649
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
650
+
651
+ if label is None:
652
+ return text
653
+ label = self.decode(label, is_remove_duplicate=False)
654
+ return text, label
655
+
656
+ def get_ignored_tokens(self):
657
+ return [self.padding_idx]
658
+
659
+
660
+ class CANLabelDecode(BaseRecLabelDecode):
661
+ """ Convert between latex-symbol and symbol-index """
662
+
663
+ def __init__(self, character_dict_path=None, use_space_char=False,
664
+ **kwargs):
665
+ super(CANLabelDecode, self).__init__(character_dict_path,
666
+ use_space_char)
667
+
668
+ def decode(self, text_index, preds_prob=None):
669
+ result_list = []
670
+ batch_size = len(text_index)
671
+ for batch_idx in range(batch_size):
672
+ seq_end = text_index[batch_idx].argmin(0)
673
+ idx_list = text_index[batch_idx][:seq_end].tolist()
674
+ symbol_list = [self.character[idx] for idx in idx_list]
675
+ probs = []
676
+ if preds_prob is not None:
677
+ probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
678
+
679
+ result_list.append([' '.join(symbol_list), probs])
680
+ return result_list
681
+
682
+ def __call__(self, preds, label=None, *args, **kwargs):
683
+ pred_prob, _, _, _ = preds
684
+ preds_idx = pred_prob.argmax(axis=2)
685
+
686
+ text = self.decode(preds_idx)
687
+ if label is None:
688
+ return text
689
+ label = self.decode(label)
690
+ return text, label