openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1172 @@
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ from pathlib import Path
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import argparse
10
+ from typing import Dict, Optional, Union, List, Tuple
11
+
12
+ import cv2
13
+ import numpy as np
14
+ from PIL import Image
15
+ import onnxruntime as ort
16
+ from tools.utils.logging import get_logger
17
+ from tools.utils.utility import get_image_file_list
18
+
19
+ from tools.utils.opendoc_onnx_utils.utils import (
20
+ convert_otsl_to_html,
21
+ crop_margin,
22
+ filter_overlap_boxes,
23
+ merge_blocks,
24
+ tokenize_figure_of_table,
25
+ truncate_repetitive_content,
26
+ untokenize_figure_of_table,
27
+ )
28
+ from tools.to_markdown import MarkdownConverter
29
+ from tools.infer_unirec_onnx import (
30
+ UniRecONNX
31
+ )
32
+
33
+ # 创建全局 markdown_converter 实例
34
+ markdown_converter = MarkdownConverter()
35
+
36
+ logger = get_logger(name='opendoc_onnx')
37
+
38
+ root_dir = Path(__file__).resolve().parent
39
+
40
+ IMAGE_LABELS = ['image', 'header_image', 'footer_image', 'seal']
41
+
42
+
43
+ def download_layout_model(model_dir=None):
44
+ """Download layout detection ONNX model from ModelScope or HuggingFace.
45
+
46
+ Args:
47
+ model_dir: Directory to save model file. If None, use default cache directory.
48
+
49
+ Returns:
50
+ Path to the downloaded model file
51
+ """
52
+ # Use default cache directory if not specified
53
+ if model_dir is None:
54
+ cache_dir = Path.home() / '.cache' / 'openocr'
55
+ model_dir = cache_dir / 'PP_DoclayoutV2_onnx'
56
+ else:
57
+ model_dir = Path(model_dir)
58
+
59
+ model_dir.mkdir(parents=True, exist_ok=True)
60
+
61
+ model_file = 'PP-DoclayoutV2.onnx'
62
+ model_path = model_dir / model_file
63
+
64
+ # Check if model already exists
65
+ if model_path.exists():
66
+ logger.info(f'✅ Layout model found in {model_dir}')
67
+ return str(model_path)
68
+
69
+ logger.info(f'📥 Downloading layout model to {model_dir}...')
70
+
71
+ download_success = False
72
+
73
+ try:
74
+ # Try ModelScope first (default)
75
+ logger.info('🌐 Trying ModelScope (China mirror) first...')
76
+ try:
77
+ from modelscope import snapshot_download
78
+ downloaded_path = snapshot_download(
79
+ 'topdktu/PP_DoclayoutV2_onnx',
80
+ cache_dir=str(model_dir.parent)
81
+ )
82
+ logger.info(f'✅ Downloaded to {downloaded_path}')
83
+
84
+ # Copy file to target directory
85
+ import shutil
86
+ src = Path(downloaded_path) / model_file
87
+ if src.exists() and not model_path.exists():
88
+ shutil.copy(str(src), str(model_path))
89
+ logger.info(f' ✓ {model_file}')
90
+
91
+ # Verify file exists after download
92
+ if model_path.exists():
93
+ download_success = True
94
+ logger.info('✅ Layout model downloaded successfully from ModelScope!')
95
+ else:
96
+ logger.info('⚠️ ModelScope download incomplete, trying HuggingFace...')
97
+
98
+ except ImportError:
99
+ logger.info('ModelScope not installed. Install with: pip install modelscope')
100
+ logger.info('Trying HuggingFace...')
101
+ except Exception as e:
102
+ logger.info(f'ModelScope download failed: {e}')
103
+ logger.info('Trying HuggingFace...')
104
+
105
+ if not download_success:
106
+ # Try HuggingFace
107
+ logger.info('🌐 Using HuggingFace...')
108
+ try:
109
+ from huggingface_hub import hf_hub_download
110
+ logger.info(f' Downloading {model_file}...')
111
+ downloaded_path = hf_hub_download(
112
+ repo_id='topdu/PP_DoclayoutV2_onnx',
113
+ filename=model_file,
114
+ cache_dir=str(model_dir.parent),
115
+ local_dir=str(model_dir),
116
+ local_dir_use_symlinks=False
117
+ )
118
+ logger.info(f' ✓ {model_file}')
119
+
120
+ # Verify file exists after download
121
+ if model_path.exists():
122
+ download_success = True
123
+ logger.info('✅ Layout model downloaded successfully from HuggingFace!')
124
+
125
+ except ImportError:
126
+ raise ImportError('HuggingFace Hub not installed. Install with: pip install huggingface_hub')
127
+
128
+ if not download_success:
129
+ raise RuntimeError(
130
+ 'Failed to download layout model. Please manually download from:\n'
131
+ ' - https://huggingface.co/topdu/PP_DoclayoutV2_onnx\n'
132
+ ' - https://modelscope.cn/models/topdktu/PP_DoclayoutV2_onnx'
133
+ )
134
+
135
+ except Exception as e:
136
+ logger.error(f'❌ Failed to download layout model: {e}')
137
+ raise
138
+
139
+ return str(model_path)
140
+
141
+
142
+ def check_and_download_layout_model(model_path, auto_download=True):
143
+ """Check if layout model exists, download if missing.
144
+
145
+ Args:
146
+ model_path: Path to layout model file
147
+ auto_download: If True, automatically download missing model
148
+
149
+ Returns:
150
+ Path to the model file
151
+ """
152
+ if model_path and os.path.exists(model_path):
153
+ return model_path
154
+
155
+ if not auto_download:
156
+ if not model_path or not os.path.exists(model_path):
157
+ logger.error(f'⚠️ Layout model not found: {model_path}')
158
+ logger.info('\n📝 Manual download instructions:')
159
+ logger.info(' 1. Visit: https://huggingface.co/topdu/PP_DoclayoutV2_onnx')
160
+ logger.info(' 2. Download PP-DoclayoutV2.onnx')
161
+ logger.info(' 3. Specify path with --layout_model argument')
162
+ raise FileNotFoundError(f'Layout model not found: {model_path}')
163
+
164
+ # Determine model directory from model path
165
+ default_path = str(Path.home() / '.cache' / 'openocr' / 'PP_DoclayoutV2_onnx' / 'PP-DoclayoutV2.onnx')
166
+ if model_path and model_path != default_path:
167
+ # User specified a custom path
168
+ model_dir = os.path.dirname(model_path)
169
+ else:
170
+ # Use default cache directory
171
+ model_dir = None
172
+
173
+ # Try ModelScope first (faster in China), then HuggingFace
174
+ try:
175
+ logger.info('🇨🇳 Trying ModelScope (China mirror) first...')
176
+ return download_layout_model(model_dir)
177
+ except:
178
+ logger.info('🌍 Trying HuggingFace...')
179
+ return download_layout_model(model_dir)
180
+
181
+
182
+ def _get_image_name_and_dir(result: Dict, output_path: str):
183
+ """根据图片名创建子目录并返回(img_name, img_dir)"""
184
+ img_name = os.path.basename(result['input_path'])
185
+ if '.' in img_name:
186
+ img_name = img_name.rsplit('.', 1)[0]
187
+
188
+ img_dir = os.path.join(output_path, img_name)
189
+ os.makedirs(img_dir, exist_ok=True)
190
+
191
+ return img_name, img_dir
192
+
193
+
194
+ # ==================== Layout Detection ONNX ====================
195
+ class LayoutDetectorONNX:
196
+ """ONNX版本的版面检测模型"""
197
+
198
+ def __init__(self,
199
+ model_path: str,
200
+ use_gpu: Optional[bool] = None,
201
+ threshold: float = 0.5,
202
+ auto_download: bool = True):
203
+ """
204
+ 初始化ONNX版面检测模型
205
+
206
+ Args:
207
+ model_path: ONNX模型路径
208
+ use_gpu: Whether to use GPU. If None, auto-detect. If True, force GPU. If False, force CPU.
209
+ threshold: 检测阈值
210
+ auto_download: If True, automatically download missing model
211
+ """
212
+ self.threshold = threshold
213
+
214
+ # Check and download model if needed
215
+ model_path = check_and_download_layout_model(model_path, auto_download=auto_download)
216
+
217
+ # Determine execution providers
218
+ providers = self._get_execution_providers(use_gpu)
219
+ logger.info(f'Layout detector using: {providers[0]}')
220
+
221
+ # 创建ONNX Runtime会话
222
+ sess_options = ort.SessionOptions()
223
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
224
+ self.session = ort.InferenceSession(model_path,
225
+ sess_options,
226
+ providers=providers)
227
+
228
+ # 获取输入输出信息
229
+ self.input_names = [inp.name for inp in self.session.get_inputs()]
230
+ self.output_names = [
231
+ output.name for output in self.session.get_outputs()
232
+ ]
233
+
234
+ logger.info(f' Input names: {self.input_names}')
235
+ logger.info(f' Output names: {self.output_names}')
236
+
237
+ self.label_map = {
238
+ 0: 'abstract',
239
+ 1: 'algorithm',
240
+ 2: 'aside_text',
241
+ 3: 'chart',
242
+ 4: 'content',
243
+ 5: 'display_formula',
244
+ 6: 'doc_title',
245
+ 7: 'figure_title',
246
+ 8: 'footer',
247
+ 9: 'footer_image',
248
+ 10: 'footnote',
249
+ 11: 'formula_number',
250
+ 12: 'header',
251
+ 13: 'header_image',
252
+ 14: 'image',
253
+ 15: 'inline_formula',
254
+ 16: 'number',
255
+ 17: 'paragraph_title',
256
+ 18: 'reference',
257
+ 19: 'reference_content',
258
+ 20: 'seal',
259
+ 21: 'table',
260
+ 22: 'text',
261
+ 23: 'vertical_text',
262
+ 24: 'vision_footnote'
263
+ }
264
+
265
+ def _get_execution_providers(self, use_gpu):
266
+ """Determine execution providers based on GPU availability and user preference.
267
+
268
+ Args:
269
+ use_gpu: None (auto-detect), True (force GPU), or False (force CPU)
270
+
271
+ Returns:
272
+ List of execution providers in priority order
273
+ """
274
+ available_providers = ort.get_available_providers()
275
+
276
+ if use_gpu is False:
277
+ # Force CPU
278
+ logger.info('🔧 User specified: Using CPU for layout detection')
279
+ return ['CPUExecutionProvider']
280
+
281
+ # Check for GPU providers
282
+ gpu_providers = []
283
+ if 'CUDAExecutionProvider' in available_providers:
284
+ gpu_providers.append('CUDAExecutionProvider')
285
+ # if 'TensorrtExecutionProvider' in available_providers:
286
+ # gpu_providers.append('TensorrtExecutionProvider')
287
+
288
+ if use_gpu is True:
289
+ # Force GPU
290
+ if gpu_providers:
291
+ logger.info(f'🔧 User specified: Using GPU for layout detection ({gpu_providers[0]})')
292
+ return gpu_providers + ['CPUExecutionProvider']
293
+ else:
294
+ logger.warning('⚠️ GPU requested but not available, falling back to CPU')
295
+ return ['CPUExecutionProvider']
296
+
297
+ # Auto-detect (use_gpu is None)
298
+ if gpu_providers:
299
+ logger.info(f'✅ GPU detected for layout detection: Using {gpu_providers[0]}')
300
+ return gpu_providers + ['CPUExecutionProvider']
301
+ else:
302
+ logger.info('ℹ️ No GPU detected for layout detection, using CPU')
303
+ return ['CPUExecutionProvider']
304
+
305
+
306
+
307
+ def crop_by_boxes(self, image: np.ndarray,
308
+ boxes: List[Dict]) -> List[Dict]:
309
+ """
310
+ 根据检测框裁剪图像区域
311
+
312
+ Args:
313
+ image: BGR格式的原始图像
314
+ boxes: 检测框列表
315
+
316
+ Returns:
317
+ 包含裁剪图像的块列表
318
+ """
319
+ blocks = []
320
+ for box in boxes:
321
+ coord = box['coordinate']
322
+ x1, y1, x2, y2 = map(int, coord)
323
+
324
+ # 裁剪图像
325
+ cropped_img = image[y1:y2, x1:x2]
326
+ if cropped_img.size == 0:
327
+ cropped_img = None
328
+
329
+ blocks.append({
330
+ 'img': cropped_img,
331
+ 'box': coord,
332
+ 'label': box['label'],
333
+ 'score': box.get('score', 1.0),
334
+ 'cls_id': box.get('cls_id', -1),
335
+ 'custom_value': box.get('custom_value', 0),
336
+ })
337
+ return blocks
338
+
339
+ def preprocess(
340
+ self, image: np.ndarray, target_input_size: tuple = (800, 800)
341
+ ) -> Tuple[Dict, Tuple[float, float], int, int]:
342
+ """
343
+ Args:
344
+ image: BGR格式的图像
345
+ target_input_size: 目标尺寸 (height, width)
346
+
347
+ Returns:
348
+ 输入字典, (scale_h, scale_w), 原始高度, 原始宽度
349
+ """
350
+ # Get original dimensions
351
+ orig_h, orig_w = image.shape[:2]
352
+
353
+ # Resize (keep_ratio=false, interp=2)
354
+ target_h, target_w = target_input_size
355
+ scale_h = target_h / orig_h
356
+ scale_w = target_w / orig_w
357
+
358
+ new_h, new_w = int(orig_h * scale_h), int(orig_w * scale_w)
359
+ resized = cv2.resize(image, (new_w, new_h),
360
+ interpolation=cv2.INTER_LINEAR)
361
+
362
+ # Convert BGR to RGB
363
+ resized_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
364
+
365
+ input_blob = resized_rgb.astype(np.float32) / 255.0
366
+
367
+ input_blob = input_blob.transpose(2, 0, 1)[np.newaxis, ...]
368
+
369
+ preprocess_shape = np.array([[target_h, target_w]], dtype=np.float32)
370
+
371
+ # scale_factor: [[scale_h, scale_w]]
372
+ scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
373
+
374
+ inputs = {
375
+ 'im_shape': preprocess_shape, # shape: [1, 2]
376
+ 'image': input_blob.astype(np.float32),
377
+ 'scale_factor': scale_factor # shape: [1, 2]
378
+ }
379
+
380
+ return inputs, (scale_h, scale_w), orig_h, orig_w
381
+
382
+ def postprocess(
383
+ self,
384
+ image: np.ndarray,
385
+ outputs: list,
386
+ scale: Tuple[float, float],
387
+ ori_h: int,
388
+ ori_w: int,
389
+ merge_layout_blocks: bool = True,
390
+ use_chart_recognition: bool = False,
391
+ ) -> Dict:
392
+ """
393
+ 后处理,仿照 get_layout_parsing_results 的逻辑
394
+
395
+ Args:
396
+ image: 原始图像 (BGR格式)
397
+ outputs: 模型输出
398
+ scale: 缩放因子 (scale_h, scale_w)
399
+ ori_h: 原始高度
400
+ ori_w: 原始宽度
401
+ merge_layout_blocks: 是否合并布局块
402
+ use_chart_recognition: 是否识别图表
403
+
404
+ Returns:
405
+ 检测结果字典,包含 boxes 和 blocks
406
+ """
407
+ # PaddleDetection ONNX 输出格式:
408
+ # outputs[0]: bbox [N, 8] - 前6个值: [class_id, score, x1, y1, x2, y2]
409
+ bboxes = outputs[0] # [N, 8]
410
+
411
+ # 如果没有检测到任何框
412
+ if bboxes.shape[0] == 0:
413
+ return {'boxes': [], 'blocks': []}
414
+
415
+ # 过滤低置信度的框
416
+ filtered_bboxes = bboxes[bboxes[:, 1] > self.threshold]
417
+
418
+ if filtered_bboxes.shape[0] == 0:
419
+ return {'boxes': [], 'blocks': []}
420
+
421
+ # 解析每个检测框
422
+ result_boxes = []
423
+ for bbox in filtered_bboxes:
424
+ class_id = int(bbox[0])
425
+ score = float(bbox[1])
426
+ order_value = float(bbox[6])
427
+ x1, y1, x2, y2 = bbox[2:6]
428
+
429
+ # 裁剪到图像边界
430
+ x1 = float(np.clip(x1, 0, ori_w))
431
+ y1 = float(np.clip(y1, 0, ori_h))
432
+ x2 = float(np.clip(x2, 0, ori_w))
433
+ y2 = float(np.clip(y2, 0, ori_h))
434
+
435
+ result_boxes.append({
436
+ 'cls_id':
437
+ class_id,
438
+ 'label':
439
+ self.label_map.get(class_id, f'class_{class_id}'),
440
+ 'score':
441
+ score,
442
+ 'coordinate': [x1, y1, x2, y2],
443
+ 'custom_value':
444
+ order_value
445
+ })
446
+
447
+ result_dict = {'boxes': result_boxes}
448
+
449
+ # 去除重叠框
450
+ result_dict = filter_overlap_boxes(result_dict)
451
+
452
+ # 根据 custom_value 排序
453
+ result_dict['boxes'] = sorted(result_dict['boxes'],
454
+ key=lambda box: box['custom_value'],
455
+ reverse=False)
456
+
457
+ # 给每个 label 添加顺序编号
458
+ for idx, box in enumerate(result_dict['boxes'], start=1):
459
+ base_label = box['label']
460
+ box['label'] = f'{base_label}_{idx:02d}'
461
+
462
+ # 裁剪图像区域
463
+ blocks = self.crop_by_boxes(image, result_dict['boxes'])
464
+
465
+ # 确定 image_labels
466
+ image_labels = IMAGE_LABELS if use_chart_recognition else IMAGE_LABELS + [
467
+ 'chart'
468
+ ]
469
+
470
+ # 合并布局块
471
+ if merge_layout_blocks:
472
+ blocks = merge_blocks(blocks,
473
+ non_merge_labels=image_labels + ['table'])
474
+
475
+ result_dict['blocks'] = blocks
476
+
477
+ return result_dict
478
+
479
+ def __call__(self,
480
+ images: Union[np.ndarray, List[np.ndarray]],
481
+ threshold: Optional[float] = None) -> List[Dict]:
482
+ """
483
+ 执行版面检测
484
+
485
+ Args:
486
+ images: 单张或多张图像
487
+ threshold: 置信度阈值
488
+
489
+ Returns:
490
+ 检测结果列表
491
+ """
492
+ if threshold is not None:
493
+ original_threshold = self.threshold
494
+ self.threshold = threshold
495
+
496
+ if isinstance(images, np.ndarray):
497
+ images = [images]
498
+
499
+ results = []
500
+ for image in images:
501
+ # 预处理
502
+ input_dict, scale, ori_h, ori_w = self.preprocess(image)
503
+
504
+ # 推理
505
+ outputs = self.session.run(self.output_names, input_dict)
506
+
507
+ # 后处理
508
+ result = self.postprocess(image, outputs, scale, ori_h, ori_w)
509
+ results.append(result)
510
+
511
+ if threshold is not None:
512
+ self.threshold = original_threshold
513
+
514
+ return results
515
+
516
+
517
+
518
+ # ==================== OpenDoc ONNX Pipeline ====================
519
+ class OpenDocONNX:
520
+ """完整的文档OCR ONNX Pipeline"""
521
+
522
+ def __init__(
523
+ self,
524
+ layout_model_path: Optional[str] = None,
525
+ unirec_encoder_path: Optional[str] = None,
526
+ unirec_decoder_path: Optional[str] = None,
527
+ tokenizer_mapping_path: Optional[str] = None,
528
+ use_gpu: Optional[bool] = None,
529
+ layout_threshold: float = 0.5,
530
+ use_layout_detection: bool = True,
531
+ use_chart_recognition: bool = True,
532
+ auto_download: bool = True,
533
+ ):
534
+ """
535
+ 初始化OpenDoc ONNX Pipeline
536
+
537
+ Args:
538
+ layout_model_path: 版面检测ONNX模型路径. If None, use default cache directory.
539
+ unirec_encoder_path: UniRec编码器ONNX模型路径. If None, use default cache directory.
540
+ unirec_decoder_path: UniRec解码器ONNX模型路径. If None, use default cache directory.
541
+ tokenizer_mapping_path: Tokenizer映射文件路径. If None, use default cache directory.
542
+ use_gpu: Whether to use GPU. If None, auto-detect. If True, force GPU. If False, force CPU.
543
+ layout_threshold: 版面检测阈值
544
+ use_layout_detection: 是否使用版面检测
545
+ use_chart_recognition: 是否识别图表
546
+ auto_download: If True, automatically download missing models
547
+ """
548
+ self.use_layout_detection = use_layout_detection
549
+ self.use_chart_recognition = use_chart_recognition
550
+
551
+ # Set default paths if not provided
552
+ if layout_model_path is None:
553
+ cache_dir = Path.home() / '.cache' / 'openocr'
554
+ layout_model_path = str(cache_dir / 'PP_DoclayoutV2_onnx' / 'PP-DoclayoutV2.onnx')
555
+
556
+ # Markdown忽略的标签
557
+ self.markdown_ignore_labels = [
558
+ 'number', 'footnote', 'header', 'footer', 'aside_text', 'footer_image', 'header_image','chart'
559
+ ]
560
+
561
+ # 为所有25种标签类型定义不同的颜色 (BGR格式)
562
+ self.colors = {
563
+ 'abstract': (255, 128, 0), # 橙色
564
+ 'algorithm': (128, 0, 255), # 紫色
565
+ 'aside_text': (128, 128, 128), # 灰色
566
+ 'chart': (0, 255, 255), # 青色
567
+ 'content': (0, 255, 0), # 绿色
568
+ 'display_formula': (255, 0, 255), # 品红
569
+ 'doc_title': (255, 0, 0), # 红色
570
+ 'figure_title': (255, 128, 128), # 浅红
571
+ 'footer': (64, 64, 64), # 深灰
572
+ 'footer_image': (128, 64, 0), # 棕色
573
+ 'footnote': (192, 192, 192), # 浅灰
574
+ 'formula_number': (255, 128, 255),# 浅品红
575
+ 'header': (96, 96, 96), # 中灰
576
+ 'header_image': (0, 128, 128), # 深青
577
+ 'image': (0, 255, 255), # 青色
578
+ 'inline_formula': (200, 0, 200), # 深品红
579
+ 'number': (128, 255, 0), # 黄绿
580
+ 'paragraph_title': (255, 64, 0), # 橙红
581
+ 'reference': (0, 128, 255), # 天蓝
582
+ 'reference_content': (128, 192, 255), # 浅蓝
583
+ 'seal': (0, 0, 128), # 深蓝
584
+ 'table': (0, 0, 255), # 蓝色
585
+ 'text': (0, 200, 0), # 深绿
586
+ 'vertical_text': (128, 255, 128), # 浅绿
587
+ 'vision_footnote': (160, 160, 160) # 中浅灰
588
+ }
589
+
590
+ # 初始化版面检测模型
591
+ if use_layout_detection:
592
+ self.layout_detector = LayoutDetectorONNX(
593
+ layout_model_path, use_gpu=use_gpu, threshold=layout_threshold, auto_download=auto_download)
594
+ else:
595
+ self.layout_detector = None
596
+
597
+ # 初始化VLM模型
598
+ self.vlm_recognizer = UniRecONNX(
599
+ encoder_path=unirec_encoder_path,
600
+ decoder_path=unirec_decoder_path,
601
+ mapping_path=tokenizer_mapping_path,
602
+ use_gpu=use_gpu,
603
+ auto_download=auto_download)
604
+
605
+
606
+ def __call__(
607
+ self,
608
+ img_path: Optional[str] = None,
609
+ img_numpy: Optional[np.ndarray] = None,
610
+ image_path: Optional[str] = None,
611
+ layout_threshold: Optional[float] = None,
612
+ max_length: int = 2048,
613
+ merge_layout_blocks: bool = True,
614
+ ) -> Dict:
615
+ """
616
+ Unified interface for OpenDoc inference.
617
+
618
+ Args:
619
+ img_path: Path to input image (str or Path)
620
+ img_numpy: Input image as numpy array (BGR format)
621
+ image_path: Alias for img_path (for backward compatibility)
622
+ layout_threshold: Layout detection threshold
623
+ max_length: VLM maximum generation length
624
+ merge_layout_blocks: Whether to merge layout blocks
625
+
626
+ Returns:
627
+ Prediction result dictionary
628
+ """
629
+ # Handle backward compatibility: image_path is alias for img_path
630
+ if image_path is not None and img_path is None:
631
+ img_path = image_path
632
+
633
+ # Load image from path or numpy array
634
+ is_temp_file = False
635
+ if img_path is not None:
636
+ actual_path = img_path
637
+ elif img_numpy is not None:
638
+ # For numpy array input, we need to save it temporarily
639
+ import tempfile
640
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
641
+ tmp_path = tmp_file.name
642
+ cv2.imwrite(tmp_path, img_numpy)
643
+ actual_path = tmp_path
644
+ is_temp_file = True
645
+ else:
646
+ raise ValueError('Either img_path or img_numpy must be provided')
647
+
648
+ start_time = time.time()
649
+
650
+ # 读取图像
651
+ image = cv2.imread(actual_path)
652
+ if image is None:
653
+ raise ValueError(f'Failed to read image: {actual_path}')
654
+
655
+ ori_h, ori_w = image.shape[:2]
656
+
657
+ # 版面检测
658
+ layout_results = None
659
+ if self.use_layout_detection:
660
+ layout_results = self.layout_detector(
661
+ [image], threshold=layout_threshold)[0]
662
+ else:
663
+ # 整张图作为一个区域
664
+ layout_results = {
665
+ 'boxes': [{
666
+ 'cls_id': 0,
667
+ 'label': 'text',
668
+ 'score': 1.0,
669
+ 'coordinate': [0, 0, ori_w, ori_h]
670
+ }]
671
+ }
672
+ logger.info(' Layout detection disabled, processing whole image')
673
+
674
+ # 确定 image_labels
675
+ image_labels = (IMAGE_LABELS if self.use_chart_recognition else
676
+ IMAGE_LABELS + ['chart'])
677
+
678
+ # 裁剪图像区域并合并布局块
679
+ boxes = layout_results['boxes']
680
+ blocks = []
681
+ for box in boxes:
682
+ coord = box['coordinate']
683
+ x1, y1, x2, y2 = map(int, coord)
684
+ cropped_img = image[y1:y2, x1:x2]
685
+ if cropped_img.size == 0:
686
+ cropped_img = None
687
+ blocks.append({
688
+ 'img': cropped_img,
689
+ 'box': coord,
690
+ 'label': box['label'],
691
+ 'score': box.get('score', 1.0),
692
+ })
693
+
694
+ # 合并布局块
695
+ if merge_layout_blocks:
696
+ blocks = merge_blocks(blocks,
697
+ non_merge_labels=image_labels + ['table'])
698
+
699
+ # 收集需要VLM处理的blocks
700
+ block_imgs = []
701
+ text_prompts = []
702
+ block_labels = []
703
+ vlm_block_ids = []
704
+ figure_token_maps = []
705
+ drop_figures_set = set()
706
+ imgs_in_doc = [] # 当前图像中的图片区域
707
+
708
+ for j, block in enumerate(blocks):
709
+ block_label = block['label']
710
+ # 提取基础标签名(去除编号后缀)
711
+ base_label = block_label.rsplit(
712
+ '_', 1)[0] if '_' in block_label and block_label.rsplit(
713
+ '_', 1)[1].isdigit() else block_label
714
+ if base_label in image_labels and block['img'] is not None:
715
+ x_min, y_min, x_max, y_max = list(map(int, block['box']))
716
+ img_path = f'imgs/img_in_{base_label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg'
717
+ imgs_in_doc.append({
718
+ 'coordinate': block['box'],
719
+ 'path': img_path
720
+ })
721
+
722
+ # 处理每个block
723
+ for j, block in enumerate(blocks):
724
+ block_img = block['img']
725
+ block_label = block['label']
726
+ # 提取基础标签名(去除编号后缀)
727
+ base_label = block_label.rsplit(
728
+ '_', 1)[0] if '_' in block_label and block_label.rsplit(
729
+ '_', 1)[1].isdigit() else block_label
730
+
731
+ if base_label not in image_labels and block_img is not None:
732
+ figure_token_map = {}
733
+ text_prompt = 'OCR:'
734
+ drop_figures = []
735
+
736
+ if 'table' in block_label:
737
+ text_prompt = 'Table Recognition:'
738
+ block_img, figure_token_map, drop_figures = (
739
+ tokenize_figure_of_table(block_img, block['box'],
740
+ imgs_in_doc))
741
+ elif block_label == 'chart' and self.use_chart_recognition:
742
+ text_prompt = 'Chart Recognition:'
743
+ elif 'formula' in block_label and block_label != 'formula_number':
744
+ text_prompt = 'Formula Recognition:'
745
+ block_img = crop_margin(block_img)
746
+
747
+ block_imgs.append(block_img)
748
+ text_prompts.append(text_prompt)
749
+ block_labels.append(block_label)
750
+ figure_token_maps.append(figure_token_map)
751
+ vlm_block_ids.append(j)
752
+ drop_figures_set.update(drop_figures)
753
+
754
+ # VLM识别
755
+ vl_rec_results = []
756
+
757
+ for block_img, block_label in zip(block_imgs, block_labels):
758
+ # 转换为RGB PIL Image
759
+ block_img_rgb = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
760
+ pil_image = Image.fromarray(block_img_rgb)
761
+
762
+ try:
763
+ text, token_ids = self.vlm_recognizer(
764
+ image=pil_image, max_length=max_length)
765
+ except Exception as e:
766
+ logger.error(f' Error processing block: {e}')
767
+ text = ''
768
+
769
+ # 使用 markdown_converter 进行后处理
770
+ if 'table' in block_label:
771
+ text = markdown_converter._handle_table(text)
772
+ elif 'formula' in block_label and block_label != 'formula_number':
773
+ text = markdown_converter._handle_formula(text)
774
+ else:
775
+ text = markdown_converter._handle_text(text)
776
+
777
+ vl_rec_results.append(text)
778
+
779
+ # 组装
780
+ recognition_results = []
781
+ curr_vlm_block_idx = 0
782
+
783
+ for j, block in enumerate(blocks):
784
+ block_img = block['img']
785
+ block_bbox = block['box']
786
+ block_label = block['label']
787
+ block_content = ''
788
+
789
+ if curr_vlm_block_idx < len(
790
+ vlm_block_ids) and vlm_block_ids[curr_vlm_block_idx] == j:
791
+ result_str = vl_rec_results[curr_vlm_block_idx]
792
+ figure_token_map = figure_token_maps[curr_vlm_block_idx]
793
+ curr_vlm_block_idx += 1
794
+
795
+ if result_str is None:
796
+ result_str = ''
797
+
798
+ # 截断重复内容
799
+ result_str = truncate_repetitive_content(result_str)
800
+
801
+ # 处理公式符号替换
802
+ has_paren = '\\(' in result_str and '\\)' in result_str
803
+ has_bracket = '\\[' in result_str and '\\]' in result_str
804
+ if has_paren or has_bracket:
805
+ result_str = result_str.replace('$', '')
806
+ result_str = (result_str.replace('\\(', ' $ ').replace(
807
+ '\\)', ' $ ').replace('\\[',
808
+ ' $$ ').replace('\\]', ' $$ '))
809
+ if block_label == 'formula_number':
810
+ result_str = result_str.replace('$', '')
811
+
812
+ # 对 table 结果进行 OTSL 转 HTML 和 untokenize
813
+ if 'table' in block_label:
814
+ html_str = convert_otsl_to_html(result_str)
815
+ if html_str != '':
816
+ result_str = html_str
817
+ result_str = untokenize_figure_of_table(
818
+ result_str, figure_token_map)
819
+
820
+ block_content = result_str
821
+
822
+ # 处理图像类标签(去除编号后缀判断)
823
+ base_label = block_label.rsplit(
824
+ '_', 1)[0] if '_' in block_label and block_label.rsplit(
825
+ '_', 1)[1].isdigit() else block_label
826
+
827
+ # 判断是否是合并块的后续部分(img 为 None 表示是合并块的后续部分)
828
+ is_merged_continuation = block_img is None
829
+
830
+ if base_label in image_labels and block_img is not None:
831
+ x_min, y_min, x_max, y_max = list(map(int, block_bbox))
832
+ img_path = f'imgs/img_in_{base_label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg'
833
+ # 不跳过表格中的图片,需要保存它们
834
+ # if img_path in drop_figures_set:
835
+ # continue
836
+ recognition_results.append({
837
+ 'label': block_label,
838
+ 'bbox': block_bbox,
839
+ 'score': block.get('score', 1.0),
840
+ 'text': '',
841
+ 'text_unirec': '',
842
+ 'is_image': True,
843
+ 'img_path': img_path,
844
+ 'is_merged_continuation': False,
845
+ 'in_table': img_path in drop_figures_set # 标记是否在表格中
846
+ })
847
+ else:
848
+ recognition_results.append({
849
+ 'label': block_label,
850
+ 'bbox': block_bbox,
851
+ 'score': block.get('score', 1.0),
852
+ 'text': block_content,
853
+ 'text_unirec': block_content,
854
+ 'is_image': False,
855
+ 'is_merged_continuation': is_merged_continuation
856
+ })
857
+
858
+ total_time = time.time() - start_time
859
+ logger.info(f' Total time: {total_time: .3f}s')
860
+
861
+ result = {
862
+ 'input_path': actual_path if not is_temp_file else '<numpy_array>',
863
+ 'width': ori_w,
864
+ 'height': ori_h,
865
+ 'layout_results': layout_results,
866
+ 'recognition_results': recognition_results,
867
+ 'blocks': blocks,
868
+ 'timing': {
869
+ 'total': total_time,
870
+ }
871
+ }
872
+
873
+ # Clean up temporary file if created
874
+ if is_temp_file and os.path.exists(actual_path):
875
+ os.remove(actual_path)
876
+
877
+ return result
878
+
879
+ def save_to_json(self, result: Dict, output_path: str):
880
+ """保存结果为JSON"""
881
+ if 'layout_results' in result:
882
+ del result['layout_results']
883
+
884
+ if 'blocks' in result:
885
+ del result['blocks']
886
+
887
+ img_name, img_dir = _get_image_name_and_dir(result, output_path)
888
+ json_path = os.path.join(img_dir, f'{img_name}.json')
889
+
890
+ with open(json_path, 'w', encoding='utf-8') as f:
891
+ json.dump(result, f, ensure_ascii=False, indent=2)
892
+
893
+ # logger.info(f" Saved JSON to {json_path}")
894
+
895
+ def save_to_markdown(self, result: Dict, output_path: str):
896
+ """保存结果为Markdown,按阅读顺序包含图片"""
897
+ img_name, img_dir = _get_image_name_and_dir(result, output_path)
898
+ md_path = os.path.join(img_dir, f'{img_name}.md')
899
+
900
+ # 创建imgs子目录
901
+ imgs_dir = os.path.join(img_dir, 'imgs')
902
+ os.makedirs(imgs_dir, exist_ok=True)
903
+
904
+ # 读取原始图像用于裁剪保存图片
905
+ original_image = cv2.imread(result['input_path'])
906
+ ori_width = result.get(
907
+ 'width',
908
+ original_image.shape[1] if original_image is not None else 1)
909
+
910
+ # 保存所有图片区域(包括表格中的图片)
911
+ if original_image is not None:
912
+ for rec in result['recognition_results']:
913
+ if rec.get('is_image', False):
914
+ img_path = rec.get('img_path', '')
915
+ if img_path:
916
+ bbox = rec.get('bbox', [])
917
+ if bbox:
918
+ x1, y1, x2, y2 = map(int, bbox)
919
+ cropped_img = original_image[y1:y2, x1:x2]
920
+ if cropped_img.size > 0:
921
+ save_img_path = os.path.join(img_dir, img_path)
922
+ os.makedirs(os.path.dirname(save_img_path), exist_ok=True)
923
+ cv2.imwrite(save_img_path, cropped_img)
924
+
925
+ with open(md_path, 'w', encoding='utf-8') as f:
926
+ pending_text = [] # 用于收集合并块的文本
927
+ pending_label = None # 当前合并块的标签类型
928
+
929
+ for rec in result['recognition_results']:
930
+ # 获取基础标签名(去除编号后缀,如 text_01 -> text)
931
+ label = rec['label']
932
+ base_label = label.rsplit(
933
+ '_', 1)[0] if '_' in label and label.rsplit(
934
+ '_', 1)[1].isdigit() else label
935
+
936
+ # 跳过忽略的标签
937
+ if base_label in self.markdown_ignore_labels:
938
+ continue
939
+
940
+ # 处理图片类型
941
+ if rec.get('is_image', False):
942
+ # 先输出之前累积的文本
943
+ if pending_text:
944
+ self._write_merged_text(f, pending_text, pending_label)
945
+ pending_text = []
946
+ pending_label = None
947
+
948
+ # 如果图片在表格中,跳过在markdown中独立显示(已在表格HTML中引用)
949
+ if rec.get('in_table', False):
950
+ continue
951
+
952
+ img_path = rec.get('img_path', '')
953
+ if img_path:
954
+ # 计算图片宽度占原图的百分比
955
+ bbox = rec.get('bbox', [])
956
+ if bbox:
957
+ x1, y1, x2, y2 = map(int, bbox)
958
+ img_width = x2 - x1
959
+ width_percent = int((img_width / ori_width) * 100)
960
+ width_percent = max(5, min(width_percent, 100)) # 限制在5%-100%之间
961
+ else:
962
+ width_percent = 50 # 默认50%
963
+ f.write(
964
+ f'<img src="{img_path}" alt="Image" width="{width_percent}%" />\\n\\n'
965
+ )
966
+ continue
967
+ text = rec['text'].strip()
968
+ if not text:
969
+ continue
970
+
971
+ # 检查是否是合并块的后续部分
972
+ is_merged_continuation = rec.get('is_merged_continuation', False)
973
+
974
+ if is_merged_continuation and pending_text:
975
+ # 是合并块的后续部分,追加文本
976
+ pending_text.append(text)
977
+ else:
978
+ # 先输出之前累积的文本
979
+ if pending_text:
980
+ self._write_merged_text(f, pending_text, pending_label)
981
+ pending_text = []
982
+ pending_label = None
983
+
984
+ # 开始新的文本块
985
+ pending_text.append(text)
986
+ pending_label = base_label
987
+
988
+ # 输出最后累积的文本
989
+ if pending_text:
990
+ self._write_merged_text(f, pending_text, pending_label)
991
+
992
+ def _write_merged_text(self, f, texts: List[str], base_label: str):
993
+ """将合并的文本写入文件"""
994
+ merged_text = ' '.join(texts)
995
+
996
+ # 根据标签类型格式化输出
997
+ if 'title' in base_label or base_label == 'doc_title':
998
+ f.write(f'## {merged_text}\n\n')
999
+ elif 'table' in base_label:
1000
+ f.write(f'{merged_text}\n\n')
1001
+ elif 'formula' in base_label or base_label == 'equation':
1002
+ f.write(f'$${merged_text}$$\n\n')
1003
+ else:
1004
+ f.write(f'{merged_text}\n\n')
1005
+
1006
+ def save_visualization(self, result: Dict, output_path: str):
1007
+ """保存可视化结果"""
1008
+ img_name, img_dir = _get_image_name_and_dir(result, output_path)
1009
+ vis_path = os.path.join(img_dir, f'{img_name}_vis.jpg')
1010
+
1011
+ image = cv2.imread(result['input_path'])
1012
+
1013
+ for box_info in result['layout_results']['boxes']:
1014
+ x1, y1, x2, y2 = map(int, box_info['coordinate'])
1015
+ label = box_info['label']
1016
+ score = box_info['score']
1017
+
1018
+ # 提取基础标签名(去除编号后缀,如 text_01 -> text)
1019
+ base_label = label.rsplit('_', 1)[0] if '_' in label and label.rsplit('_', 1)[1].isdigit() else label
1020
+
1021
+ # 获取颜色,如果没有定义则使用默认红色
1022
+ color = self.colors.get(base_label, (255, 0, 0))
1023
+
1024
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
1025
+ cv2.putText(image, f'{label}: {score: .2f}', (x1, y1 - 10),
1026
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
1027
+
1028
+ cv2.imwrite(vis_path, image)
1029
+ # logger.info(f" Saved visualization to {vis_path}")
1030
+
1031
+
1032
+ # ==================== Main Function ====================
1033
+ def main():
1034
+ desc = 'OpenDoc ONNX Pipeline - Full Document OCR with Layout Detection'
1035
+ parser = argparse.ArgumentParser(description=desc)
1036
+
1037
+ # Input/Output
1038
+ parser.add_argument('--input_path',
1039
+ type=str,
1040
+ required=True,
1041
+ help='Path to input image or directory')
1042
+ parser.add_argument('--output_path',
1043
+ type=str,
1044
+ default='./output_onnx',
1045
+ help='Path to save output results')
1046
+
1047
+ # Model paths
1048
+ parser.add_argument('--layout_model',
1049
+ type=str,
1050
+ default=None,
1051
+ help='Path to layout detection ONNX model (default: ~/.cache/openocr/PP_DoclayoutV2_onnx/PP-DoclayoutV2.onnx)')
1052
+ parser.add_argument('--encoder_model',
1053
+ type=str,
1054
+ default=None,
1055
+ help='Path to UniRec encoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_encoder.onnx)')
1056
+ parser.add_argument('--decoder_model',
1057
+ type=str,
1058
+ default=None,
1059
+ help='Path to UniRec decoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_decoder.onnx)')
1060
+ parser.add_argument('--tokenizer_mapping',
1061
+ type=str,
1062
+ default=None,
1063
+ help='Path to tokenizer mapping JSON file (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_tokenizer_mapping.json)')
1064
+
1065
+ # Settings
1066
+ parser.add_argument('--use-gpu',
1067
+ type=str,
1068
+ default='auto',
1069
+ choices=['auto', 'true', 'false'],
1070
+ help='Use GPU for inference (auto: auto-detect, true: force GPU, false: force CPU)')
1071
+ parser.add_argument('--layout_threshold',
1072
+ type=float,
1073
+ default=0.4,
1074
+ help='Layout detection threshold')
1075
+ parser.add_argument('--max_length',
1076
+ type=int,
1077
+ default=2048,
1078
+ help='Maximum generation length for VLM')
1079
+ parser.add_argument('--use_layout_detection',
1080
+ action='store_true',
1081
+ help='Use layout detection')
1082
+ parser.add_argument('--no_layout_detection',
1083
+ dest='use_layout_detection',
1084
+ action='store_false',
1085
+ help='Disable layout detection (process whole image)')
1086
+ parser.add_argument('--use_chart_recognition',
1087
+ action='store_true',
1088
+ help='Recognize charts')
1089
+ parser.add_argument('--no-auto-download',
1090
+ action='store_true',
1091
+ help='Disable automatic model download')
1092
+
1093
+ # Output formats
1094
+ parser.add_argument('--save_vis',
1095
+ action='store_true',
1096
+ help='Save visualization images')
1097
+ parser.add_argument('--save_json',
1098
+ action='store_true',
1099
+ help='Save JSON results')
1100
+ parser.add_argument('--save_markdown',
1101
+ action='store_true',
1102
+ help='Save Markdown results')
1103
+
1104
+ args = parser.parse_args()
1105
+
1106
+ # Parse use_gpu argument
1107
+ if args.use_gpu == 'auto':
1108
+ use_gpu = None
1109
+ elif args.use_gpu == 'true':
1110
+ use_gpu = True
1111
+ else:
1112
+ use_gpu = False
1113
+
1114
+ # 创建输出目录
1115
+ os.makedirs(args.output_path, exist_ok=True)
1116
+
1117
+ opendoc_onnx = OpenDocONNX(
1118
+ layout_model_path=args.layout_model,
1119
+ unirec_encoder_path=args.encoder_model,
1120
+ unirec_decoder_path=args.decoder_model,
1121
+ tokenizer_mapping_path=args.tokenizer_mapping,
1122
+ use_gpu=use_gpu,
1123
+ layout_threshold=args.layout_threshold,
1124
+ use_layout_detection=args.use_layout_detection,
1125
+ use_chart_recognition=args.use_chart_recognition,
1126
+ auto_download=not args.no_auto_download,
1127
+ )
1128
+
1129
+ # 获取图像列表
1130
+ img_list = get_image_file_list(args.input_path)
1131
+ logger.info(f'\nFound {len(img_list)} images in {args.input_path}')
1132
+ logger.info(f'Output will be saved to: {args.output_path}')
1133
+ logger.info('=' * 80)
1134
+
1135
+ # 处理每张图像
1136
+ for idx, img_path in enumerate(img_list):
1137
+ logger.info(
1138
+ f'\n[{idx + 1}/{len(img_list)}] Processing: {os.path.basename(img_path)}'
1139
+ )
1140
+
1141
+ try:
1142
+ # 预测
1143
+ result = opendoc_onnx(
1144
+ img_path=img_path,
1145
+ layout_threshold=args.layout_threshold,
1146
+ max_length=args.max_length,
1147
+ )
1148
+
1149
+ # 保存结果
1150
+ if args.save_vis:
1151
+ opendoc_onnx.save_visualization(result, args.output_path)
1152
+
1153
+ if args.save_json:
1154
+ opendoc_onnx.save_to_json(result, args.output_path)
1155
+
1156
+ if args.save_markdown:
1157
+ opendoc_onnx.save_to_markdown(result, args.output_path)
1158
+
1159
+ except Exception as e:
1160
+ logger.error(f'Error processing {img_path}: {str(e)}')
1161
+ import traceback
1162
+ traceback.print_exc()
1163
+ continue
1164
+
1165
+ logger.info('\n' + '=' * 80)
1166
+ logger.info(
1167
+ f'✅ All processing completed! Results saved to {args.output_path}')
1168
+ logger.info('=' * 80)
1169
+
1170
+
1171
+ if __name__ == '__main__':
1172
+ main()