openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,753 @@
1
+ import io
2
+ import json
3
+ import re
4
+ import random
5
+ import traceback
6
+ import ast
7
+ from PIL import Image
8
+ import fitz
9
+ import lmdb
10
+ import numpy as np
11
+ from torch.utils.data import Dataset
12
+ from torchvision import transforms as T
13
+ from torchvision.transforms import functional as F
14
+ from openrec.preprocess import create_operators, transform
15
+ from openrec.preprocess.rec_aug import DocAug #, PARSeqAugPIL
16
+ from pdf2image import convert_from_path
17
+
18
+ import torch.distributed as dist
19
+
20
+ import unicodedata
21
+
22
+ # 完整音标表
23
+ COMBINING_DIACRITICS = {
24
+ '´': '\u0301',
25
+ '`': '\u0300',
26
+ 'ˆ': '\u0302',
27
+ '¨': '\u0308',
28
+ 'ˇ': '\u030C',
29
+ '˘': '\u0306',
30
+ '¯': '\u0304',
31
+ '˚': '\u030A',
32
+ '˜': '\u0303',
33
+ '˙': '\u0307',
34
+ '¸': '\u0327',
35
+ '˛': '\u0328',
36
+ '̇': '\u0307',
37
+ '̱': '\u0331',
38
+ '̥': '\u0325',
39
+ '̩': '\u0329',
40
+ '̯': '\u032F',
41
+ '̄': '\u0304',
42
+ '̋': '\u030B',
43
+ '̨': '\u0328',
44
+ }
45
+
46
+ # 正则表达式匹配 “音标 + 字符”
47
+ pattern = re.compile(r'([' + re.escape(''.join(COMBINING_DIACRITICS.keys())) +
48
+ r'])([A-Za-z])')
49
+
50
+
51
+ def fix_diacritics_regex(text):
52
+
53
+ def repl(match):
54
+ accent = COMBINING_DIACRITICS[match.group(1)]
55
+ base = match.group(2)
56
+ return unicodedata.normalize('NFC', base + accent)
57
+
58
+ return pattern.sub(repl, text)
59
+
60
+
61
+ def random_color():
62
+ """生成一个随机的 RGB 颜色元组 (0.0 到 1.0 之间)."""
63
+ return (random.random(), random.random(), random.random())
64
+
65
+
66
+ def load_pdf_as_image(pdf_path, dpi=300):
67
+
68
+ images = convert_from_path(pdf_path, dpi=dpi)
69
+ return images[0]
70
+
71
+
72
+ def get_masked_image(
73
+ image,
74
+ page_info,
75
+ box_padding=3,
76
+ mask_color=(255, 255, 255),
77
+ token_num_limit=(15, 5000),
78
+ overlap_ratio_limit=0.003,
79
+ ):
80
+ # box_padding = max(0, int(DPI * box_padding_ratio) )
81
+
82
+ tot_num = sum([block['token_num'] for block in page_info])
83
+ if tot_num < token_num_limit[0] or tot_num > token_num_limit[1]:
84
+ return None
85
+
86
+ box_list = [
87
+ block.get('vision_content', {}).get('box_list', '[]')
88
+ for block in page_info
89
+ ]
90
+ box_list = [ast.literal_eval(box) for box in box_list]
91
+ box_list = [item for lst in box_list for item in lst]
92
+ # box = (top, bottom, left, right)
93
+
94
+ tex_token_info_list = [
95
+ block.get('vision_content', {}).get('tex_token_info_list', [])
96
+ for block in page_info
97
+ ]
98
+ tex_token_info_list = [
99
+ ast.literal_eval(info) for info in tex_token_info_list
100
+ ]
101
+ tex_token_info_list = [item for lst in tex_token_info_list for item in lst]
102
+ # tex_token_info = ("chars", "pos", "len", "attr")
103
+
104
+ assert len(box_list) == len(tex_token_info_list)
105
+
106
+ width, height = image.size
107
+ mask_matrix = np.zeros((height, width), dtype=np.uint8)
108
+ # for box in box_list:
109
+ additonal_mask_matrix = np.zeros((height, width), dtype=np.uint8)
110
+ for tex_token_info, box in zip(tex_token_info_list, box_list):
111
+ if box == [0, 0, 0, 0]:
112
+ continue
113
+ top, bottom, left, right = box
114
+ top, bottom, left, right = max(0, top - box_padding), min(
115
+ height - 1, bottom + box_padding), max(0, left - box_padding), min(
116
+ width - 1, right + box_padding)
117
+ # mask_matrix[top:bottom+1, left:right+1] += 1
118
+ if tex_token_info[3] not in ('tabular', ):
119
+ mask_matrix[top:bottom + 1, left:right + 1] += 1
120
+ else:
121
+ additonal_mask_matrix[top:bottom + 1, left:right + 1] += 1
122
+
123
+ if mask_matrix.max() == 0:
124
+ return None
125
+ overlap_matrix = (mask_matrix > 1)
126
+ overlap_ratio = overlap_matrix.sum() / (width * height)
127
+ if overlap_ratio > overlap_ratio_limit:
128
+ return None
129
+ # print(f"overlap ratio: {overlap_ratio}")
130
+
131
+ mask_matrix = mask_matrix + additonal_mask_matrix
132
+ mask_matrix = (mask_matrix == 0)
133
+ img_array = np.array(image)
134
+ col = np.array(mask_color, img_array.dtype)
135
+ img_array[mask_matrix] = col
136
+ masked_image = Image.fromarray(img_array)
137
+
138
+ return masked_image
139
+
140
+
141
+ layout_map = {
142
+ 'section_paragraph': 'paragraph',
143
+ 'subsection_paragraph': 'paragraph',
144
+ 'subsubsection_paragraph': 'paragraph',
145
+ 'document_paragraph': 'paragraph',
146
+ 'enumerate_paragraph': 'paragraph',
147
+ 'enumerate*_paragraph': 'paragraph',
148
+ 'itemize_paragraph': 'paragraph',
149
+ 'itemize*_paragraph': 'paragraph',
150
+ 'paragraph_paragraph': 'paragraph',
151
+ 'subparagraph_paragraph': 'paragraph',
152
+ 'chapter_paragraph': 'paragraph',
153
+ 'part_paragraph': 'paragraph',
154
+ 'figure_paragraph': 'figure',
155
+ 'figure*_paragraph': 'figure',
156
+ 'caption_paragraph': 'caption',
157
+ 'caption*_paragraph': 'caption',
158
+ 'table_paragraph': 'table',
159
+ 'table*_paragraph': 'table',
160
+ 'deluxetable_paragraph': 'table',
161
+ 'deluxetable*_paragraph': 'table',
162
+ 'section_title': 'section_title',
163
+ 'subsection_title': 'subsection_title',
164
+ 'subsubsection_title': 'subsubsection_title',
165
+ 'subparagraph_title': 'subsection_title',
166
+ 'title_paragraph': 'title',
167
+ 'abstract_paragraph': 'abstract',
168
+ 'abstract*_paragraph': 'abstract',
169
+ 'footnote_paragraph': 'footnote',
170
+ 'tablenotes_paragraph': 'tablenotes',
171
+ # footer
172
+ }
173
+
174
+ pattern_indent = re.compile(r'\\,|\\;|\\:|\\!|\\\s+')
175
+
176
+
177
+ def rm_indent_in_latex(text):
178
+ text = pattern_indent.sub('', text)
179
+ return text
180
+
181
+
182
+ def resize_image(original_width, original_height, max_width, max_height):
183
+ # 计算宽高比
184
+ original_width = max(original_width, 64)
185
+ original_height = max(original_height, 64)
186
+ aspect_ratio = original_width / original_height
187
+
188
+ # 计算新的宽度和高度
189
+ if original_width > max_width or original_height > max_height:
190
+ if (max_width / max_height) >= aspect_ratio:
191
+ # 按高度限制比例
192
+ new_height = max_height
193
+ new_width = int(new_height * aspect_ratio)
194
+ else:
195
+ # 按宽度限制比例
196
+ new_width = max_width
197
+ new_height = int(new_width / aspect_ratio)
198
+ else:
199
+ # 如果图片已经小于或等于最大尺寸,则无需调整
200
+ new_width, new_height = original_width, original_height
201
+ return new_width, new_height
202
+
203
+
204
+ class NaSizeDataSet(Dataset):
205
+
206
+ def __init__(self, config, mode, logger, seed=None, epoch=0, task='Rec'):
207
+ super(NaSizeDataSet, self).__init__()
208
+ self.logger = logger
209
+ self.mode = mode.lower()
210
+
211
+ if dist.is_available() and dist.is_initialized():
212
+ world_size = dist.get_world_size()
213
+ rank = dist.get_rank()
214
+ else:
215
+ world_size = 1
216
+ rank = 0
217
+ num_replicas = world_size
218
+
219
+ global_config = config['Global']
220
+ dataset_config = config[mode]['dataset']
221
+ loader_config = config[mode]['loader']
222
+ self.seed = seed if seed is not None else epoch
223
+ random.seed(self.seed)
224
+ self.e2e_info = dataset_config.get('e2e_info', True)
225
+ self.layout_info = dataset_config.get('layout_info', False)
226
+ self.add_return = dataset_config.get('add_return', True)
227
+ self.zoom_min_factor = dataset_config.get('zoom_min_factor', 10)
228
+ self.use_zoom = dataset_config.get('use_zoom', False)
229
+ self.all_data = dataset_config.get('all_data', False)
230
+ self.use_linedata = dataset_config.get('use_linedata', False)
231
+ self.test_data = dataset_config.get('test_data', False)
232
+ self.use_aug = dataset_config.get('use_aug', True)
233
+ self.use_table = dataset_config.get('use_table', False)
234
+ self.e2e_info = False if self.layout_info else self.e2e_info
235
+
236
+ self.use_math_norm = dataset_config.get('use_math_norm', False)
237
+ self.root_path = dataset_config.get('root_path', None)
238
+ if self.root_path is None:
239
+ assert False, 'root_path is None'
240
+ self.env = None # LMDB environment
241
+ img_label_pair_list = {}
242
+ self.do_shuffle = loader_config['shuffle']
243
+
244
+ self.max_side = dataset_config.get('max_side',
245
+ [64 * 15, 64 * 22]) # w, h
246
+ self.divided_factor = dataset_config.get('divided_factor',
247
+ [64., 64.]) # w, h
248
+ self.use_region = dataset_config.get('use_region', False)
249
+ self.use_ch = dataset_config.get('use_ch', False)
250
+ logger.info('Initialize indexs of doc datasets')
251
+ if self.all_data:
252
+ self.__init_lmdb()
253
+
254
+ label_json_list = []
255
+ if self.test_data:
256
+ epoch_current = (epoch - 1) % 10
257
+ label_json_list = [
258
+ f'{self.root_path}/hiertext_lmdb/label_key_char_line_para.json',
259
+ f'{self.root_path}/hiertext_lmdb/label_key_word_{epoch_current}ep.json'
260
+ ]
261
+ test_lmdb_path = f'{self.root_path}/hiertext_lmdb/image_lmdb'
262
+ self.env_test = lmdb.open(test_lmdb_path,
263
+ max_readers=32,
264
+ readonly=True,
265
+ lock=False,
266
+ readahead=False,
267
+ meminit=False)
268
+ self.txn_test = self.env_test.begin()
269
+ else:
270
+ epoch_current_10ep = (epoch - 1) % 10
271
+ epoch_current_20ep = (epoch - 1) % 20
272
+ epoch_current_5ep = (epoch - 1) % 5
273
+ label_json_list = [
274
+ f'{self.root_path}/tex_en_label/block_pymu_fix_none_line/label_key.json',
275
+ f'{self.root_path}/tex_en_label/block_pymu_case_fix_none_line_10ep/{epoch_current_10ep}/label_key.json',
276
+ f'{self.root_path}/tex_en_label/rec_label_all_region_fix_none_line_20ep/{epoch_current_20ep}/label_key.json',
277
+ f'{self.root_path}/tex_en_label/math_all_5ep/{epoch_current_5ep}/label_key.json'
278
+ ]
279
+ if self.use_ch:
280
+ epoch_current = (epoch - 1) % 4
281
+ label_json_list += [
282
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch/math_inline/label_key.json',
283
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch/math_display/label_key.json',
284
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch/plain_text_4ep/{epoch_current}/label_key.json',
285
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch/region_4ep/{epoch_current}/label_key.json',
286
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch_random_new/math_inline/label_key.json',
287
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch_random_new/math_display/label_key.json',
288
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch_random_new/plain_text_4ep/{epoch_current}/label_key.json',
289
+ f'{self.root_path}/tex_ch_label/40w_e2e_test1_croppdf_ch_random_new/region_4ep/{epoch_current}/label_key.json'
290
+ ]
291
+
292
+ ratio_sample = [1 for _ in label_json_list]
293
+
294
+ self.docaug = DocAug()
295
+ if self.use_linedata and self.all_data:
296
+ epoch_current = (epoch - 1) % 10
297
+ line_json_list = [
298
+ f'{self.root_path}/K-12/label_key_qwen.json', # 25w tex_norm 1
299
+ f'{self.root_path}/LSVT-2019/label_key.json', # 260868 5
300
+ f'{self.root_path}/webdata_MTWI/label_key.json', # 147160 7
301
+ f'{self.root_path}/HWDB2Train/label_key.json', # 376029 1 largefile 180g
302
+ f'{self.root_path}/HWDB2Train/label_key_region.json', # 36684 10
303
+ f'{self.root_path}/TAL_OCR_HW/label_key.json', # 2w 10
304
+ f'{self.root_path}/K-12_exam/label_key_qwen.json', # 5w tex_norm 3
305
+ f'{self.root_path}/hw_pdf/label_key_qwen_tex_norm.json', # 5.5k 5368 10
306
+ f'{self.root_path}/hw_pdf/label_key_qwen_crop_tex_norm.json', # 8.1w 83723 83755 5
307
+ f'{self.root_path}/hw_xhs/label_key_qwen_tex_norm.json', # 73826 73911 5
308
+ f'{self.root_path}/dfcf_pdf_dpi300/label_key.json', # 349907 fix huice dpi300 3
309
+ f'{self.root_path}/huaxue_jiaoyu/label_key.json', # 263000 # formula 2
310
+ f'{self.root_path}/latex_aug_new_circled/label_key.json', # 215000 223000 # rm only textcircled 2
311
+ f'{self.root_path}/nongminribao_pdf_dpi300/label_keys.json', #126963 dpi300 2
312
+ f'{self.root_path}/hiertext_lmdb/label_key_char_line_para.json', # 25w 1
313
+ f'{self.root_path}/hiertext_lmdb/label_key_word_{epoch_current}ep.json', # 8w 1
314
+ ]
315
+ ratio_sample += [
316
+ 1, 5, 7, 1, 10, 10, 3, 10, 5, 5, 3, 5, 5, 2, 1, 1
317
+ ] # ratio_sample for line_json_list
318
+ if self.use_table:
319
+ line_json_list += [
320
+ f'{self.root_path}/hw_pdf_table/label_key.json', #677
321
+ f'{self.root_path}/dfcf_table/label_key_refine.json', # 218330
322
+ f'{self.root_path}/jiaoyu_table/label_key.json', # 3233
323
+ f'{self.root_path}/pubtab1m_table/label_key_refine.json', # 131463
324
+ ]
325
+ ratio_sample += [50, 3, 50, 3] # for table
326
+
327
+ self.__init_line_lmdb(epoch_current)
328
+ label_json_list += line_json_list
329
+
330
+ img_label_pair_list = self.load_label_json(label_json_list,
331
+ ratio_sample)
332
+
333
+ self.need_reset = True
334
+ self.img_label_pair_list = {}
335
+ self.img_label_pair_list_small = {}
336
+ for key in img_label_pair_list:
337
+ json_data_list = img_label_pair_list[key]
338
+ if self.mode == 'train':
339
+ if len(json_data_list) < num_replicas * 2:
340
+ continue
341
+ if len(json_data_list) <= 8 * num_replicas:
342
+ self.img_label_pair_list_small[key] = json_data_list
343
+ # 补充至num_replicas的倍数
344
+ fill_num = num_replicas - len(json_data_list) % num_replicas
345
+ if fill_num < num_replicas:
346
+ for i in range(fill_num):
347
+ json_data_list.append(
348
+ json_data_list[i % len(json_data_list)])
349
+ # 按照GPU数和rank划分数据
350
+ json_data_list = json_data_list[rank::num_replicas]
351
+ random.shuffle(json_data_list)
352
+ self.img_label_pair_list[key] = json_data_list
353
+
354
+ del img_label_pair_list
355
+ self.ops = create_operators(dataset_config['transforms'],
356
+ global_config)
357
+ self.interpolation = T.InterpolationMode.BICUBIC
358
+ transforms = []
359
+ transforms.extend([
360
+ T.ToTensor(),
361
+ T.Normalize(0.5, 0.5),
362
+ ])
363
+ self.transforms = T.Compose(transforms)
364
+ self.math_pattern = re.compile(r'\\\(|\\\[')
365
+ self.rules = [
366
+ # 超过4个连续 <|unk|> → 4 个空格
367
+ (r'(?:<\|unk\|>){5,}', ' '),
368
+ # 超过4个连续 \uffff → 4 个空格
369
+ (r'(?:\uffff){5,}', ' '),
370
+ #(可选)单个 <|unk|> → 空格
371
+ (r'<\|unk\|>', ' '),
372
+ #(可选)单个 \uffff → 空格
373
+ (r'\uffff', ' '),
374
+ (r'_{6,}', '______'),
375
+ (r'\.{6,}', '......'),
376
+ ]
377
+
378
+ def load_label_json(self, label_json_list, ratio_sample):
379
+ img_label_pair_list = {}
380
+ for line_json, rti_sam in zip(label_json_list, ratio_sample):
381
+ with open(line_json, 'r') as f:
382
+ json_data_list = json.load(f)
383
+ for keywh, value_list in json_data_list.items():
384
+ if rti_sam > 1:
385
+ value_list = value_list * rti_sam
386
+
387
+ w_r, h_r = keywh.split('_')
388
+ w_r = int(w_r)
389
+ h_r = int(h_r)
390
+ w_r, h_r = resize_image(w_r, h_r, self.max_side[0],
391
+ self.max_side[1])
392
+ h_r = max(
393
+ int(h_r // self.divided_factor[1] *
394
+ self.divided_factor[1]), self.divided_factor[1])
395
+ w_r = max(
396
+ int(w_r // self.divided_factor[0] *
397
+ self.divided_factor[0]), self.divided_factor[0])
398
+
399
+ key = str(w_r) + '_' + str(h_r)
400
+ if key in img_label_pair_list:
401
+ img_label_pair_list[key].extend(value_list)
402
+ else:
403
+ img_label_pair_list[key] = value_list
404
+ return img_label_pair_list
405
+
406
+ def __init_lmdb(self):
407
+ """Initializes the LMDB environment."""
408
+ if self.env is None:
409
+ # Set max_readers high enough for potential multi-process data loading
410
+ # map_size should be large enough to hold your entire dataset
411
+ # self.env = lmdb.open(self.lmdb_path, readonly=True, create=False) # create=False 表示如果不存在则报错
412
+ en_lmdb_path = f'{self.root_path}/en_pdf_lmdb'
413
+ self.env = lmdb.open(
414
+ en_lmdb_path,
415
+ max_readers=32,
416
+ readonly=True,
417
+ lock=False,
418
+ readahead=False,
419
+ meminit=False,
420
+ )
421
+ self.txn = self.env.begin()
422
+ ch_lmdb_path = f'{self.root_path}/ch_pdf_lmdb'
423
+ self.env_ch = lmdb.open(
424
+ ch_lmdb_path,
425
+ max_readers=32,
426
+ readonly=True,
427
+ lock=False,
428
+ readahead=False,
429
+ meminit=False,
430
+ )
431
+ self.txn_ch = self.env_ch.begin()
432
+
433
+ def __init_line_lmdb(self, epoch_current):
434
+ """Initializes the LMDB environment."""
435
+
436
+ lmdb_paths = {
437
+ 'k12_': f'{self.root_path}/K-12/image_lmdb_qwen',
438
+ 'lsvt_': f'{self.root_path}/LSVT-2019/image_lmdb',
439
+ 'mtwi_': f'{self.root_path}/webdata_MTWI/image_lmdb',
440
+ 'tal_': f'{self.root_path}/TAL_OCR_HW/image_lmdb',
441
+ 'hwdb_': f'{self.root_path}/HWDB2Train/image_lmdb',
442
+ 'exam_': f'{self.root_path}/K-12_exam/image_lmdb_qwen',
443
+ 'page_hwpdf': f'{self.root_path}/hw_pdf/image_lmdb_qwen_tex_norm',
444
+ 'crop_hwpdf':
445
+ f'{self.root_path}/hw_pdf/image_lmdb_qwen_crop_tex_norm',
446
+ 'xhs_hw_': f'{self.root_path}/hw_xhs/image_lmdb_qwen_tex_norm', #
447
+ 'dfcf_pdf_':
448
+ f'{self.root_path}/dfcf_pdf_dpi300/image_lmdb', # fix huice
449
+ 'hauxuejiaoyu_': f'{self.root_path}/huaxue_jiaoyu/image_lmdb',
450
+ 'augtex_':
451
+ f'{self.root_path}/latex_aug_new_circled/image_lmdb', # rm only textcircled
452
+ 'nongminribao_pdf_':
453
+ f'{self.root_path}/nongminribao_pdf_dpi300/image_lmdb', #126963
454
+ 'hiertext_': f'{self.root_path}/hiertext_lmdb/image_lmdb',
455
+ }
456
+
457
+ if self.use_table:
458
+ lmdb_paths[
459
+ 'hw_table_'] = f'{self.root_path}/hw_pdf_table/image_lmdb'
460
+ lmdb_paths[
461
+ 'dfcf_table_'] = f'{self.root_path}/dfcf_table/image_lmdb_refine'
462
+ lmdb_paths[
463
+ 'jiaoyu_table_'] = f'{self.root_path}/jiaoyu_table/image_lmdb'
464
+ lmdb_paths[
465
+ 'pubtab1m_table_'] = f'{self.root_path}/pubtab1m_table/image_lmdb_refine'
466
+
467
+ self.txns = {}
468
+ lmdb_args = dict(max_readers=32,
469
+ readonly=True,
470
+ lock=False,
471
+ readahead=False,
472
+ meminit=False)
473
+
474
+ for prefix, path in lmdb_paths.items():
475
+ env = lmdb.open(path, **lmdb_args)
476
+ self.txns[prefix] = env.begin()
477
+
478
+ def crop_pdf_as_image(self, data_info, dpi=300, is_math=False):
479
+ file_name = data_info['file_name']
480
+
481
+ bbox_crop = data_info['bbox']
482
+ if isinstance(bbox_crop[0], list):
483
+ bbox = [
484
+ bbox_crop[0][0], bbox_crop[0][1], bbox_crop[0][2],
485
+ bbox_crop[0][3]
486
+ ]
487
+ for i in range(1, len(bbox_crop)):
488
+ bbox[0] = min(bbox[0], bbox_crop[i][0])
489
+ bbox[1] = min(bbox[1], bbox_crop[i][1])
490
+ bbox[2] = max(bbox[2], bbox_crop[i][2])
491
+ bbox[3] = max(bbox[3], bbox_crop[i][3])
492
+ else:
493
+ bbox = bbox_crop
494
+
495
+ # Use LMDB to read the PDF file
496
+ if '/home/ubuntu/bigdiskdata/' in file_name: # for ch pdf
497
+ pdf_data = self.txn_ch.get(file_name.encode('utf-8'))
498
+ else:
499
+ pdf_data = self.txn.get(file_name.encode('utf-8'))
500
+ if pdf_data is None:
501
+ return None
502
+ doc = fitz.open(stream=pdf_data, filetype='pdf')
503
+
504
+ # crop pdf with bbox
505
+ page = doc[0]
506
+ rect = fitz.Rect(*[x * 72. / dpi for x in bbox])
507
+ if not is_math and random.random() < 0.2:
508
+ text_dict = page.get_text('words', clip=rect)
509
+ box_color = random_color()
510
+ for x0, y0, x1, y1, text, _, _, _ in text_dict:
511
+ if text.strip(): # 过滤空文本
512
+ text = re.sub(r'[,:;.]', '', text)
513
+ if text[1:-1].isdigit() and (
514
+ text[0] == '(' and text[-1] == ')'
515
+ or text[0] == '[' and text[-1] == ']'):
516
+ bbox_color = [x0 + 3, y0, x1 - 3, y1]
517
+ rect_bbox = fitz.Rect(*bbox_color)
518
+ page.draw_rect(rect_bbox, color=box_color, width=1.0)
519
+ elif random.random() < 0.01:
520
+ bbox_color = [x0, y0, x1, y1]
521
+ rect_bbox = fitz.Rect(*bbox_color)
522
+ page.draw_rect(rect_bbox, color=box_color, width=1.0)
523
+ crop_img = page.get_pixmap(clip=rect, dpi=dpi)
524
+ # 转换为img
525
+ image = Image.frombytes('RGB', [crop_img.width, crop_img.height],
526
+ crop_img.samples)
527
+ line_data = False
528
+ if 'lines_bbox' in data_info:
529
+ lines_bbox_ = data_info['lines_bbox']
530
+ if data_info['class_name'] == 'region':
531
+ lines_bbox = []
532
+ for bbox_item in lines_bbox_:
533
+ lines_bbox.extend(bbox_item)
534
+ else:
535
+ lines_bbox = lines_bbox_
536
+ if len(lines_bbox) > 3:
537
+ line_data = True
538
+ else:
539
+ line_data = False
540
+ if len(lines_bbox) > 0:
541
+ np_img = np.array(image)
542
+ zero_img = np.zeros_like(np_img) + 255
543
+ for bbox_item in lines_bbox:
544
+ x1, y1, x2, y2 = bbox_item
545
+ x1 = x1 - bbox[0]
546
+ y1 = y1 - bbox[1]
547
+ x2 = x2 - bbox[0]
548
+ y2 = y2 - bbox[1]
549
+ zero_img[y1:y2, x1:x2] = np_img[y1:y2, x1:x2]
550
+ image = Image.fromarray(zero_img)
551
+ doc.close()
552
+ return image, line_data
553
+
554
+ def resize_norm_img(self,
555
+ data,
556
+ imgW,
557
+ imgH,
558
+ zoom_time,
559
+ padding=False,
560
+ line_data=False):
561
+ img = data['image']
562
+
563
+ zoom_time = float(zoom_time) / 10.
564
+
565
+ if zoom_time <= 1.:
566
+ zoom_time = 1.
567
+ w, h = img.size
568
+ if self.use_zoom:
569
+ if imgW / self.divided_factor[
570
+ 0] >= 5 or imgH / self.divided_factor[1] >= 5:
571
+ zoom_time = min(imgW / self.divided_factor[0] - 2,
572
+ imgH / self.divided_factor[1] - 2, zoom_time)
573
+ else:
574
+ zoom_time = min(imgW / self.divided_factor[0] - 1,
575
+ imgH / self.divided_factor[1] - 1, zoom_time)
576
+ if imgW >= self.max_side[0] or imgH >= self.max_side[1]:
577
+ zoom_time = max(zoom_time, 1.0)
578
+ else:
579
+ zoom_time = max(zoom_time, 2.0)
580
+ # if not line_data else max(zoom_time, 1.0)
581
+ imgW_r = imgW / float(zoom_time)
582
+ imgH_r = imgH / float(zoom_time)
583
+ imgW_r = max(
584
+ int(imgW_r // self.divided_factor[0] * self.divided_factor[0]),
585
+ 64)
586
+ imgH_r = max(
587
+ int(imgH_r // self.divided_factor[1] * self.divided_factor[1]),
588
+ 64)
589
+ resized_image = F.resize(img, (imgH_r, imgW_r),
590
+ interpolation=self.interpolation)
591
+ else:
592
+ resized_image = F.resize(img, (imgH, imgW),
593
+ interpolation=self.interpolation)
594
+
595
+ img = self.transforms(resized_image)
596
+ valid_ratio = min(1.0, float(w / imgW))
597
+ data['image'] = img
598
+ data['valid_ratio'] = valid_ratio
599
+ return data
600
+
601
+ def remove_space_before_sn(self, text, rep_str):
602
+ # 匹配 “汉字 + 空格 + <|sn|>” 这种模式
603
+ # \u4e00-\u9fff 是中文字符的 Unicode 范围
604
+ return re.sub(r'([\u4e00-\u9fff])\s*<\|sn\|>', r'\1' + rep_str, text)
605
+
606
+ def clean_label(self, text):
607
+ for rule in self.rules:
608
+ text = re.sub(rule[0], rule[1], text)
609
+ text = fix_diacritics_regex(text)
610
+ return text
611
+
612
+ def __getitem__(self, properties):
613
+
614
+ if len(properties) != 5:
615
+ img_id, w_r, h_r, zoom_time = properties
616
+ else:
617
+ img_id, w_r, h_r, zoom_time, resume_batch = properties
618
+ if resume_batch > 0:
619
+ return np.zeros((1, 3), dtype=np.float32)
620
+ key = str(w_r) + '_' + str(h_r)
621
+ if img_id > len(self.img_label_pair_list[key]) - 1:
622
+ data_info = self.img_label_pair_list_small[key][img_id]
623
+ else:
624
+ data_info = self.img_label_pair_list[key][img_id]
625
+ label = data_info['label']
626
+ if isinstance(label, list):
627
+ label = '\n\n'.join(label)
628
+ if self.add_return:
629
+ rep_str = '<|sn|>'
630
+ else:
631
+ rep_str = ''
632
+ label = label.replace('<<<hyphen>>>', '')
633
+ label = label.replace('<<<change_line_token_wrap>>>', rep_str)
634
+ label = label.replace('<<<change_line_token_split>>>', rep_str)
635
+ label = label.replace('<<<null>>>', '')
636
+ label = self.remove_space_before_sn(label, rep_str)
637
+ label = self.clean_label(label)
638
+ if not self.add_return:
639
+ label = label.replace('<|sn|>', '')
640
+ try:
641
+ file_name = data_info['file_name']
642
+ img_data = None
643
+ if self.test_data:
644
+ img_data = self.txn_test.get(file_name.encode('utf-8'))
645
+ image = Image.open(io.BytesIO(img_data)).convert('RGB')
646
+ line_data = True
647
+ else:
648
+ if self.use_linedata:
649
+ for prefix, txn in self.txns.items():
650
+ if file_name.startswith(prefix):
651
+ img_data = txn.get(file_name.encode('utf-8'))
652
+ break
653
+ if img_data is not None or 'bbox' not in data_info:
654
+ image = Image.open(io.BytesIO(img_data)).convert('RGB')
655
+ line_data = True
656
+ else:
657
+ image, line_data = self.crop_pdf_as_image(
658
+ data_info,
659
+ dpi=300,
660
+ is_math=self.math_pattern.search(label))
661
+ if image is None or label is None:
662
+ if len(self.img_label_pair_list[key]) <= 8:
663
+ rnd_properties = [
664
+ random.randint(
665
+ 0,
666
+ len(self.img_label_pair_list_small[key]) - 1), w_r,
667
+ h_r, zoom_time
668
+ ]
669
+ else:
670
+ rnd_properties = [
671
+ random.randint(0,
672
+ len(self.img_label_pair_list[key]) - 1),
673
+ w_r, h_r, zoom_time
674
+ ]
675
+ return self.__getitem__(rnd_properties)
676
+ if self.use_table and (file_name.startswith('table_')
677
+ or file_name.startswith('dfcf_table_')):
678
+ line_data = False
679
+ data = {'image': image, 'label': label, 'arxiv': not line_data}
680
+ data = transform(data, self.ops[:-1])
681
+ if data is None:
682
+ if len(self.img_label_pair_list[key]) <= 8:
683
+ rnd_properties = [
684
+ random.randint(
685
+ 0,
686
+ len(self.img_label_pair_list_small[key]) - 1), w_r,
687
+ h_r, zoom_time
688
+ ]
689
+ else:
690
+ rnd_properties = [
691
+ random.randint(0,
692
+ len(self.img_label_pair_list[key]) - 1),
693
+ w_r, h_r, zoom_time
694
+ ]
695
+ return self.__getitem__(rnd_properties)
696
+
697
+ if self.use_aug:
698
+ w, h = image.size
699
+ if w > self.max_side[0] or h > self.max_side[1]:
700
+ w, h = resize_image(w, h, self.max_side[0],
701
+ self.max_side[1])
702
+ h = max(
703
+ int(h // self.divided_factor[1] *
704
+ self.divided_factor[1]), self.divided_factor[1])
705
+ w = max(
706
+ int(w // self.divided_factor[0] *
707
+ self.divided_factor[0]), self.divided_factor[0])
708
+ data['image'] = data['image'].resize((w, h))
709
+
710
+ if any(
711
+ file_name.startswith(prefix) for prefix in [
712
+ 'lsvt_', 'mtwi_', 'tal_', 'page_hwpdf',
713
+ 'crop_hwpdf', 'xhs_hw_', 'hiertext_', 'hw_table_'
714
+ ]):
715
+ data['arxiv'] = False
716
+ else:
717
+ data['arxiv'] = True
718
+ data = self.docaug(data)
719
+
720
+ data = self.resize_norm_img(data,
721
+ w_r,
722
+ h_r,
723
+ zoom_time,
724
+ line_data=line_data)
725
+ outs = transform(data, self.ops[-1:])
726
+ except:
727
+ self.logger.error(
728
+ 'When parsing line {}, error happened with msg: {}'.format(
729
+ data_info['file_name'], traceback.format_exc()))
730
+ outs = None
731
+ if outs is None:
732
+ # during evaluation, we should fix the idx to get same results for many times of evaluation.
733
+ if len(self.img_label_pair_list[key]) <= 8:
734
+ rnd_properties = [
735
+ random.randint(
736
+ 0,
737
+ len(self.img_label_pair_list_small[key]) - 1), w_r,
738
+ h_r, zoom_time
739
+ ]
740
+ else:
741
+ rnd_properties = [
742
+ random.randint(0,
743
+ len(self.img_label_pair_list[key]) - 1),
744
+ w_r, h_r, zoom_time
745
+ ]
746
+ return self.__getitem__(rnd_properties)
747
+ return outs
748
+
749
+ def __len__(self):
750
+ len_all = 0
751
+ for key in self.img_label_pair_list:
752
+ len_all += len(self.img_label_pair_list[key])
753
+ return len_all