openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
openocr/demo_unirec.py ADDED
@@ -0,0 +1,392 @@
1
+ import re
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ from threading import Thread
6
+ import queue
7
+ import time
8
+
9
+ # Import ONNX inference components
10
+ import sys
11
+ import os
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+ from tools.infer_unirec_onnx import UniRecONNX, clean_special_tokens
14
+ from tools.download_example_images import get_example_images_path
15
+ from tools.to_markdown import MarkdownConverter
16
+
17
+ # 创建全局 markdown_converter 实例
18
+ markdown_converter = MarkdownConverter()
19
+
20
+ # LaTeX delimiters for formula rendering
21
+ LATEX_DELIMS = [
22
+ {
23
+ 'left': '$$',
24
+ 'right': '$$',
25
+ 'display': True
26
+ },
27
+ {
28
+ 'left': '$',
29
+ 'right': '$',
30
+ 'display': False
31
+ },
32
+ {
33
+ 'left': '\\(',
34
+ 'right': '\\)',
35
+ 'display': False
36
+ },
37
+ {
38
+ 'left': '\\[',
39
+ 'right': '\\]',
40
+ 'display': True
41
+ },
42
+ ]
43
+
44
+ # --- 1. Initialize ONNX Model ---
45
+ def initialize_model(
46
+ encoder_path=None,
47
+ decoder_path=None,
48
+ mapping_path=None,
49
+ use_gpu=None,
50
+ auto_download=True
51
+ ):
52
+ """Initialize ONNX inference model.
53
+
54
+ Args:
55
+ encoder_path: Path to encoder ONNX model. If None, use default cache directory.
56
+ decoder_path: Path to decoder ONNX model. If None, use default cache directory.
57
+ mapping_path: Path to tokenizer mapping JSON. If None, use default cache directory.
58
+ use_gpu: Whether to use GPU. If None, auto-detect. If True, force GPU. If False, force CPU.
59
+ auto_download: If True, automatically download missing model files
60
+ """
61
+ print('Initializing UniRec ONNX model...')
62
+ inference = UniRecONNX(
63
+ encoder_path=encoder_path,
64
+ decoder_path=decoder_path,
65
+ mapping_path=mapping_path,
66
+ use_gpu=use_gpu,
67
+ auto_download=auto_download
68
+ )
69
+ print('✅ Model initialized successfully!')
70
+ return inference
71
+
72
+
73
+ # Global model instance (will be initialized in main)
74
+ model = None
75
+
76
+
77
+ # --- 2. Streaming generation function ---
78
+ def stream_generate(inference, image, max_length=2048, result_queue=None):
79
+ """Generate text with streaming output."""
80
+ # Get token IDs
81
+ bos_token_id = inference.tokenizer.bos_token_id
82
+ eos_token_id = inference.tokenizer.eos_token_id
83
+ pad_token_id = inference.tokenizer.pad_token_id
84
+
85
+ # Encode image
86
+ encoder_hidden_states, cross_k, cross_v = inference.encode_image(image)
87
+
88
+ # Initialize generation
89
+ generated_ids = [bos_token_id]
90
+
91
+ # Initialize empty past_key_values
92
+ batch_size = encoder_hidden_states.shape[0]
93
+ past_key_values = []
94
+ for _ in range(inference.num_decoder_layers):
95
+ empty_key = np.zeros(
96
+ (batch_size, inference.num_heads, 0, inference.head_dim),
97
+ dtype=np.float32)
98
+ empty_value = np.zeros(
99
+ (batch_size, inference.num_heads, 0, inference.head_dim),
100
+ dtype=np.float32)
101
+ past_key_values.append((empty_key, empty_value))
102
+ cleaned_text = ''
103
+ put_token_num = 30
104
+ # Generation loop with streaming
105
+ for step in range(max_length - 1):
106
+ current_token = generated_ids[-1]
107
+ past_length = step
108
+
109
+ # Decode step
110
+ logits, past_key_values = inference.decode_step(
111
+ current_token,
112
+ past_length,
113
+ cross_k,
114
+ cross_v,
115
+ past_key_values,
116
+ padding_idx=pad_token_id
117
+ )
118
+
119
+ # Get next token
120
+ next_token_id = int(np.argmax(logits[0, -1, :]))
121
+ generated_ids.append(next_token_id)
122
+
123
+ # Decode current sequence and put in queue
124
+ if result_queue is not None:
125
+ current_text = inference.tokenizer.decode(generated_ids[-1:], skip_special_tokens=False)
126
+ # print(current_text+'\n')
127
+ cleaned_text = cleaned_text + clean_special_tokens(current_text)
128
+ # Post-process HTML table attributes
129
+ # cleaned_text = cleaned_text.replace('<tdcolspan=', '<td colspan=')
130
+ # cleaned_text = cleaned_text.replace('<tdrowspan=', '<td rowspan=')
131
+ # cleaned_text = cleaned_text.replace('"colspan=', '" colspan=')
132
+ if (step + 1) % put_token_num == 0:
133
+ result_queue.put(cleaned_text)
134
+ result_queue.put(cleaned_text)
135
+ # Check for EOS
136
+ if next_token_id == eos_token_id:
137
+ break
138
+
139
+ # Signal completion
140
+ if result_queue is not None:
141
+ result_queue.put(None)
142
+
143
+
144
+ # --- 3. Gradio streaming function for dual display ---
145
+ def stream_recognize_image(input_image):
146
+ """Stream recognition results with dual display: markdown text only during recognition, render after completion."""
147
+ if input_image is None:
148
+ yield '请先上传一张图片。', '**请先上传一张图片。**'
149
+ return
150
+
151
+ # Convert to PIL Image if needed
152
+ if not isinstance(input_image, Image.Image):
153
+ input_image = Image.fromarray(input_image).convert('RGB')
154
+ else:
155
+ input_image = input_image.convert('RGB')
156
+
157
+ # Create queue for streaming results
158
+ result_queue = queue.Queue()
159
+
160
+ # Start generation in background thread
161
+ thread = Thread(target=stream_generate, args=(model, input_image, 2048, result_queue))
162
+ thread.daemon = True # Set as daemon thread
163
+ thread.start()
164
+
165
+ # Stream results - only update markdown text, keep render area with "recognizing" message
166
+ last_update_time = time.time()
167
+ current_text = ''
168
+
169
+ while True:
170
+ try:
171
+ # Get result with longer timeout
172
+ result = result_queue.get(timeout=1.0)
173
+ if result is None: # Generation complete
174
+ break
175
+ current_text = result
176
+ last_update_time = time.time()
177
+
178
+ # Only update markdown text, show "recognizing" message in render area
179
+ yield current_text, '_正在识别中,请稍候..._'
180
+
181
+ except queue.Empty:
182
+ # No new result yet, check if thread is still alive
183
+ if not thread.is_alive():
184
+ # Thread finished but no completion signal, break
185
+ break
186
+ # Yield current state periodically to keep UI responsive
187
+ current_time = time.time()
188
+ if current_time - last_update_time > 0.5: # Update UI every 0.5s
189
+ yield current_text if current_text else '正在识别中...', '_正在识别中,请稍候..._'
190
+ last_update_time = current_time
191
+
192
+ # Wait for thread to complete
193
+ thread.join(timeout=2.0)
194
+
195
+ # Final yield - now render the complete result
196
+ formatted_result = format_markdown_output(current_text)
197
+ yield formatted_result, formatted_result
198
+
199
+
200
+ def format_markdown_output(markdown_text):
201
+ """Format markdown text for display.
202
+
203
+ This function handles:
204
+ - HTML tables (pass through as-is for Gradio Markdown rendering)
205
+ - LaTeX formulas (already in proper format)
206
+ - Basic markdown formatting
207
+ """
208
+ if not markdown_text:
209
+ return '_等待识别结果..._'
210
+ if '<table>' in markdown_text:
211
+ markdown_text = markdown_converter._handle_table(markdown_text)
212
+ if '\\(' in markdown_text or '\\[' in markdown_text:
213
+ # extract the formula
214
+ formula_pattern = r'\n\n\\\[.*?\\\]\n\n'
215
+ # print(re.findall(formula_pattern, markdown_text, flags=re.DOTALL))
216
+ # markdown_text = re.sub(formula_pattern, markdown_converter._handle_formula, markdown_text, flags=re.DOTALL)
217
+ for formula in re.findall(formula_pattern, markdown_text, flags=re.DOTALL):
218
+ markdown_text = markdown_text.replace(formula, markdown_converter._handle_formula(formula))
219
+ if '\\(' in markdown_text:
220
+ markdown_text = markdown_text.replace('\\(', '$')
221
+ markdown_text = markdown_text.replace('\\)', '$')
222
+ # Return the markdown text as-is
223
+ # Gradio's Markdown component will handle the rendering
224
+ return markdown_text
225
+
226
+
227
+ # --- 4. Gradio Interface ---
228
+ # Get example images path and download if necessary
229
+ example_img_dir = get_example_images_path(demo_type='unirec')
230
+
231
+ # Get list of example images
232
+ example_images = []
233
+ if os.path.exists(example_img_dir):
234
+ for file in os.listdir(example_img_dir):
235
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
236
+ example_images.append(os.path.join(example_img_dir, file))
237
+ example_images = sorted(example_images)
238
+
239
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
240
+ gr.HTML("""
241
+ <h1 style='text-align: center;'><a href="https://github.com/Topdu/OpenOCR">UniRec-0.1B: Unified Text and Formula Recognition with 0.1B Parameters</a></h1>
242
+ <p style='text-align: center;'>0.1B超轻量模型统一文本与公式识别(由<a href="https://fvl.fudan.edu.cn">FVL实验室</a> <a href="https://github.com/Topdu/OpenOCR">OCR Team</a> 创建)</p>
243
+ <p style='text-align: center;'><a href="https://github.com/Topdu/OpenOCR/blob/main/docs/unirec.md">[本地GPU部署]</a>获取快速识别体验</p>"""
244
+ )
245
+ gr.Markdown('上传一张图片,点击"运行识别"按钮进行文本和公式识别。')
246
+ with gr.Row():
247
+ with gr.Column(scale=4): # 左侧竖排:图片 + 按钮
248
+ image_input = gr.Image(label='上传图片 or 粘贴截图', type='pil')
249
+
250
+ # Add examples if available
251
+ if example_images:
252
+ gr.Examples(
253
+ examples=example_images,
254
+ inputs=image_input,
255
+ label='📚 示例图片'
256
+ )
257
+
258
+ with gr.Row():
259
+ run_button = gr.Button('🚀 运行识别', variant='primary')
260
+ clear_button = gr.Button('🗑️ 清空', variant='secondary')
261
+
262
+ with gr.Column(scale=6):
263
+ with gr.Tabs():
264
+ with gr.Tab('📝 Markdown Source'):
265
+ markdown_output = gr.Code(label='Markdown Source',
266
+ language='markdown',
267
+ lines=20)
268
+ with gr.Tab('📝 Markdown Preview'):
269
+ markdown_render = gr.Markdown(
270
+ value='_渲染后的表格/公式将显示在这里..._',
271
+ latex_delimiters=LATEX_DELIMS,
272
+ elem_id='md_preview')
273
+
274
+ # 点击运行按钮后触发
275
+ run_button.click(
276
+ stream_recognize_image,
277
+ inputs=[image_input],
278
+ outputs=[markdown_output, markdown_render]
279
+ )
280
+
281
+ # 清空按钮功能:清空图片和输出结果
282
+ def clear_all():
283
+ return None, '', '_渲染后的表格/公式将显示在这里..._'
284
+
285
+ clear_button.click(
286
+ clear_all,
287
+ outputs=[image_input, markdown_output, markdown_render]
288
+ )
289
+
290
+
291
+ def launch_demo(
292
+ encoder_path=None,
293
+ decoder_path=None,
294
+ mapping_path=None,
295
+ use_gpu=None,
296
+ auto_download=True,
297
+ share=False,
298
+ server_name='0.0.0.0',
299
+ server_port=7860
300
+ ):
301
+ """Launch UniRec ONNX Gradio demo with default configuration.
302
+
303
+ Args:
304
+ encoder_path: Path to encoder ONNX model (default: auto-download)
305
+ decoder_path: Path to decoder ONNX model (default: auto-download)
306
+ mapping_path: Path to tokenizer mapping JSON (default: auto-download)
307
+ use_gpu: Whether to use GPU. If None, auto-detect (default: None)
308
+ auto_download: If True, automatically download missing models (default: True)
309
+ share: Create a public share link (default: False)
310
+ server_name: Server name for Gradio (default: '0.0.0.0')
311
+ server_port: Server port for Gradio (default: 7860)
312
+
313
+ Returns:
314
+ gr.Blocks: Gradio demo instance
315
+ """
316
+ global model
317
+
318
+ # Initialize model with specified parameters
319
+ model = initialize_model(
320
+ encoder_path=encoder_path,
321
+ decoder_path=decoder_path,
322
+ mapping_path=mapping_path,
323
+ use_gpu=use_gpu,
324
+ auto_download=auto_download
325
+ )
326
+
327
+ # Launch demo
328
+ demo.queue().launch(
329
+ share=share,
330
+ server_name=server_name,
331
+ server_port=server_port
332
+ )
333
+ return demo
334
+
335
+
336
+ # --- 5. Launch application ---
337
+ if __name__ == '__main__':
338
+ import argparse
339
+
340
+ parser = argparse.ArgumentParser(description='UniRec ONNX Gradio Demo')
341
+ parser.add_argument('--encoder_model',
342
+ type=str,
343
+ default=None,
344
+ help='Path to encoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_encoder.onnx)')
345
+ parser.add_argument('--decoder_model',
346
+ type=str,
347
+ default=None,
348
+ help='Path to decoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_decoder.onnx)')
349
+ parser.add_argument('--mapping',
350
+ type=str,
351
+ default=None,
352
+ help='Path to tokenizer mapping JSON (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_tokenizer_mapping.json)')
353
+ parser.add_argument('--use-gpu',
354
+ type=str,
355
+ default='auto',
356
+ choices=['auto', 'true', 'false'],
357
+ help='Use GPU for inference (auto: auto-detect, true: force GPU, false: force CPU)')
358
+ parser.add_argument('--no-auto-download',
359
+ action='store_true',
360
+ help='Disable automatic model download')
361
+ parser.add_argument('--share',
362
+ action='store_true',
363
+ help='Create a public share link')
364
+ parser.add_argument('--server-name',
365
+ type=str,
366
+ default='0.0.0.0',
367
+ help='Server name for Gradio')
368
+ parser.add_argument('--server-port',
369
+ type=int,
370
+ default=7860,
371
+ help='Server port for Gradio')
372
+ args = parser.parse_args()
373
+
374
+ # Parse use_gpu argument
375
+ if args.use_gpu == 'auto':
376
+ use_gpu = None
377
+ elif args.use_gpu == 'true':
378
+ use_gpu = True
379
+ else:
380
+ use_gpu = False
381
+
382
+ # Launch demo with parsed arguments
383
+ launch_demo(
384
+ encoder_path=args.encoder_model,
385
+ decoder_path=args.decoder_model,
386
+ mapping_path=args.mapping,
387
+ use_gpu=use_gpu,
388
+ auto_download=not args.no_auto_download,
389
+ share=args.share,
390
+ server_name=args.server_name,
391
+ server_port=args.server_port
392
+ )
@@ -10,13 +10,11 @@ def build_loss(config):
10
10
  config = copy.deepcopy(config)
11
11
  module_name = config.pop('name')
12
12
  assert module_name in name_to_module, Exception(
13
- 'loss only support {}'.format(list(name_to_module.keys())))
13
+ '{} is not supported. The losses in {} are supportes'.format(
14
+ module_name, list(name_to_module.keys())))
14
15
 
15
- if module_name in globals():
16
- module_class = globals()[module_name]
17
- else:
18
- module_path = name_to_module[module_name]
19
- module = import_module(module_path, package=__package__)
20
- module_class = getattr(module, module_name)
16
+ module_path = name_to_module[module_name]
17
+ module = import_module(module_path, package=__package__)
18
+ module_class = getattr(module, module_name)
21
19
 
22
20
  return module_class(**config)
@@ -111,7 +111,8 @@ def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
111
111
  else:
112
112
  ymin, ymax = random_select(h_axis, h)
113
113
 
114
- if (xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h):
114
+ if (xmax - xmin < min_crop_side_ratio * w
115
+ or ymax - ymin < min_crop_side_ratio * h):
115
116
  # area too small
116
117
  continue
117
118
  num_poly_in_rect = 0