openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,898 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
from __future__ import division
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
10
|
+
sys.path.append(__dir__)
|
|
11
|
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
|
12
|
+
|
|
13
|
+
import queue
|
|
14
|
+
import threading
|
|
15
|
+
import time
|
|
16
|
+
from itertools import chain
|
|
17
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
18
|
+
import cv2
|
|
19
|
+
import numpy as np
|
|
20
|
+
from PIL import Image
|
|
21
|
+
import random
|
|
22
|
+
import argparse
|
|
23
|
+
from multiprocessing import Process
|
|
24
|
+
|
|
25
|
+
from paddlex.utils import logging
|
|
26
|
+
from paddlex.inference.common.batch_sampler import ImageBatchSampler
|
|
27
|
+
from paddlex.inference.common.reader import ReadImage
|
|
28
|
+
from paddlex.inference.utils.hpi import HPIConfig
|
|
29
|
+
from paddlex.inference.utils.pp_option import PaddlePredictorOption
|
|
30
|
+
from paddlex.inference.pipelines import load_pipeline_config
|
|
31
|
+
from paddlex.inference.pipelines.base import BasePipeline
|
|
32
|
+
from paddlex.inference.pipelines.components import CropByBoxes
|
|
33
|
+
from paddlex.inference.pipelines.layout_parsing.utils import gather_imgs
|
|
34
|
+
from paddlex.inference.pipelines.paddleocr_vl.result import PaddleOCRVLBlock, PaddleOCRVLResult
|
|
35
|
+
from paddlex.inference.pipelines.paddleocr_vl.uilts import (
|
|
36
|
+
convert_otsl_to_html,
|
|
37
|
+
crop_margin,
|
|
38
|
+
filter_overlap_boxes,
|
|
39
|
+
merge_blocks,
|
|
40
|
+
tokenize_figure_of_table,
|
|
41
|
+
truncate_repetitive_content,
|
|
42
|
+
untokenize_figure_of_table,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
from tools.engine.config import Config
|
|
46
|
+
from tools.utils.logging import get_logger
|
|
47
|
+
from tools.utils.utility import get_image_file_list
|
|
48
|
+
from tools.infer_rec import OpenRecognizer
|
|
49
|
+
from tools.to_markdown import MarkdownConverter
|
|
50
|
+
|
|
51
|
+
logger = get_logger(name='opendoc')
|
|
52
|
+
|
|
53
|
+
root_dir = Path(__file__).resolve().parent
|
|
54
|
+
DEFAULT_CFG_PATH_UNIREC = str(
|
|
55
|
+
root_dir / '../configs/rec/unirec/focalsvtr_ardecoder_unirec.yml')
|
|
56
|
+
|
|
57
|
+
IMAGE_LABELS = ['image', 'header_image', 'footer_image', 'seal']
|
|
58
|
+
|
|
59
|
+
markdown_converter = MarkdownConverter()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class OpenDoc(BasePipeline):
|
|
63
|
+
"""_UniRec Pipeline"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
gpuId: Optional[int] = 0,
|
|
68
|
+
pp_option: Optional[PaddlePredictorOption] = None,
|
|
69
|
+
use_hpip: bool = False,
|
|
70
|
+
hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Initializes the class with given configurations and options.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
config (Dict): Configuration dictionary containing various settings.
|
|
77
|
+
gpuId (int, optional): GPU ID to run the predictions on. Defaults to 0.
|
|
78
|
+
pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
|
|
79
|
+
use_hpip (bool, optional): Whether to use the high-performance
|
|
80
|
+
inference plugin (HPIP) by default. Defaults to False.
|
|
81
|
+
hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
|
|
82
|
+
The default high-performance inference configuration dictionary.
|
|
83
|
+
Defaults to None.
|
|
84
|
+
"""
|
|
85
|
+
if gpuId < 0:
|
|
86
|
+
device = None
|
|
87
|
+
else:
|
|
88
|
+
device = f'gpu:{gpuId}'
|
|
89
|
+
super().__init__(device=device,
|
|
90
|
+
pp_option=pp_option,
|
|
91
|
+
use_hpip=use_hpip,
|
|
92
|
+
hpi_config=hpi_config)
|
|
93
|
+
config = load_pipeline_config(
|
|
94
|
+
str(root_dir / '../configs/rec/unirec/opendoc_pipeline.yml'))
|
|
95
|
+
|
|
96
|
+
self.use_doc_preprocessor = config.get('use_doc_preprocessor', True)
|
|
97
|
+
if self.use_doc_preprocessor:
|
|
98
|
+
doc_preprocessor_config = config.get('SubPipelines', {}).get(
|
|
99
|
+
'DocPreprocessor',
|
|
100
|
+
{
|
|
101
|
+
'pipeline_config_error':
|
|
102
|
+
'config error for doc_preprocessor_pipeline!'
|
|
103
|
+
},
|
|
104
|
+
)
|
|
105
|
+
self.doc_preprocessor_pipeline = self.create_pipeline(
|
|
106
|
+
doc_preprocessor_config)
|
|
107
|
+
|
|
108
|
+
self.use_layout_detection = config.get('use_layout_detection', True)
|
|
109
|
+
if self.use_layout_detection:
|
|
110
|
+
layout_det_config = config.get('SubModules', {}).get(
|
|
111
|
+
'LayoutDetection',
|
|
112
|
+
{'model_config_error': 'config error for layout_det_model!'},
|
|
113
|
+
)
|
|
114
|
+
model_name = layout_det_config.get('model_name', None)
|
|
115
|
+
assert (model_name is not None and model_name
|
|
116
|
+
== 'PP-DocLayoutV2'), 'model_name must be PP-DocLayoutV2'
|
|
117
|
+
layout_kwargs = {}
|
|
118
|
+
if (threshold := layout_det_config.get('threshold',
|
|
119
|
+
None)) is not None:
|
|
120
|
+
layout_kwargs['threshold'] = threshold
|
|
121
|
+
if (layout_nms := layout_det_config.get('layout_nms',
|
|
122
|
+
None)) is not None:
|
|
123
|
+
layout_kwargs['layout_nms'] = layout_nms
|
|
124
|
+
if (layout_unclip_ratio :=
|
|
125
|
+
layout_det_config.get('layout_unclip_ratio',
|
|
126
|
+
None)) is not None:
|
|
127
|
+
layout_kwargs['layout_unclip_ratio'] = layout_unclip_ratio
|
|
128
|
+
if (layout_merge_bboxes_mode :=
|
|
129
|
+
layout_det_config.get('layout_merge_bboxes_mode',
|
|
130
|
+
None)) is not None:
|
|
131
|
+
layout_kwargs[
|
|
132
|
+
'layout_merge_bboxes_mode'] = layout_merge_bboxes_mode
|
|
133
|
+
self.layout_det_model = self.create_model(layout_det_config,
|
|
134
|
+
**layout_kwargs)
|
|
135
|
+
|
|
136
|
+
self.use_chart_recognition = config.get('use_chart_recognition', True)
|
|
137
|
+
|
|
138
|
+
unirec_cfg = Config(DEFAULT_CFG_PATH_UNIREC).cfg
|
|
139
|
+
if unirec_cfg['Global'][
|
|
140
|
+
'pretrained_model'] is None or not os.path.exists(
|
|
141
|
+
unirec_cfg['Global']['pretrained_model']):
|
|
142
|
+
cache_dir = Path.home() / '.cache' / 'openocr'
|
|
143
|
+
model_path = cache_dir / 'unirec-0.1b'
|
|
144
|
+
# modelscope download model
|
|
145
|
+
download_flag = False
|
|
146
|
+
if not os.path.exists(str(model_path) + '/model.pth'):
|
|
147
|
+
try:
|
|
148
|
+
from modelscope.hub.snapshot_download import snapshot_download
|
|
149
|
+
model_dir = snapshot_download(
|
|
150
|
+
'topdktu/unirec-0.1b',
|
|
151
|
+
local_dir=model_path,
|
|
152
|
+
)
|
|
153
|
+
if os.path.exists(str(model_path) + '/model.pth'):
|
|
154
|
+
download_flag = True
|
|
155
|
+
except:
|
|
156
|
+
logger.error(
|
|
157
|
+
'Try to download the model from modelscope failed.')
|
|
158
|
+
|
|
159
|
+
if not download_flag:
|
|
160
|
+
try:
|
|
161
|
+
from huggingface_hub import snapshot_download
|
|
162
|
+
model_dir = snapshot_download(
|
|
163
|
+
'topdu/unirec-0.1b',
|
|
164
|
+
local_dir=model_path,
|
|
165
|
+
)
|
|
166
|
+
if os.path.exists(str(model_path) + '/model.pth'):
|
|
167
|
+
download_flag = True
|
|
168
|
+
except:
|
|
169
|
+
logger.error(
|
|
170
|
+
'Try to download the model from huggingface failed.'
|
|
171
|
+
)
|
|
172
|
+
if not download_flag:
|
|
173
|
+
raise ImportError(
|
|
174
|
+
'Please download the model from https://huggingface.co/topdu/unirec-0.1b or https://modelscope.cn/models/topdktu/unirec-0.1b and put the model in the directory: ~/.cache/openocr/unirec-0.1b'
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
unirec_cfg['Global']['pretrained_model'] = str(
|
|
178
|
+
model_path) + '/model.pth'
|
|
179
|
+
else:
|
|
180
|
+
logger.info(
|
|
181
|
+
f'UniRec-0.1B Model already exists in {model_path}')
|
|
182
|
+
unirec_cfg['Global']['pretrained_model'] = str(
|
|
183
|
+
model_path) + '/model.pth'
|
|
184
|
+
|
|
185
|
+
self.vl_rec_model = OpenRecognizer(unirec_cfg, numId=gpuId)
|
|
186
|
+
self.format_block_content = config.get('format_block_content', False)
|
|
187
|
+
|
|
188
|
+
self.batch_sampler = ImageBatchSampler(
|
|
189
|
+
batch_size=config.get('batch_size', 1))
|
|
190
|
+
self.img_reader = ReadImage(format='BGR')
|
|
191
|
+
self.crop_by_boxes = CropByBoxes()
|
|
192
|
+
|
|
193
|
+
self.use_queues = config.get('use_queues', False)
|
|
194
|
+
self.merge_layout_blocks = config.get('merge_layout_blocks', True)
|
|
195
|
+
self.markdown_ignore_labels = config.get(
|
|
196
|
+
'markdown_ignore_labels',
|
|
197
|
+
[
|
|
198
|
+
'number',
|
|
199
|
+
'footnote',
|
|
200
|
+
'header',
|
|
201
|
+
'header_image',
|
|
202
|
+
'footer',
|
|
203
|
+
'footer_image',
|
|
204
|
+
'aside_text',
|
|
205
|
+
],
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def get_model_settings(
|
|
209
|
+
self,
|
|
210
|
+
use_doc_orientation_classify: Union[bool, None],
|
|
211
|
+
use_doc_unwarping: Union[bool, None],
|
|
212
|
+
use_layout_detection: Union[bool, None],
|
|
213
|
+
use_chart_recognition: Union[bool, None],
|
|
214
|
+
format_block_content: Union[bool, None],
|
|
215
|
+
merge_layout_blocks: Union[bool, None],
|
|
216
|
+
markdown_ignore_labels: Optional[list[str]] = None,
|
|
217
|
+
) -> dict:
|
|
218
|
+
"""
|
|
219
|
+
Get the model settings based on the provided parameters or default values.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
use_doc_orientation_classify (Union[bool, None]): Enables document orientation classification if True. Defaults to system setting if None.
|
|
223
|
+
use_doc_unwarping (Union[bool, None]): Enables document unwarping if True. Defaults to system setting if None.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
dict: A dictionary containing the model settings.
|
|
227
|
+
|
|
228
|
+
"""
|
|
229
|
+
if use_doc_orientation_classify is None and use_doc_unwarping is None:
|
|
230
|
+
use_doc_preprocessor = self.use_doc_preprocessor
|
|
231
|
+
else:
|
|
232
|
+
if use_doc_orientation_classify is True or use_doc_unwarping is True:
|
|
233
|
+
use_doc_preprocessor = True
|
|
234
|
+
else:
|
|
235
|
+
use_doc_preprocessor = False
|
|
236
|
+
|
|
237
|
+
if use_layout_detection is None:
|
|
238
|
+
use_layout_detection = self.use_layout_detection
|
|
239
|
+
|
|
240
|
+
if use_chart_recognition is None:
|
|
241
|
+
use_chart_recognition = self.use_chart_recognition
|
|
242
|
+
|
|
243
|
+
if format_block_content is None:
|
|
244
|
+
format_block_content = self.format_block_content
|
|
245
|
+
|
|
246
|
+
if merge_layout_blocks is None:
|
|
247
|
+
merge_layout_blocks = self.merge_layout_blocks
|
|
248
|
+
|
|
249
|
+
if markdown_ignore_labels is None:
|
|
250
|
+
markdown_ignore_labels = self.markdown_ignore_labels
|
|
251
|
+
|
|
252
|
+
return dict(
|
|
253
|
+
use_doc_preprocessor=use_doc_preprocessor,
|
|
254
|
+
use_layout_detection=use_layout_detection,
|
|
255
|
+
use_chart_recognition=use_chart_recognition,
|
|
256
|
+
format_block_content=format_block_content,
|
|
257
|
+
merge_layout_blocks=merge_layout_blocks,
|
|
258
|
+
markdown_ignore_labels=markdown_ignore_labels,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def check_model_settings_valid(self, input_params: dict) -> bool:
|
|
262
|
+
"""
|
|
263
|
+
Check if the input parameters are valid based on the initialized models.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
input_params (Dict): A dictionary containing input parameters.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
bool: True if all required models are initialized according to input parameters, False otherwise.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
if input_params[
|
|
273
|
+
'use_doc_preprocessor'] and not self.use_doc_preprocessor:
|
|
274
|
+
logging.error(
|
|
275
|
+
'Set use_doc_preprocessor, but the models for doc preprocessor are not initialized.',
|
|
276
|
+
)
|
|
277
|
+
return False
|
|
278
|
+
|
|
279
|
+
return True
|
|
280
|
+
|
|
281
|
+
def get_layout_parsing_results(
|
|
282
|
+
self,
|
|
283
|
+
images,
|
|
284
|
+
layout_det_results,
|
|
285
|
+
imgs_in_doc,
|
|
286
|
+
use_chart_recognition=False,
|
|
287
|
+
vlm_kwargs=None,
|
|
288
|
+
merge_layout_blocks=True,
|
|
289
|
+
):
|
|
290
|
+
blocks = []
|
|
291
|
+
block_imgs = []
|
|
292
|
+
text_prompts = []
|
|
293
|
+
block_labels = []
|
|
294
|
+
vlm_block_ids = []
|
|
295
|
+
figure_token_maps = []
|
|
296
|
+
drop_figures_set = set()
|
|
297
|
+
image_labels = (IMAGE_LABELS
|
|
298
|
+
if use_chart_recognition else IMAGE_LABELS + ['chart'])
|
|
299
|
+
for i, (image, layout_det_res, imgs_in_doc_for_img) in enumerate(
|
|
300
|
+
zip(images, layout_det_results, imgs_in_doc)):
|
|
301
|
+
layout_det_res = filter_overlap_boxes(layout_det_res)
|
|
302
|
+
boxes = layout_det_res['boxes']
|
|
303
|
+
blocks_for_img = self.crop_by_boxes(image, boxes)
|
|
304
|
+
if merge_layout_blocks:
|
|
305
|
+
blocks_for_img = merge_blocks(blocks_for_img,
|
|
306
|
+
non_merge_labels=image_labels +
|
|
307
|
+
['table'])
|
|
308
|
+
blocks.append(blocks_for_img)
|
|
309
|
+
for j, block in enumerate(blocks_for_img):
|
|
310
|
+
block_img = block['img']
|
|
311
|
+
block_label = block['label']
|
|
312
|
+
|
|
313
|
+
if block_label not in image_labels and block_img is not None:
|
|
314
|
+
figure_token_map = {}
|
|
315
|
+
text_prompt = 'OCR:'
|
|
316
|
+
drop_figures = []
|
|
317
|
+
if block_label == 'table':
|
|
318
|
+
text_prompt = 'Table Recognition:'
|
|
319
|
+
block_img, figure_token_map, drop_figures = (
|
|
320
|
+
tokenize_figure_of_table(block_img, block['box'],
|
|
321
|
+
imgs_in_doc_for_img))
|
|
322
|
+
elif block_label == 'chart' and use_chart_recognition:
|
|
323
|
+
text_prompt = 'Chart Recognition:'
|
|
324
|
+
elif 'formula' in block_label and block_label != 'formula_number':
|
|
325
|
+
text_prompt = 'Formula Recognition:'
|
|
326
|
+
block_img = crop_margin(block_img)
|
|
327
|
+
block_imgs.append(block_img)
|
|
328
|
+
text_prompts.append(text_prompt)
|
|
329
|
+
block_labels.append(block_label)
|
|
330
|
+
figure_token_maps.append(figure_token_map)
|
|
331
|
+
vlm_block_ids.append((i, j))
|
|
332
|
+
drop_figures_set.update(drop_figures)
|
|
333
|
+
|
|
334
|
+
if vlm_kwargs is None:
|
|
335
|
+
vlm_kwargs = {}
|
|
336
|
+
elif vlm_kwargs.get('max_new_tokens', None) is None:
|
|
337
|
+
vlm_kwargs['max_new_tokens'] = 4096
|
|
338
|
+
|
|
339
|
+
kwargs = {
|
|
340
|
+
'use_cache': True,
|
|
341
|
+
**vlm_kwargs,
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
vl_rec_results = []
|
|
345
|
+
for block_img, block_label in zip(block_imgs, block_labels):
|
|
346
|
+
block_img_forunirec = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
|
|
347
|
+
output_unirec = self.vl_rec_model(
|
|
348
|
+
img_numpy=Image.fromarray(block_img_forunirec), batch_num=1)[0]
|
|
349
|
+
unirec_res = output_unirec['text']
|
|
350
|
+
if block_label == 'table':
|
|
351
|
+
unirec_res = markdown_converter._handle_table(unirec_res)
|
|
352
|
+
elif 'formula' in block_label and block_label != 'formula_number':
|
|
353
|
+
unirec_res = markdown_converter._handle_formula(unirec_res)
|
|
354
|
+
else:
|
|
355
|
+
unirec_res = markdown_converter._handle_text(unirec_res)
|
|
356
|
+
vl_rec_results.append(unirec_res)
|
|
357
|
+
|
|
358
|
+
parsing_res_lists = []
|
|
359
|
+
table_res_lists = []
|
|
360
|
+
curr_vlm_block_idx = 0
|
|
361
|
+
for i, blocks_for_img in enumerate(blocks):
|
|
362
|
+
parsing_res_list = []
|
|
363
|
+
table_res_list = []
|
|
364
|
+
for j, block in enumerate(blocks_for_img):
|
|
365
|
+
block_img = block['img']
|
|
366
|
+
block_bbox = block['box']
|
|
367
|
+
block_label = block['label']
|
|
368
|
+
block_content = ''
|
|
369
|
+
if curr_vlm_block_idx < len(vlm_block_ids) and vlm_block_ids[
|
|
370
|
+
curr_vlm_block_idx] == (i, j):
|
|
371
|
+
result_str = vl_rec_results[curr_vlm_block_idx]
|
|
372
|
+
figure_token_map = figure_token_maps[curr_vlm_block_idx]
|
|
373
|
+
block_img4vl = block_imgs[curr_vlm_block_idx]
|
|
374
|
+
curr_vlm_block_idx += 1
|
|
375
|
+
if result_str is None:
|
|
376
|
+
result_str = ''
|
|
377
|
+
result_str = truncate_repetitive_content(result_str)
|
|
378
|
+
if ('\\(' in result_str and '\\)' in result_str) or (
|
|
379
|
+
'\\[' in result_str and '\\]' in result_str):
|
|
380
|
+
result_str = result_str.replace('$', '')
|
|
381
|
+
|
|
382
|
+
result_str = (result_str.replace('\\(', ' $ ').replace(
|
|
383
|
+
'\\)',
|
|
384
|
+
' $ ').replace('\\[',
|
|
385
|
+
' $$ ').replace('\\]', ' $$ '))
|
|
386
|
+
if block_label == 'formula_number':
|
|
387
|
+
result_str = result_str.replace('$', '')
|
|
388
|
+
if block_label == 'table':
|
|
389
|
+
html_str = convert_otsl_to_html(result_str)
|
|
390
|
+
if html_str != '':
|
|
391
|
+
result_str = html_str
|
|
392
|
+
result_str = untokenize_figure_of_table(
|
|
393
|
+
result_str, figure_token_map)
|
|
394
|
+
|
|
395
|
+
block_content = result_str
|
|
396
|
+
|
|
397
|
+
block_info = PaddleOCRVLBlock(
|
|
398
|
+
label=block_label,
|
|
399
|
+
bbox=block_bbox,
|
|
400
|
+
content=block_content,
|
|
401
|
+
group_id=block.get('group_id', None),
|
|
402
|
+
)
|
|
403
|
+
if block_label in image_labels and block_img is not None:
|
|
404
|
+
x_min, y_min, x_max, y_max = list(map(int, block_bbox))
|
|
405
|
+
img_path = f'imgs/img_in_{block_label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg'
|
|
406
|
+
if img_path not in drop_figures_set:
|
|
407
|
+
block_img = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
|
|
408
|
+
block_info.image = {
|
|
409
|
+
'path': img_path,
|
|
410
|
+
'img': Image.fromarray(block_img),
|
|
411
|
+
}
|
|
412
|
+
else:
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
parsing_res_list.append(block_info)
|
|
416
|
+
parsing_res_lists.append(parsing_res_list)
|
|
417
|
+
table_res_lists.append(table_res_list)
|
|
418
|
+
|
|
419
|
+
return parsing_res_lists, table_res_lists, imgs_in_doc
|
|
420
|
+
|
|
421
|
+
def predict(
|
|
422
|
+
self,
|
|
423
|
+
input: Union[str, list[str], np.ndarray, list[np.ndarray]],
|
|
424
|
+
use_doc_orientation_classify: Union[bool, None] = False,
|
|
425
|
+
use_doc_unwarping: Union[bool, None] = False,
|
|
426
|
+
use_layout_detection: Union[bool, None] = None,
|
|
427
|
+
use_chart_recognition: Union[bool, None] = None,
|
|
428
|
+
layout_threshold: Optional[Union[float, dict]] = None,
|
|
429
|
+
layout_nms: Optional[bool] = None,
|
|
430
|
+
layout_unclip_ratio: Optional[Union[float, Tuple[float, float],
|
|
431
|
+
dict]] = None,
|
|
432
|
+
layout_merge_bboxes_mode: Optional[str] = None,
|
|
433
|
+
use_queues: Optional[bool] = None,
|
|
434
|
+
prompt_label: Optional[Union[str, None]] = None,
|
|
435
|
+
format_block_content: Union[bool, None] = None,
|
|
436
|
+
repetition_penalty: Optional[float] = None,
|
|
437
|
+
temperature: Optional[float] = None,
|
|
438
|
+
top_p: Optional[float] = None,
|
|
439
|
+
min_pixels: Optional[int] = None,
|
|
440
|
+
max_pixels: Optional[int] = None,
|
|
441
|
+
max_new_tokens: Optional[int] = None,
|
|
442
|
+
merge_layout_blocks: Optional[bool] = None,
|
|
443
|
+
markdown_ignore_labels: Optional[list[str]] = None,
|
|
444
|
+
**kwargs,
|
|
445
|
+
) -> PaddleOCRVLResult:
|
|
446
|
+
"""
|
|
447
|
+
Predicts the layout parsing result for the given input.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
|
|
451
|
+
numpy array of an image, or list of numpy arrays.
|
|
452
|
+
use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
|
|
453
|
+
use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
|
|
454
|
+
layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
|
|
455
|
+
layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
|
|
456
|
+
layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
|
|
457
|
+
Defaults to None.
|
|
458
|
+
If it's a single number, then both width and height are used.
|
|
459
|
+
If it's a tuple of two numbers, then they are used separately for width and height respectively.
|
|
460
|
+
If it's None, then no unclipping will be performed.
|
|
461
|
+
layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
|
|
462
|
+
use_queues (Optional[bool], optional): Whether to use queues. Defaults to None.
|
|
463
|
+
prompt_label (Optional[Union[str, None]], optional): The label of the prompt in ['ocr', 'formula', 'table', 'chart']. Defaults to None.
|
|
464
|
+
format_block_content (Optional[bool]): Whether to format the block content. Default is None.
|
|
465
|
+
repetition_penalty (Optional[float]): The repetition penalty parameter used for VL model sampling. Default is None.
|
|
466
|
+
temperature (Optional[float]): Temperature parameter used for VL model sampling. Default is None.
|
|
467
|
+
top_p (Optional[float]): Top-p parameter used for VL model sampling. Default is None.
|
|
468
|
+
min_pixels (Optional[int]): The minimum number of pixels allowed when the VL model preprocesses images. Default is None.
|
|
469
|
+
max_pixels (Optional[int]): The maximum number of pixels allowed when the VL model preprocesses images. Default is None.
|
|
470
|
+
max_new_tokens (Optional[int]): The maximum number of new tokens. Default is None.
|
|
471
|
+
merge_layout_blocks (Optional[bool]): Whether to merge layout blocks. Default is None.
|
|
472
|
+
markdown_ignore_labels (Optional[list[str]]): The list of ignored markdown labels. Default is None.
|
|
473
|
+
**kwargs (Any): Additional settings to extend functionality.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
PaddleOCRVLResult: The predicted layout parsing result.
|
|
477
|
+
"""
|
|
478
|
+
model_settings = self.get_model_settings(
|
|
479
|
+
use_doc_orientation_classify,
|
|
480
|
+
use_doc_unwarping,
|
|
481
|
+
use_layout_detection,
|
|
482
|
+
use_chart_recognition,
|
|
483
|
+
format_block_content,
|
|
484
|
+
merge_layout_blocks,
|
|
485
|
+
markdown_ignore_labels,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if not self.check_model_settings_valid(model_settings):
|
|
489
|
+
yield {'error': 'the input params for model settings are invalid!'}
|
|
490
|
+
|
|
491
|
+
if use_queues is None:
|
|
492
|
+
use_queues = self.use_queues
|
|
493
|
+
|
|
494
|
+
if not model_settings['use_layout_detection']:
|
|
495
|
+
prompt_label = prompt_label if prompt_label else 'ocr'
|
|
496
|
+
if prompt_label.lower() == 'chart':
|
|
497
|
+
model_settings['use_chart_recognition'] = True
|
|
498
|
+
assert prompt_label.lower() in [
|
|
499
|
+
'ocr',
|
|
500
|
+
'formula',
|
|
501
|
+
'table',
|
|
502
|
+
'chart',
|
|
503
|
+
], f"Layout detection is disabled (use_layout_detection=False). 'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], but got '{prompt_label}'."
|
|
504
|
+
|
|
505
|
+
def _process_cv(batch_data, new_batch_size=None):
|
|
506
|
+
if not new_batch_size:
|
|
507
|
+
new_batch_size = len(batch_data)
|
|
508
|
+
|
|
509
|
+
for idx in range(0, len(batch_data), new_batch_size):
|
|
510
|
+
instances = batch_data.instances[idx:idx + new_batch_size]
|
|
511
|
+
input_paths = batch_data.input_paths[idx:idx + new_batch_size]
|
|
512
|
+
page_indexes = batch_data.page_indexes[idx:idx +
|
|
513
|
+
new_batch_size]
|
|
514
|
+
page_counts = batch_data.page_counts[idx:idx + new_batch_size]
|
|
515
|
+
|
|
516
|
+
image_arrays = self.img_reader(instances)
|
|
517
|
+
|
|
518
|
+
if model_settings['use_doc_preprocessor']:
|
|
519
|
+
doc_preprocessor_results = list(
|
|
520
|
+
self.doc_preprocessor_pipeline(
|
|
521
|
+
image_arrays,
|
|
522
|
+
use_doc_orientation_classify=
|
|
523
|
+
use_doc_orientation_classify,
|
|
524
|
+
use_doc_unwarping=use_doc_unwarping,
|
|
525
|
+
))
|
|
526
|
+
else:
|
|
527
|
+
doc_preprocessor_results = [{
|
|
528
|
+
'output_img': arr
|
|
529
|
+
} for arr in image_arrays]
|
|
530
|
+
|
|
531
|
+
doc_preprocessor_images = [
|
|
532
|
+
item['output_img'] for item in doc_preprocessor_results
|
|
533
|
+
]
|
|
534
|
+
|
|
535
|
+
if model_settings['use_layout_detection']:
|
|
536
|
+
layout_det_results = list(
|
|
537
|
+
self.layout_det_model(
|
|
538
|
+
doc_preprocessor_images,
|
|
539
|
+
threshold=layout_threshold,
|
|
540
|
+
layout_nms=layout_nms,
|
|
541
|
+
layout_unclip_ratio=layout_unclip_ratio,
|
|
542
|
+
layout_merge_bboxes_mode=layout_merge_bboxes_mode,
|
|
543
|
+
))
|
|
544
|
+
|
|
545
|
+
imgs_in_doc = [
|
|
546
|
+
gather_imgs(doc_pp_img, layout_det_res['boxes'])
|
|
547
|
+
for doc_pp_img, layout_det_res in zip(
|
|
548
|
+
doc_preprocessor_images, layout_det_results)
|
|
549
|
+
]
|
|
550
|
+
else:
|
|
551
|
+
layout_det_results = []
|
|
552
|
+
for doc_preprocessor_image in doc_preprocessor_images:
|
|
553
|
+
layout_det_results.append({
|
|
554
|
+
'input_path':
|
|
555
|
+
None,
|
|
556
|
+
'page_index':
|
|
557
|
+
None,
|
|
558
|
+
'boxes': [{
|
|
559
|
+
'cls_id':
|
|
560
|
+
0,
|
|
561
|
+
'label':
|
|
562
|
+
prompt_label.lower(),
|
|
563
|
+
'score':
|
|
564
|
+
1,
|
|
565
|
+
'coordinate': [
|
|
566
|
+
0,
|
|
567
|
+
0,
|
|
568
|
+
doc_preprocessor_image.shape[1],
|
|
569
|
+
doc_preprocessor_image.shape[0],
|
|
570
|
+
],
|
|
571
|
+
}],
|
|
572
|
+
})
|
|
573
|
+
imgs_in_doc = [[] for _ in layout_det_results]
|
|
574
|
+
|
|
575
|
+
yield input_paths, page_indexes, page_counts, doc_preprocessor_images, doc_preprocessor_results, layout_det_results, imgs_in_doc
|
|
576
|
+
|
|
577
|
+
def _process_vlm(results_cv):
|
|
578
|
+
(
|
|
579
|
+
input_paths,
|
|
580
|
+
page_indexes,
|
|
581
|
+
page_counts,
|
|
582
|
+
doc_preprocessor_images,
|
|
583
|
+
doc_preprocessor_results,
|
|
584
|
+
layout_det_results,
|
|
585
|
+
imgs_in_doc,
|
|
586
|
+
) = results_cv
|
|
587
|
+
|
|
588
|
+
parsing_res_lists, table_res_lists, imgs_in_doc = (
|
|
589
|
+
self.get_layout_parsing_results(
|
|
590
|
+
doc_preprocessor_images,
|
|
591
|
+
layout_det_results,
|
|
592
|
+
imgs_in_doc,
|
|
593
|
+
model_settings['use_chart_recognition'],
|
|
594
|
+
{
|
|
595
|
+
'repetition_penalty': repetition_penalty,
|
|
596
|
+
'temperature': temperature,
|
|
597
|
+
'top_p': top_p,
|
|
598
|
+
'min_pixels': min_pixels,
|
|
599
|
+
'max_pixels': max_pixels,
|
|
600
|
+
'max_new_tokens': max_new_tokens,
|
|
601
|
+
},
|
|
602
|
+
model_settings['merge_layout_blocks'],
|
|
603
|
+
))
|
|
604
|
+
|
|
605
|
+
for (
|
|
606
|
+
input_path,
|
|
607
|
+
page_index,
|
|
608
|
+
page_count,
|
|
609
|
+
doc_preprocessor_image,
|
|
610
|
+
doc_preprocessor_res,
|
|
611
|
+
layout_det_res,
|
|
612
|
+
table_res_list,
|
|
613
|
+
parsing_res_list,
|
|
614
|
+
imgs_in_doc_for_img,
|
|
615
|
+
) in zip(
|
|
616
|
+
input_paths,
|
|
617
|
+
page_indexes,
|
|
618
|
+
page_counts,
|
|
619
|
+
doc_preprocessor_images,
|
|
620
|
+
doc_preprocessor_results,
|
|
621
|
+
layout_det_results,
|
|
622
|
+
table_res_lists,
|
|
623
|
+
parsing_res_lists,
|
|
624
|
+
imgs_in_doc,
|
|
625
|
+
):
|
|
626
|
+
single_img_res = {
|
|
627
|
+
'input_path': input_path,
|
|
628
|
+
'page_index': page_index,
|
|
629
|
+
'page_count': page_count,
|
|
630
|
+
'width': doc_preprocessor_image.shape[1],
|
|
631
|
+
'height': doc_preprocessor_image.shape[0],
|
|
632
|
+
'doc_preprocessor_res': doc_preprocessor_res,
|
|
633
|
+
'layout_det_res': layout_det_res,
|
|
634
|
+
'table_res_list': table_res_list,
|
|
635
|
+
'parsing_res_list': parsing_res_list,
|
|
636
|
+
'imgs_in_doc': imgs_in_doc_for_img,
|
|
637
|
+
'model_settings': model_settings,
|
|
638
|
+
}
|
|
639
|
+
yield PaddleOCRVLResult(single_img_res)
|
|
640
|
+
|
|
641
|
+
if use_queues:
|
|
642
|
+
max_num_batches_in_process = 64
|
|
643
|
+
queue_input = queue.Queue(maxsize=max_num_batches_in_process)
|
|
644
|
+
queue_cv = queue.Queue(maxsize=max_num_batches_in_process)
|
|
645
|
+
queue_vlm = queue.Queue(maxsize=self.batch_sampler.batch_size *
|
|
646
|
+
max_num_batches_in_process)
|
|
647
|
+
event_shutdown = threading.Event()
|
|
648
|
+
event_data_loading_done = threading.Event()
|
|
649
|
+
event_cv_processing_done = threading.Event()
|
|
650
|
+
event_vlm_processing_done = threading.Event()
|
|
651
|
+
|
|
652
|
+
def _worker_input(input_):
|
|
653
|
+
all_batch_data = self.batch_sampler(input_)
|
|
654
|
+
while not event_shutdown.is_set():
|
|
655
|
+
try:
|
|
656
|
+
batch_data = next(all_batch_data)
|
|
657
|
+
except StopIteration:
|
|
658
|
+
break
|
|
659
|
+
except Exception as e:
|
|
660
|
+
queue_input.put((False, 'input', e))
|
|
661
|
+
break
|
|
662
|
+
else:
|
|
663
|
+
queue_input.put((True, batch_data))
|
|
664
|
+
event_data_loading_done.set()
|
|
665
|
+
|
|
666
|
+
def _worker_cv():
|
|
667
|
+
while not event_shutdown.is_set():
|
|
668
|
+
try:
|
|
669
|
+
item = queue_input.get(timeout=0.5)
|
|
670
|
+
except queue.Empty:
|
|
671
|
+
if event_data_loading_done.is_set():
|
|
672
|
+
event_cv_processing_done.set()
|
|
673
|
+
break
|
|
674
|
+
continue
|
|
675
|
+
if not item[0]:
|
|
676
|
+
queue_cv.put(item)
|
|
677
|
+
break
|
|
678
|
+
try:
|
|
679
|
+
for results_cv in _process_cv(
|
|
680
|
+
item[1],
|
|
681
|
+
(self.layout_det_model.batch_sampler.batch_size if
|
|
682
|
+
model_settings['use_layout_detection'] else None),
|
|
683
|
+
):
|
|
684
|
+
queue_cv.put((True, results_cv))
|
|
685
|
+
except Exception as e:
|
|
686
|
+
queue_cv.put((False, 'cv', e))
|
|
687
|
+
break
|
|
688
|
+
|
|
689
|
+
def _worker_vlm():
|
|
690
|
+
MAX_QUEUE_DELAY_SECS = 0.5
|
|
691
|
+
MAX_NUM_BOXES = 4096 #self.vl_rec_model.batch_sampler.batch_size
|
|
692
|
+
|
|
693
|
+
while not event_shutdown.is_set():
|
|
694
|
+
results_cv_list = []
|
|
695
|
+
start_time = time.time()
|
|
696
|
+
should_break = False
|
|
697
|
+
num_boxes = 0
|
|
698
|
+
while True:
|
|
699
|
+
remaining_time = MAX_QUEUE_DELAY_SECS - (time.time() -
|
|
700
|
+
start_time)
|
|
701
|
+
if remaining_time <= 0:
|
|
702
|
+
break
|
|
703
|
+
try:
|
|
704
|
+
item = queue_cv.get(timeout=remaining_time)
|
|
705
|
+
except queue.Empty:
|
|
706
|
+
break
|
|
707
|
+
if not item[0]:
|
|
708
|
+
queue_vlm.put(item)
|
|
709
|
+
should_break = True
|
|
710
|
+
break
|
|
711
|
+
results_cv_list.append(item[1])
|
|
712
|
+
for res in results_cv_list[-1][5]:
|
|
713
|
+
num_boxes += len(res['boxes'])
|
|
714
|
+
if num_boxes >= MAX_NUM_BOXES:
|
|
715
|
+
break
|
|
716
|
+
if should_break:
|
|
717
|
+
break
|
|
718
|
+
if not results_cv_list:
|
|
719
|
+
if event_cv_processing_done.is_set():
|
|
720
|
+
event_vlm_processing_done.set()
|
|
721
|
+
break
|
|
722
|
+
continue
|
|
723
|
+
|
|
724
|
+
merged_results_cv = [
|
|
725
|
+
list(chain.from_iterable(lists))
|
|
726
|
+
for lists in zip(*results_cv_list)
|
|
727
|
+
]
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
for result_vlm in _process_vlm(merged_results_cv):
|
|
731
|
+
queue_vlm.put((True, result_vlm))
|
|
732
|
+
except Exception as e:
|
|
733
|
+
queue_vlm.put((False, 'vlm', e))
|
|
734
|
+
break
|
|
735
|
+
|
|
736
|
+
thread_input = threading.Thread(target=_worker_input,
|
|
737
|
+
args=(input, ),
|
|
738
|
+
daemon=False)
|
|
739
|
+
thread_input.start()
|
|
740
|
+
thread_cv = threading.Thread(target=_worker_cv, daemon=False)
|
|
741
|
+
thread_cv.start()
|
|
742
|
+
thread_vlm = threading.Thread(target=_worker_vlm, daemon=False)
|
|
743
|
+
thread_vlm.start()
|
|
744
|
+
|
|
745
|
+
try:
|
|
746
|
+
if use_queues:
|
|
747
|
+
while not (event_vlm_processing_done.is_set()
|
|
748
|
+
and queue_vlm.empty()):
|
|
749
|
+
try:
|
|
750
|
+
item = queue_vlm.get(timeout=0.5)
|
|
751
|
+
except queue.Empty:
|
|
752
|
+
if event_vlm_processing_done.is_set():
|
|
753
|
+
break
|
|
754
|
+
continue
|
|
755
|
+
if not item[0]:
|
|
756
|
+
raise RuntimeError(
|
|
757
|
+
f"Exception from the '{item[1]}' worker: {item[2]}"
|
|
758
|
+
)
|
|
759
|
+
else:
|
|
760
|
+
yield item[1]
|
|
761
|
+
else:
|
|
762
|
+
for batch_data in self.batch_sampler(input):
|
|
763
|
+
results_cv_list = list(_process_cv(batch_data))
|
|
764
|
+
assert len(results_cv_list) == 1, len(results_cv_list)
|
|
765
|
+
results_cv = results_cv_list[0]
|
|
766
|
+
for res in _process_vlm(results_cv):
|
|
767
|
+
yield res
|
|
768
|
+
finally:
|
|
769
|
+
if use_queues:
|
|
770
|
+
event_shutdown.set()
|
|
771
|
+
thread_cv.join(timeout=5)
|
|
772
|
+
if thread_cv.is_alive():
|
|
773
|
+
logging.warning('CV worker did not terminate in time')
|
|
774
|
+
thread_vlm.join(timeout=5)
|
|
775
|
+
if thread_vlm.is_alive():
|
|
776
|
+
logging.warning('VLM worker did not terminate in time')
|
|
777
|
+
|
|
778
|
+
def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
|
|
779
|
+
"""
|
|
780
|
+
Concatenate Markdown content from multiple pages into a single document.
|
|
781
|
+
|
|
782
|
+
Args:
|
|
783
|
+
markdown_list (list): A list containing Markdown data for each page.
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
tuple: A tuple containing the processed Markdown text.
|
|
787
|
+
"""
|
|
788
|
+
markdown_texts = ''
|
|
789
|
+
|
|
790
|
+
for res in markdown_list:
|
|
791
|
+
markdown_texts += '\n\n' + res['markdown_texts']
|
|
792
|
+
|
|
793
|
+
return markdown_texts
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def process_batch(process_batch_list: list, gpuId: int, output_path: str,
|
|
797
|
+
is_save_vis_img: bool, is_save_json: bool,
|
|
798
|
+
is_save_markdown: bool, pretty: bool):
|
|
799
|
+
|
|
800
|
+
opendoc_pipeline = OpenDoc(gpuId=gpuId)
|
|
801
|
+
for img_path in process_batch_list:
|
|
802
|
+
img_name = os.path.basename(img_path)[:-4]
|
|
803
|
+
output = opendoc_pipeline.predict(img_path,
|
|
804
|
+
use_doc_orientation_classify=False,
|
|
805
|
+
use_doc_unwarping=False)
|
|
806
|
+
for res in output:
|
|
807
|
+
if is_save_vis_img:
|
|
808
|
+
res.save_to_img(output_path)
|
|
809
|
+
if is_save_json:
|
|
810
|
+
res.save_to_json(save_path=output_path) ## 保存当前图像的结构化json结果
|
|
811
|
+
if is_save_markdown:
|
|
812
|
+
res.save_to_markdown(save_path=output_path,
|
|
813
|
+
pretty=pretty) ## 保存当前图像的markdown格式的结果
|
|
814
|
+
if gpuId >= 0:
|
|
815
|
+
logger.info(f'GPU {gpuId} processing {img_name} done!')
|
|
816
|
+
else:
|
|
817
|
+
logger.info(f'CPU processing {img_name} done!')
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
if __name__ == '__main__':
|
|
821
|
+
|
|
822
|
+
parser = argparse.ArgumentParser(
|
|
823
|
+
description='OpenDoc Pipeline for Document OCR')
|
|
824
|
+
parser.add_argument(
|
|
825
|
+
'--input_path',
|
|
826
|
+
type=str,
|
|
827
|
+
required=True,
|
|
828
|
+
help='Path to the directory containing images or pdf files to process')
|
|
829
|
+
parser.add_argument('--output_path',
|
|
830
|
+
type=str,
|
|
831
|
+
default='./output',
|
|
832
|
+
help='Path to save output results (default: ./output)')
|
|
833
|
+
parser.add_argument(
|
|
834
|
+
'--gpus',
|
|
835
|
+
type=str,
|
|
836
|
+
default='0',
|
|
837
|
+
help=
|
|
838
|
+
'GPU IDs to use, separated by comma (e.g., "0,1,2,3"). Use "-1" for CPU mode (default: "0")'
|
|
839
|
+
)
|
|
840
|
+
parser.add_argument(
|
|
841
|
+
'--is_save_vis_img',
|
|
842
|
+
action='store_true',
|
|
843
|
+
help='Save visualized images with layout detection boxes')
|
|
844
|
+
parser.add_argument('--is_save_json',
|
|
845
|
+
action='store_true',
|
|
846
|
+
help='Save JSON results')
|
|
847
|
+
parser.add_argument('--is_save_markdown',
|
|
848
|
+
action='store_true',
|
|
849
|
+
help='Save Markdown results')
|
|
850
|
+
parser.add_argument('--pretty',
|
|
851
|
+
action='store_true',
|
|
852
|
+
help='Pretty print Markdown results')
|
|
853
|
+
|
|
854
|
+
args = parser.parse_args()
|
|
855
|
+
|
|
856
|
+
# Parse GPU IDs
|
|
857
|
+
gpus = [int(gpu_id.strip()) for gpu_id in args.gpus.split(',')]
|
|
858
|
+
|
|
859
|
+
# Create output directory if not exists
|
|
860
|
+
os.makedirs(args.output_path, exist_ok=True)
|
|
861
|
+
|
|
862
|
+
# Get image list
|
|
863
|
+
img_list = get_image_file_list(args.input_path)
|
|
864
|
+
random.shuffle(img_list)
|
|
865
|
+
|
|
866
|
+
logger.info(f'Found {len(img_list)} images in {args.input_path}')
|
|
867
|
+
if gpus[0] == -1:
|
|
868
|
+
logger.info('Running in CPU mode')
|
|
869
|
+
else:
|
|
870
|
+
logger.info(f'Using GPUs: {gpus}')
|
|
871
|
+
logger.info(f'Output will be saved to: {args.output_path}')
|
|
872
|
+
|
|
873
|
+
# Split images into batches for each GPU
|
|
874
|
+
if len(gpus) == 1:
|
|
875
|
+
process_batch(img_list, gpus[0], args.output_path,
|
|
876
|
+
args.is_save_vis_img, args.is_save_json,
|
|
877
|
+
args.is_save_markdown, args.pretty)
|
|
878
|
+
else:
|
|
879
|
+
num_gpus = len(gpus)
|
|
880
|
+
img_list_batch = [img_list[i::num_gpus] for i in range(num_gpus)]
|
|
881
|
+
|
|
882
|
+
# Create and start processes
|
|
883
|
+
process_list = []
|
|
884
|
+
for idx, gpuId in enumerate(gpus):
|
|
885
|
+
process_list.append(
|
|
886
|
+
Process(target=process_batch,
|
|
887
|
+
args=(img_list_batch[idx], gpuId, args.output_path,
|
|
888
|
+
args.is_save_vis_img, args.is_save_json,
|
|
889
|
+
args.is_save_markdown, args.pretty)))
|
|
890
|
+
|
|
891
|
+
for process in process_list:
|
|
892
|
+
process.start()
|
|
893
|
+
|
|
894
|
+
for process in process_list:
|
|
895
|
+
process.join()
|
|
896
|
+
|
|
897
|
+
logger.info(
|
|
898
|
+
f'All processing completed! Results saved to {args.output_path}')
|