openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,898 @@
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ from pathlib import Path
5
+
6
+ import os
7
+ import sys
8
+
9
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
10
+ sys.path.append(__dir__)
11
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
12
+
13
+ import queue
14
+ import threading
15
+ import time
16
+ from itertools import chain
17
+ from typing import Any, Dict, Optional, Tuple, Union
18
+ import cv2
19
+ import numpy as np
20
+ from PIL import Image
21
+ import random
22
+ import argparse
23
+ from multiprocessing import Process
24
+
25
+ from paddlex.utils import logging
26
+ from paddlex.inference.common.batch_sampler import ImageBatchSampler
27
+ from paddlex.inference.common.reader import ReadImage
28
+ from paddlex.inference.utils.hpi import HPIConfig
29
+ from paddlex.inference.utils.pp_option import PaddlePredictorOption
30
+ from paddlex.inference.pipelines import load_pipeline_config
31
+ from paddlex.inference.pipelines.base import BasePipeline
32
+ from paddlex.inference.pipelines.components import CropByBoxes
33
+ from paddlex.inference.pipelines.layout_parsing.utils import gather_imgs
34
+ from paddlex.inference.pipelines.paddleocr_vl.result import PaddleOCRVLBlock, PaddleOCRVLResult
35
+ from paddlex.inference.pipelines.paddleocr_vl.uilts import (
36
+ convert_otsl_to_html,
37
+ crop_margin,
38
+ filter_overlap_boxes,
39
+ merge_blocks,
40
+ tokenize_figure_of_table,
41
+ truncate_repetitive_content,
42
+ untokenize_figure_of_table,
43
+ )
44
+
45
+ from tools.engine.config import Config
46
+ from tools.utils.logging import get_logger
47
+ from tools.utils.utility import get_image_file_list
48
+ from tools.infer_rec import OpenRecognizer
49
+ from tools.to_markdown import MarkdownConverter
50
+
51
+ logger = get_logger(name='opendoc')
52
+
53
+ root_dir = Path(__file__).resolve().parent
54
+ DEFAULT_CFG_PATH_UNIREC = str(
55
+ root_dir / '../configs/rec/unirec/focalsvtr_ardecoder_unirec.yml')
56
+
57
+ IMAGE_LABELS = ['image', 'header_image', 'footer_image', 'seal']
58
+
59
+ markdown_converter = MarkdownConverter()
60
+
61
+
62
+ class OpenDoc(BasePipeline):
63
+ """_UniRec Pipeline"""
64
+
65
+ def __init__(
66
+ self,
67
+ gpuId: Optional[int] = 0,
68
+ pp_option: Optional[PaddlePredictorOption] = None,
69
+ use_hpip: bool = False,
70
+ hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
71
+ ) -> None:
72
+ """
73
+ Initializes the class with given configurations and options.
74
+
75
+ Args:
76
+ config (Dict): Configuration dictionary containing various settings.
77
+ gpuId (int, optional): GPU ID to run the predictions on. Defaults to 0.
78
+ pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
79
+ use_hpip (bool, optional): Whether to use the high-performance
80
+ inference plugin (HPIP) by default. Defaults to False.
81
+ hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
82
+ The default high-performance inference configuration dictionary.
83
+ Defaults to None.
84
+ """
85
+ if gpuId < 0:
86
+ device = None
87
+ else:
88
+ device = f'gpu:{gpuId}'
89
+ super().__init__(device=device,
90
+ pp_option=pp_option,
91
+ use_hpip=use_hpip,
92
+ hpi_config=hpi_config)
93
+ config = load_pipeline_config(
94
+ str(root_dir / '../configs/rec/unirec/opendoc_pipeline.yml'))
95
+
96
+ self.use_doc_preprocessor = config.get('use_doc_preprocessor', True)
97
+ if self.use_doc_preprocessor:
98
+ doc_preprocessor_config = config.get('SubPipelines', {}).get(
99
+ 'DocPreprocessor',
100
+ {
101
+ 'pipeline_config_error':
102
+ 'config error for doc_preprocessor_pipeline!'
103
+ },
104
+ )
105
+ self.doc_preprocessor_pipeline = self.create_pipeline(
106
+ doc_preprocessor_config)
107
+
108
+ self.use_layout_detection = config.get('use_layout_detection', True)
109
+ if self.use_layout_detection:
110
+ layout_det_config = config.get('SubModules', {}).get(
111
+ 'LayoutDetection',
112
+ {'model_config_error': 'config error for layout_det_model!'},
113
+ )
114
+ model_name = layout_det_config.get('model_name', None)
115
+ assert (model_name is not None and model_name
116
+ == 'PP-DocLayoutV2'), 'model_name must be PP-DocLayoutV2'
117
+ layout_kwargs = {}
118
+ if (threshold := layout_det_config.get('threshold',
119
+ None)) is not None:
120
+ layout_kwargs['threshold'] = threshold
121
+ if (layout_nms := layout_det_config.get('layout_nms',
122
+ None)) is not None:
123
+ layout_kwargs['layout_nms'] = layout_nms
124
+ if (layout_unclip_ratio :=
125
+ layout_det_config.get('layout_unclip_ratio',
126
+ None)) is not None:
127
+ layout_kwargs['layout_unclip_ratio'] = layout_unclip_ratio
128
+ if (layout_merge_bboxes_mode :=
129
+ layout_det_config.get('layout_merge_bboxes_mode',
130
+ None)) is not None:
131
+ layout_kwargs[
132
+ 'layout_merge_bboxes_mode'] = layout_merge_bboxes_mode
133
+ self.layout_det_model = self.create_model(layout_det_config,
134
+ **layout_kwargs)
135
+
136
+ self.use_chart_recognition = config.get('use_chart_recognition', True)
137
+
138
+ unirec_cfg = Config(DEFAULT_CFG_PATH_UNIREC).cfg
139
+ if unirec_cfg['Global'][
140
+ 'pretrained_model'] is None or not os.path.exists(
141
+ unirec_cfg['Global']['pretrained_model']):
142
+ cache_dir = Path.home() / '.cache' / 'openocr'
143
+ model_path = cache_dir / 'unirec-0.1b'
144
+ # modelscope download model
145
+ download_flag = False
146
+ if not os.path.exists(str(model_path) + '/model.pth'):
147
+ try:
148
+ from modelscope.hub.snapshot_download import snapshot_download
149
+ model_dir = snapshot_download(
150
+ 'topdktu/unirec-0.1b',
151
+ local_dir=model_path,
152
+ )
153
+ if os.path.exists(str(model_path) + '/model.pth'):
154
+ download_flag = True
155
+ except:
156
+ logger.error(
157
+ 'Try to download the model from modelscope failed.')
158
+
159
+ if not download_flag:
160
+ try:
161
+ from huggingface_hub import snapshot_download
162
+ model_dir = snapshot_download(
163
+ 'topdu/unirec-0.1b',
164
+ local_dir=model_path,
165
+ )
166
+ if os.path.exists(str(model_path) + '/model.pth'):
167
+ download_flag = True
168
+ except:
169
+ logger.error(
170
+ 'Try to download the model from huggingface failed.'
171
+ )
172
+ if not download_flag:
173
+ raise ImportError(
174
+ 'Please download the model from https://huggingface.co/topdu/unirec-0.1b or https://modelscope.cn/models/topdktu/unirec-0.1b and put the model in the directory: ~/.cache/openocr/unirec-0.1b'
175
+ )
176
+ else:
177
+ unirec_cfg['Global']['pretrained_model'] = str(
178
+ model_path) + '/model.pth'
179
+ else:
180
+ logger.info(
181
+ f'UniRec-0.1B Model already exists in {model_path}')
182
+ unirec_cfg['Global']['pretrained_model'] = str(
183
+ model_path) + '/model.pth'
184
+
185
+ self.vl_rec_model = OpenRecognizer(unirec_cfg, numId=gpuId)
186
+ self.format_block_content = config.get('format_block_content', False)
187
+
188
+ self.batch_sampler = ImageBatchSampler(
189
+ batch_size=config.get('batch_size', 1))
190
+ self.img_reader = ReadImage(format='BGR')
191
+ self.crop_by_boxes = CropByBoxes()
192
+
193
+ self.use_queues = config.get('use_queues', False)
194
+ self.merge_layout_blocks = config.get('merge_layout_blocks', True)
195
+ self.markdown_ignore_labels = config.get(
196
+ 'markdown_ignore_labels',
197
+ [
198
+ 'number',
199
+ 'footnote',
200
+ 'header',
201
+ 'header_image',
202
+ 'footer',
203
+ 'footer_image',
204
+ 'aside_text',
205
+ ],
206
+ )
207
+
208
+ def get_model_settings(
209
+ self,
210
+ use_doc_orientation_classify: Union[bool, None],
211
+ use_doc_unwarping: Union[bool, None],
212
+ use_layout_detection: Union[bool, None],
213
+ use_chart_recognition: Union[bool, None],
214
+ format_block_content: Union[bool, None],
215
+ merge_layout_blocks: Union[bool, None],
216
+ markdown_ignore_labels: Optional[list[str]] = None,
217
+ ) -> dict:
218
+ """
219
+ Get the model settings based on the provided parameters or default values.
220
+
221
+ Args:
222
+ use_doc_orientation_classify (Union[bool, None]): Enables document orientation classification if True. Defaults to system setting if None.
223
+ use_doc_unwarping (Union[bool, None]): Enables document unwarping if True. Defaults to system setting if None.
224
+
225
+ Returns:
226
+ dict: A dictionary containing the model settings.
227
+
228
+ """
229
+ if use_doc_orientation_classify is None and use_doc_unwarping is None:
230
+ use_doc_preprocessor = self.use_doc_preprocessor
231
+ else:
232
+ if use_doc_orientation_classify is True or use_doc_unwarping is True:
233
+ use_doc_preprocessor = True
234
+ else:
235
+ use_doc_preprocessor = False
236
+
237
+ if use_layout_detection is None:
238
+ use_layout_detection = self.use_layout_detection
239
+
240
+ if use_chart_recognition is None:
241
+ use_chart_recognition = self.use_chart_recognition
242
+
243
+ if format_block_content is None:
244
+ format_block_content = self.format_block_content
245
+
246
+ if merge_layout_blocks is None:
247
+ merge_layout_blocks = self.merge_layout_blocks
248
+
249
+ if markdown_ignore_labels is None:
250
+ markdown_ignore_labels = self.markdown_ignore_labels
251
+
252
+ return dict(
253
+ use_doc_preprocessor=use_doc_preprocessor,
254
+ use_layout_detection=use_layout_detection,
255
+ use_chart_recognition=use_chart_recognition,
256
+ format_block_content=format_block_content,
257
+ merge_layout_blocks=merge_layout_blocks,
258
+ markdown_ignore_labels=markdown_ignore_labels,
259
+ )
260
+
261
+ def check_model_settings_valid(self, input_params: dict) -> bool:
262
+ """
263
+ Check if the input parameters are valid based on the initialized models.
264
+
265
+ Args:
266
+ input_params (Dict): A dictionary containing input parameters.
267
+
268
+ Returns:
269
+ bool: True if all required models are initialized according to input parameters, False otherwise.
270
+ """
271
+
272
+ if input_params[
273
+ 'use_doc_preprocessor'] and not self.use_doc_preprocessor:
274
+ logging.error(
275
+ 'Set use_doc_preprocessor, but the models for doc preprocessor are not initialized.',
276
+ )
277
+ return False
278
+
279
+ return True
280
+
281
+ def get_layout_parsing_results(
282
+ self,
283
+ images,
284
+ layout_det_results,
285
+ imgs_in_doc,
286
+ use_chart_recognition=False,
287
+ vlm_kwargs=None,
288
+ merge_layout_blocks=True,
289
+ ):
290
+ blocks = []
291
+ block_imgs = []
292
+ text_prompts = []
293
+ block_labels = []
294
+ vlm_block_ids = []
295
+ figure_token_maps = []
296
+ drop_figures_set = set()
297
+ image_labels = (IMAGE_LABELS
298
+ if use_chart_recognition else IMAGE_LABELS + ['chart'])
299
+ for i, (image, layout_det_res, imgs_in_doc_for_img) in enumerate(
300
+ zip(images, layout_det_results, imgs_in_doc)):
301
+ layout_det_res = filter_overlap_boxes(layout_det_res)
302
+ boxes = layout_det_res['boxes']
303
+ blocks_for_img = self.crop_by_boxes(image, boxes)
304
+ if merge_layout_blocks:
305
+ blocks_for_img = merge_blocks(blocks_for_img,
306
+ non_merge_labels=image_labels +
307
+ ['table'])
308
+ blocks.append(blocks_for_img)
309
+ for j, block in enumerate(blocks_for_img):
310
+ block_img = block['img']
311
+ block_label = block['label']
312
+
313
+ if block_label not in image_labels and block_img is not None:
314
+ figure_token_map = {}
315
+ text_prompt = 'OCR:'
316
+ drop_figures = []
317
+ if block_label == 'table':
318
+ text_prompt = 'Table Recognition:'
319
+ block_img, figure_token_map, drop_figures = (
320
+ tokenize_figure_of_table(block_img, block['box'],
321
+ imgs_in_doc_for_img))
322
+ elif block_label == 'chart' and use_chart_recognition:
323
+ text_prompt = 'Chart Recognition:'
324
+ elif 'formula' in block_label and block_label != 'formula_number':
325
+ text_prompt = 'Formula Recognition:'
326
+ block_img = crop_margin(block_img)
327
+ block_imgs.append(block_img)
328
+ text_prompts.append(text_prompt)
329
+ block_labels.append(block_label)
330
+ figure_token_maps.append(figure_token_map)
331
+ vlm_block_ids.append((i, j))
332
+ drop_figures_set.update(drop_figures)
333
+
334
+ if vlm_kwargs is None:
335
+ vlm_kwargs = {}
336
+ elif vlm_kwargs.get('max_new_tokens', None) is None:
337
+ vlm_kwargs['max_new_tokens'] = 4096
338
+
339
+ kwargs = {
340
+ 'use_cache': True,
341
+ **vlm_kwargs,
342
+ }
343
+
344
+ vl_rec_results = []
345
+ for block_img, block_label in zip(block_imgs, block_labels):
346
+ block_img_forunirec = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
347
+ output_unirec = self.vl_rec_model(
348
+ img_numpy=Image.fromarray(block_img_forunirec), batch_num=1)[0]
349
+ unirec_res = output_unirec['text']
350
+ if block_label == 'table':
351
+ unirec_res = markdown_converter._handle_table(unirec_res)
352
+ elif 'formula' in block_label and block_label != 'formula_number':
353
+ unirec_res = markdown_converter._handle_formula(unirec_res)
354
+ else:
355
+ unirec_res = markdown_converter._handle_text(unirec_res)
356
+ vl_rec_results.append(unirec_res)
357
+
358
+ parsing_res_lists = []
359
+ table_res_lists = []
360
+ curr_vlm_block_idx = 0
361
+ for i, blocks_for_img in enumerate(blocks):
362
+ parsing_res_list = []
363
+ table_res_list = []
364
+ for j, block in enumerate(blocks_for_img):
365
+ block_img = block['img']
366
+ block_bbox = block['box']
367
+ block_label = block['label']
368
+ block_content = ''
369
+ if curr_vlm_block_idx < len(vlm_block_ids) and vlm_block_ids[
370
+ curr_vlm_block_idx] == (i, j):
371
+ result_str = vl_rec_results[curr_vlm_block_idx]
372
+ figure_token_map = figure_token_maps[curr_vlm_block_idx]
373
+ block_img4vl = block_imgs[curr_vlm_block_idx]
374
+ curr_vlm_block_idx += 1
375
+ if result_str is None:
376
+ result_str = ''
377
+ result_str = truncate_repetitive_content(result_str)
378
+ if ('\\(' in result_str and '\\)' in result_str) or (
379
+ '\\[' in result_str and '\\]' in result_str):
380
+ result_str = result_str.replace('$', '')
381
+
382
+ result_str = (result_str.replace('\\(', ' $ ').replace(
383
+ '\\)',
384
+ ' $ ').replace('\\[',
385
+ ' $$ ').replace('\\]', ' $$ '))
386
+ if block_label == 'formula_number':
387
+ result_str = result_str.replace('$', '')
388
+ if block_label == 'table':
389
+ html_str = convert_otsl_to_html(result_str)
390
+ if html_str != '':
391
+ result_str = html_str
392
+ result_str = untokenize_figure_of_table(
393
+ result_str, figure_token_map)
394
+
395
+ block_content = result_str
396
+
397
+ block_info = PaddleOCRVLBlock(
398
+ label=block_label,
399
+ bbox=block_bbox,
400
+ content=block_content,
401
+ group_id=block.get('group_id', None),
402
+ )
403
+ if block_label in image_labels and block_img is not None:
404
+ x_min, y_min, x_max, y_max = list(map(int, block_bbox))
405
+ img_path = f'imgs/img_in_{block_label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg'
406
+ if img_path not in drop_figures_set:
407
+ block_img = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
408
+ block_info.image = {
409
+ 'path': img_path,
410
+ 'img': Image.fromarray(block_img),
411
+ }
412
+ else:
413
+ continue
414
+
415
+ parsing_res_list.append(block_info)
416
+ parsing_res_lists.append(parsing_res_list)
417
+ table_res_lists.append(table_res_list)
418
+
419
+ return parsing_res_lists, table_res_lists, imgs_in_doc
420
+
421
+ def predict(
422
+ self,
423
+ input: Union[str, list[str], np.ndarray, list[np.ndarray]],
424
+ use_doc_orientation_classify: Union[bool, None] = False,
425
+ use_doc_unwarping: Union[bool, None] = False,
426
+ use_layout_detection: Union[bool, None] = None,
427
+ use_chart_recognition: Union[bool, None] = None,
428
+ layout_threshold: Optional[Union[float, dict]] = None,
429
+ layout_nms: Optional[bool] = None,
430
+ layout_unclip_ratio: Optional[Union[float, Tuple[float, float],
431
+ dict]] = None,
432
+ layout_merge_bboxes_mode: Optional[str] = None,
433
+ use_queues: Optional[bool] = None,
434
+ prompt_label: Optional[Union[str, None]] = None,
435
+ format_block_content: Union[bool, None] = None,
436
+ repetition_penalty: Optional[float] = None,
437
+ temperature: Optional[float] = None,
438
+ top_p: Optional[float] = None,
439
+ min_pixels: Optional[int] = None,
440
+ max_pixels: Optional[int] = None,
441
+ max_new_tokens: Optional[int] = None,
442
+ merge_layout_blocks: Optional[bool] = None,
443
+ markdown_ignore_labels: Optional[list[str]] = None,
444
+ **kwargs,
445
+ ) -> PaddleOCRVLResult:
446
+ """
447
+ Predicts the layout parsing result for the given input.
448
+
449
+ Args:
450
+ input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
451
+ numpy array of an image, or list of numpy arrays.
452
+ use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
453
+ use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
454
+ layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
455
+ layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
456
+ layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
457
+ Defaults to None.
458
+ If it's a single number, then both width and height are used.
459
+ If it's a tuple of two numbers, then they are used separately for width and height respectively.
460
+ If it's None, then no unclipping will be performed.
461
+ layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
462
+ use_queues (Optional[bool], optional): Whether to use queues. Defaults to None.
463
+ prompt_label (Optional[Union[str, None]], optional): The label of the prompt in ['ocr', 'formula', 'table', 'chart']. Defaults to None.
464
+ format_block_content (Optional[bool]): Whether to format the block content. Default is None.
465
+ repetition_penalty (Optional[float]): The repetition penalty parameter used for VL model sampling. Default is None.
466
+ temperature (Optional[float]): Temperature parameter used for VL model sampling. Default is None.
467
+ top_p (Optional[float]): Top-p parameter used for VL model sampling. Default is None.
468
+ min_pixels (Optional[int]): The minimum number of pixels allowed when the VL model preprocesses images. Default is None.
469
+ max_pixels (Optional[int]): The maximum number of pixels allowed when the VL model preprocesses images. Default is None.
470
+ max_new_tokens (Optional[int]): The maximum number of new tokens. Default is None.
471
+ merge_layout_blocks (Optional[bool]): Whether to merge layout blocks. Default is None.
472
+ markdown_ignore_labels (Optional[list[str]]): The list of ignored markdown labels. Default is None.
473
+ **kwargs (Any): Additional settings to extend functionality.
474
+
475
+ Returns:
476
+ PaddleOCRVLResult: The predicted layout parsing result.
477
+ """
478
+ model_settings = self.get_model_settings(
479
+ use_doc_orientation_classify,
480
+ use_doc_unwarping,
481
+ use_layout_detection,
482
+ use_chart_recognition,
483
+ format_block_content,
484
+ merge_layout_blocks,
485
+ markdown_ignore_labels,
486
+ )
487
+
488
+ if not self.check_model_settings_valid(model_settings):
489
+ yield {'error': 'the input params for model settings are invalid!'}
490
+
491
+ if use_queues is None:
492
+ use_queues = self.use_queues
493
+
494
+ if not model_settings['use_layout_detection']:
495
+ prompt_label = prompt_label if prompt_label else 'ocr'
496
+ if prompt_label.lower() == 'chart':
497
+ model_settings['use_chart_recognition'] = True
498
+ assert prompt_label.lower() in [
499
+ 'ocr',
500
+ 'formula',
501
+ 'table',
502
+ 'chart',
503
+ ], f"Layout detection is disabled (use_layout_detection=False). 'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], but got '{prompt_label}'."
504
+
505
+ def _process_cv(batch_data, new_batch_size=None):
506
+ if not new_batch_size:
507
+ new_batch_size = len(batch_data)
508
+
509
+ for idx in range(0, len(batch_data), new_batch_size):
510
+ instances = batch_data.instances[idx:idx + new_batch_size]
511
+ input_paths = batch_data.input_paths[idx:idx + new_batch_size]
512
+ page_indexes = batch_data.page_indexes[idx:idx +
513
+ new_batch_size]
514
+ page_counts = batch_data.page_counts[idx:idx + new_batch_size]
515
+
516
+ image_arrays = self.img_reader(instances)
517
+
518
+ if model_settings['use_doc_preprocessor']:
519
+ doc_preprocessor_results = list(
520
+ self.doc_preprocessor_pipeline(
521
+ image_arrays,
522
+ use_doc_orientation_classify=
523
+ use_doc_orientation_classify,
524
+ use_doc_unwarping=use_doc_unwarping,
525
+ ))
526
+ else:
527
+ doc_preprocessor_results = [{
528
+ 'output_img': arr
529
+ } for arr in image_arrays]
530
+
531
+ doc_preprocessor_images = [
532
+ item['output_img'] for item in doc_preprocessor_results
533
+ ]
534
+
535
+ if model_settings['use_layout_detection']:
536
+ layout_det_results = list(
537
+ self.layout_det_model(
538
+ doc_preprocessor_images,
539
+ threshold=layout_threshold,
540
+ layout_nms=layout_nms,
541
+ layout_unclip_ratio=layout_unclip_ratio,
542
+ layout_merge_bboxes_mode=layout_merge_bboxes_mode,
543
+ ))
544
+
545
+ imgs_in_doc = [
546
+ gather_imgs(doc_pp_img, layout_det_res['boxes'])
547
+ for doc_pp_img, layout_det_res in zip(
548
+ doc_preprocessor_images, layout_det_results)
549
+ ]
550
+ else:
551
+ layout_det_results = []
552
+ for doc_preprocessor_image in doc_preprocessor_images:
553
+ layout_det_results.append({
554
+ 'input_path':
555
+ None,
556
+ 'page_index':
557
+ None,
558
+ 'boxes': [{
559
+ 'cls_id':
560
+ 0,
561
+ 'label':
562
+ prompt_label.lower(),
563
+ 'score':
564
+ 1,
565
+ 'coordinate': [
566
+ 0,
567
+ 0,
568
+ doc_preprocessor_image.shape[1],
569
+ doc_preprocessor_image.shape[0],
570
+ ],
571
+ }],
572
+ })
573
+ imgs_in_doc = [[] for _ in layout_det_results]
574
+
575
+ yield input_paths, page_indexes, page_counts, doc_preprocessor_images, doc_preprocessor_results, layout_det_results, imgs_in_doc
576
+
577
+ def _process_vlm(results_cv):
578
+ (
579
+ input_paths,
580
+ page_indexes,
581
+ page_counts,
582
+ doc_preprocessor_images,
583
+ doc_preprocessor_results,
584
+ layout_det_results,
585
+ imgs_in_doc,
586
+ ) = results_cv
587
+
588
+ parsing_res_lists, table_res_lists, imgs_in_doc = (
589
+ self.get_layout_parsing_results(
590
+ doc_preprocessor_images,
591
+ layout_det_results,
592
+ imgs_in_doc,
593
+ model_settings['use_chart_recognition'],
594
+ {
595
+ 'repetition_penalty': repetition_penalty,
596
+ 'temperature': temperature,
597
+ 'top_p': top_p,
598
+ 'min_pixels': min_pixels,
599
+ 'max_pixels': max_pixels,
600
+ 'max_new_tokens': max_new_tokens,
601
+ },
602
+ model_settings['merge_layout_blocks'],
603
+ ))
604
+
605
+ for (
606
+ input_path,
607
+ page_index,
608
+ page_count,
609
+ doc_preprocessor_image,
610
+ doc_preprocessor_res,
611
+ layout_det_res,
612
+ table_res_list,
613
+ parsing_res_list,
614
+ imgs_in_doc_for_img,
615
+ ) in zip(
616
+ input_paths,
617
+ page_indexes,
618
+ page_counts,
619
+ doc_preprocessor_images,
620
+ doc_preprocessor_results,
621
+ layout_det_results,
622
+ table_res_lists,
623
+ parsing_res_lists,
624
+ imgs_in_doc,
625
+ ):
626
+ single_img_res = {
627
+ 'input_path': input_path,
628
+ 'page_index': page_index,
629
+ 'page_count': page_count,
630
+ 'width': doc_preprocessor_image.shape[1],
631
+ 'height': doc_preprocessor_image.shape[0],
632
+ 'doc_preprocessor_res': doc_preprocessor_res,
633
+ 'layout_det_res': layout_det_res,
634
+ 'table_res_list': table_res_list,
635
+ 'parsing_res_list': parsing_res_list,
636
+ 'imgs_in_doc': imgs_in_doc_for_img,
637
+ 'model_settings': model_settings,
638
+ }
639
+ yield PaddleOCRVLResult(single_img_res)
640
+
641
+ if use_queues:
642
+ max_num_batches_in_process = 64
643
+ queue_input = queue.Queue(maxsize=max_num_batches_in_process)
644
+ queue_cv = queue.Queue(maxsize=max_num_batches_in_process)
645
+ queue_vlm = queue.Queue(maxsize=self.batch_sampler.batch_size *
646
+ max_num_batches_in_process)
647
+ event_shutdown = threading.Event()
648
+ event_data_loading_done = threading.Event()
649
+ event_cv_processing_done = threading.Event()
650
+ event_vlm_processing_done = threading.Event()
651
+
652
+ def _worker_input(input_):
653
+ all_batch_data = self.batch_sampler(input_)
654
+ while not event_shutdown.is_set():
655
+ try:
656
+ batch_data = next(all_batch_data)
657
+ except StopIteration:
658
+ break
659
+ except Exception as e:
660
+ queue_input.put((False, 'input', e))
661
+ break
662
+ else:
663
+ queue_input.put((True, batch_data))
664
+ event_data_loading_done.set()
665
+
666
+ def _worker_cv():
667
+ while not event_shutdown.is_set():
668
+ try:
669
+ item = queue_input.get(timeout=0.5)
670
+ except queue.Empty:
671
+ if event_data_loading_done.is_set():
672
+ event_cv_processing_done.set()
673
+ break
674
+ continue
675
+ if not item[0]:
676
+ queue_cv.put(item)
677
+ break
678
+ try:
679
+ for results_cv in _process_cv(
680
+ item[1],
681
+ (self.layout_det_model.batch_sampler.batch_size if
682
+ model_settings['use_layout_detection'] else None),
683
+ ):
684
+ queue_cv.put((True, results_cv))
685
+ except Exception as e:
686
+ queue_cv.put((False, 'cv', e))
687
+ break
688
+
689
+ def _worker_vlm():
690
+ MAX_QUEUE_DELAY_SECS = 0.5
691
+ MAX_NUM_BOXES = 4096 #self.vl_rec_model.batch_sampler.batch_size
692
+
693
+ while not event_shutdown.is_set():
694
+ results_cv_list = []
695
+ start_time = time.time()
696
+ should_break = False
697
+ num_boxes = 0
698
+ while True:
699
+ remaining_time = MAX_QUEUE_DELAY_SECS - (time.time() -
700
+ start_time)
701
+ if remaining_time <= 0:
702
+ break
703
+ try:
704
+ item = queue_cv.get(timeout=remaining_time)
705
+ except queue.Empty:
706
+ break
707
+ if not item[0]:
708
+ queue_vlm.put(item)
709
+ should_break = True
710
+ break
711
+ results_cv_list.append(item[1])
712
+ for res in results_cv_list[-1][5]:
713
+ num_boxes += len(res['boxes'])
714
+ if num_boxes >= MAX_NUM_BOXES:
715
+ break
716
+ if should_break:
717
+ break
718
+ if not results_cv_list:
719
+ if event_cv_processing_done.is_set():
720
+ event_vlm_processing_done.set()
721
+ break
722
+ continue
723
+
724
+ merged_results_cv = [
725
+ list(chain.from_iterable(lists))
726
+ for lists in zip(*results_cv_list)
727
+ ]
728
+
729
+ try:
730
+ for result_vlm in _process_vlm(merged_results_cv):
731
+ queue_vlm.put((True, result_vlm))
732
+ except Exception as e:
733
+ queue_vlm.put((False, 'vlm', e))
734
+ break
735
+
736
+ thread_input = threading.Thread(target=_worker_input,
737
+ args=(input, ),
738
+ daemon=False)
739
+ thread_input.start()
740
+ thread_cv = threading.Thread(target=_worker_cv, daemon=False)
741
+ thread_cv.start()
742
+ thread_vlm = threading.Thread(target=_worker_vlm, daemon=False)
743
+ thread_vlm.start()
744
+
745
+ try:
746
+ if use_queues:
747
+ while not (event_vlm_processing_done.is_set()
748
+ and queue_vlm.empty()):
749
+ try:
750
+ item = queue_vlm.get(timeout=0.5)
751
+ except queue.Empty:
752
+ if event_vlm_processing_done.is_set():
753
+ break
754
+ continue
755
+ if not item[0]:
756
+ raise RuntimeError(
757
+ f"Exception from the '{item[1]}' worker: {item[2]}"
758
+ )
759
+ else:
760
+ yield item[1]
761
+ else:
762
+ for batch_data in self.batch_sampler(input):
763
+ results_cv_list = list(_process_cv(batch_data))
764
+ assert len(results_cv_list) == 1, len(results_cv_list)
765
+ results_cv = results_cv_list[0]
766
+ for res in _process_vlm(results_cv):
767
+ yield res
768
+ finally:
769
+ if use_queues:
770
+ event_shutdown.set()
771
+ thread_cv.join(timeout=5)
772
+ if thread_cv.is_alive():
773
+ logging.warning('CV worker did not terminate in time')
774
+ thread_vlm.join(timeout=5)
775
+ if thread_vlm.is_alive():
776
+ logging.warning('VLM worker did not terminate in time')
777
+
778
+ def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
779
+ """
780
+ Concatenate Markdown content from multiple pages into a single document.
781
+
782
+ Args:
783
+ markdown_list (list): A list containing Markdown data for each page.
784
+
785
+ Returns:
786
+ tuple: A tuple containing the processed Markdown text.
787
+ """
788
+ markdown_texts = ''
789
+
790
+ for res in markdown_list:
791
+ markdown_texts += '\n\n' + res['markdown_texts']
792
+
793
+ return markdown_texts
794
+
795
+
796
+ def process_batch(process_batch_list: list, gpuId: int, output_path: str,
797
+ is_save_vis_img: bool, is_save_json: bool,
798
+ is_save_markdown: bool, pretty: bool):
799
+
800
+ opendoc_pipeline = OpenDoc(gpuId=gpuId)
801
+ for img_path in process_batch_list:
802
+ img_name = os.path.basename(img_path)[:-4]
803
+ output = opendoc_pipeline.predict(img_path,
804
+ use_doc_orientation_classify=False,
805
+ use_doc_unwarping=False)
806
+ for res in output:
807
+ if is_save_vis_img:
808
+ res.save_to_img(output_path)
809
+ if is_save_json:
810
+ res.save_to_json(save_path=output_path) ## 保存当前图像的结构化json结果
811
+ if is_save_markdown:
812
+ res.save_to_markdown(save_path=output_path,
813
+ pretty=pretty) ## 保存当前图像的markdown格式的结果
814
+ if gpuId >= 0:
815
+ logger.info(f'GPU {gpuId} processing {img_name} done!')
816
+ else:
817
+ logger.info(f'CPU processing {img_name} done!')
818
+
819
+
820
+ if __name__ == '__main__':
821
+
822
+ parser = argparse.ArgumentParser(
823
+ description='OpenDoc Pipeline for Document OCR')
824
+ parser.add_argument(
825
+ '--input_path',
826
+ type=str,
827
+ required=True,
828
+ help='Path to the directory containing images or pdf files to process')
829
+ parser.add_argument('--output_path',
830
+ type=str,
831
+ default='./output',
832
+ help='Path to save output results (default: ./output)')
833
+ parser.add_argument(
834
+ '--gpus',
835
+ type=str,
836
+ default='0',
837
+ help=
838
+ 'GPU IDs to use, separated by comma (e.g., "0,1,2,3"). Use "-1" for CPU mode (default: "0")'
839
+ )
840
+ parser.add_argument(
841
+ '--is_save_vis_img',
842
+ action='store_true',
843
+ help='Save visualized images with layout detection boxes')
844
+ parser.add_argument('--is_save_json',
845
+ action='store_true',
846
+ help='Save JSON results')
847
+ parser.add_argument('--is_save_markdown',
848
+ action='store_true',
849
+ help='Save Markdown results')
850
+ parser.add_argument('--pretty',
851
+ action='store_true',
852
+ help='Pretty print Markdown results')
853
+
854
+ args = parser.parse_args()
855
+
856
+ # Parse GPU IDs
857
+ gpus = [int(gpu_id.strip()) for gpu_id in args.gpus.split(',')]
858
+
859
+ # Create output directory if not exists
860
+ os.makedirs(args.output_path, exist_ok=True)
861
+
862
+ # Get image list
863
+ img_list = get_image_file_list(args.input_path)
864
+ random.shuffle(img_list)
865
+
866
+ logger.info(f'Found {len(img_list)} images in {args.input_path}')
867
+ if gpus[0] == -1:
868
+ logger.info('Running in CPU mode')
869
+ else:
870
+ logger.info(f'Using GPUs: {gpus}')
871
+ logger.info(f'Output will be saved to: {args.output_path}')
872
+
873
+ # Split images into batches for each GPU
874
+ if len(gpus) == 1:
875
+ process_batch(img_list, gpus[0], args.output_path,
876
+ args.is_save_vis_img, args.is_save_json,
877
+ args.is_save_markdown, args.pretty)
878
+ else:
879
+ num_gpus = len(gpus)
880
+ img_list_batch = [img_list[i::num_gpus] for i in range(num_gpus)]
881
+
882
+ # Create and start processes
883
+ process_list = []
884
+ for idx, gpuId in enumerate(gpus):
885
+ process_list.append(
886
+ Process(target=process_batch,
887
+ args=(img_list_batch[idx], gpuId, args.output_path,
888
+ args.is_save_vis_img, args.is_save_json,
889
+ args.is_save_markdown, args.pretty)))
890
+
891
+ for process in process_list:
892
+ process.start()
893
+
894
+ for process in process_list:
895
+ process.join()
896
+
897
+ logger.info(
898
+ f'All processing completed! Results saved to {args.output_path}')