openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
openocr/demo_unirec.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import gradio as gr
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from threading import Thread
|
|
6
|
+
import queue
|
|
7
|
+
import time
|
|
8
|
+
|
|
9
|
+
# Import ONNX inference components
|
|
10
|
+
import sys
|
|
11
|
+
import os
|
|
12
|
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
13
|
+
from tools.infer_unirec_onnx import UniRecONNX, clean_special_tokens
|
|
14
|
+
from tools.download_example_images import get_example_images_path
|
|
15
|
+
from tools.to_markdown import MarkdownConverter
|
|
16
|
+
|
|
17
|
+
# 创建全局 markdown_converter 实例
|
|
18
|
+
markdown_converter = MarkdownConverter()
|
|
19
|
+
|
|
20
|
+
# LaTeX delimiters for formula rendering
|
|
21
|
+
LATEX_DELIMS = [
|
|
22
|
+
{
|
|
23
|
+
'left': '$$',
|
|
24
|
+
'right': '$$',
|
|
25
|
+
'display': True
|
|
26
|
+
},
|
|
27
|
+
{
|
|
28
|
+
'left': '$',
|
|
29
|
+
'right': '$',
|
|
30
|
+
'display': False
|
|
31
|
+
},
|
|
32
|
+
{
|
|
33
|
+
'left': '\\(',
|
|
34
|
+
'right': '\\)',
|
|
35
|
+
'display': False
|
|
36
|
+
},
|
|
37
|
+
{
|
|
38
|
+
'left': '\\[',
|
|
39
|
+
'right': '\\]',
|
|
40
|
+
'display': True
|
|
41
|
+
},
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
# --- 1. Initialize ONNX Model ---
|
|
45
|
+
def initialize_model(
|
|
46
|
+
encoder_path=None,
|
|
47
|
+
decoder_path=None,
|
|
48
|
+
mapping_path=None,
|
|
49
|
+
use_gpu=None,
|
|
50
|
+
auto_download=True
|
|
51
|
+
):
|
|
52
|
+
"""Initialize ONNX inference model.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
encoder_path: Path to encoder ONNX model. If None, use default cache directory.
|
|
56
|
+
decoder_path: Path to decoder ONNX model. If None, use default cache directory.
|
|
57
|
+
mapping_path: Path to tokenizer mapping JSON. If None, use default cache directory.
|
|
58
|
+
use_gpu: Whether to use GPU. If None, auto-detect. If True, force GPU. If False, force CPU.
|
|
59
|
+
auto_download: If True, automatically download missing model files
|
|
60
|
+
"""
|
|
61
|
+
print('Initializing UniRec ONNX model...')
|
|
62
|
+
inference = UniRecONNX(
|
|
63
|
+
encoder_path=encoder_path,
|
|
64
|
+
decoder_path=decoder_path,
|
|
65
|
+
mapping_path=mapping_path,
|
|
66
|
+
use_gpu=use_gpu,
|
|
67
|
+
auto_download=auto_download
|
|
68
|
+
)
|
|
69
|
+
print('✅ Model initialized successfully!')
|
|
70
|
+
return inference
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Global model instance (will be initialized in main)
|
|
74
|
+
model = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# --- 2. Streaming generation function ---
|
|
78
|
+
def stream_generate(inference, image, max_length=2048, result_queue=None):
|
|
79
|
+
"""Generate text with streaming output."""
|
|
80
|
+
# Get token IDs
|
|
81
|
+
bos_token_id = inference.tokenizer.bos_token_id
|
|
82
|
+
eos_token_id = inference.tokenizer.eos_token_id
|
|
83
|
+
pad_token_id = inference.tokenizer.pad_token_id
|
|
84
|
+
|
|
85
|
+
# Encode image
|
|
86
|
+
encoder_hidden_states, cross_k, cross_v = inference.encode_image(image)
|
|
87
|
+
|
|
88
|
+
# Initialize generation
|
|
89
|
+
generated_ids = [bos_token_id]
|
|
90
|
+
|
|
91
|
+
# Initialize empty past_key_values
|
|
92
|
+
batch_size = encoder_hidden_states.shape[0]
|
|
93
|
+
past_key_values = []
|
|
94
|
+
for _ in range(inference.num_decoder_layers):
|
|
95
|
+
empty_key = np.zeros(
|
|
96
|
+
(batch_size, inference.num_heads, 0, inference.head_dim),
|
|
97
|
+
dtype=np.float32)
|
|
98
|
+
empty_value = np.zeros(
|
|
99
|
+
(batch_size, inference.num_heads, 0, inference.head_dim),
|
|
100
|
+
dtype=np.float32)
|
|
101
|
+
past_key_values.append((empty_key, empty_value))
|
|
102
|
+
cleaned_text = ''
|
|
103
|
+
put_token_num = 30
|
|
104
|
+
# Generation loop with streaming
|
|
105
|
+
for step in range(max_length - 1):
|
|
106
|
+
current_token = generated_ids[-1]
|
|
107
|
+
past_length = step
|
|
108
|
+
|
|
109
|
+
# Decode step
|
|
110
|
+
logits, past_key_values = inference.decode_step(
|
|
111
|
+
current_token,
|
|
112
|
+
past_length,
|
|
113
|
+
cross_k,
|
|
114
|
+
cross_v,
|
|
115
|
+
past_key_values,
|
|
116
|
+
padding_idx=pad_token_id
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Get next token
|
|
120
|
+
next_token_id = int(np.argmax(logits[0, -1, :]))
|
|
121
|
+
generated_ids.append(next_token_id)
|
|
122
|
+
|
|
123
|
+
# Decode current sequence and put in queue
|
|
124
|
+
if result_queue is not None:
|
|
125
|
+
current_text = inference.tokenizer.decode(generated_ids[-1:], skip_special_tokens=False)
|
|
126
|
+
# print(current_text+'\n')
|
|
127
|
+
cleaned_text = cleaned_text + clean_special_tokens(current_text)
|
|
128
|
+
# Post-process HTML table attributes
|
|
129
|
+
# cleaned_text = cleaned_text.replace('<tdcolspan=', '<td colspan=')
|
|
130
|
+
# cleaned_text = cleaned_text.replace('<tdrowspan=', '<td rowspan=')
|
|
131
|
+
# cleaned_text = cleaned_text.replace('"colspan=', '" colspan=')
|
|
132
|
+
if (step + 1) % put_token_num == 0:
|
|
133
|
+
result_queue.put(cleaned_text)
|
|
134
|
+
result_queue.put(cleaned_text)
|
|
135
|
+
# Check for EOS
|
|
136
|
+
if next_token_id == eos_token_id:
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
# Signal completion
|
|
140
|
+
if result_queue is not None:
|
|
141
|
+
result_queue.put(None)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# --- 3. Gradio streaming function for dual display ---
|
|
145
|
+
def stream_recognize_image(input_image):
|
|
146
|
+
"""Stream recognition results with dual display: markdown text only during recognition, render after completion."""
|
|
147
|
+
if input_image is None:
|
|
148
|
+
yield '请先上传一张图片。', '**请先上传一张图片。**'
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
# Convert to PIL Image if needed
|
|
152
|
+
if not isinstance(input_image, Image.Image):
|
|
153
|
+
input_image = Image.fromarray(input_image).convert('RGB')
|
|
154
|
+
else:
|
|
155
|
+
input_image = input_image.convert('RGB')
|
|
156
|
+
|
|
157
|
+
# Create queue for streaming results
|
|
158
|
+
result_queue = queue.Queue()
|
|
159
|
+
|
|
160
|
+
# Start generation in background thread
|
|
161
|
+
thread = Thread(target=stream_generate, args=(model, input_image, 2048, result_queue))
|
|
162
|
+
thread.daemon = True # Set as daemon thread
|
|
163
|
+
thread.start()
|
|
164
|
+
|
|
165
|
+
# Stream results - only update markdown text, keep render area with "recognizing" message
|
|
166
|
+
last_update_time = time.time()
|
|
167
|
+
current_text = ''
|
|
168
|
+
|
|
169
|
+
while True:
|
|
170
|
+
try:
|
|
171
|
+
# Get result with longer timeout
|
|
172
|
+
result = result_queue.get(timeout=1.0)
|
|
173
|
+
if result is None: # Generation complete
|
|
174
|
+
break
|
|
175
|
+
current_text = result
|
|
176
|
+
last_update_time = time.time()
|
|
177
|
+
|
|
178
|
+
# Only update markdown text, show "recognizing" message in render area
|
|
179
|
+
yield current_text, '_正在识别中,请稍候..._'
|
|
180
|
+
|
|
181
|
+
except queue.Empty:
|
|
182
|
+
# No new result yet, check if thread is still alive
|
|
183
|
+
if not thread.is_alive():
|
|
184
|
+
# Thread finished but no completion signal, break
|
|
185
|
+
break
|
|
186
|
+
# Yield current state periodically to keep UI responsive
|
|
187
|
+
current_time = time.time()
|
|
188
|
+
if current_time - last_update_time > 0.5: # Update UI every 0.5s
|
|
189
|
+
yield current_text if current_text else '正在识别中...', '_正在识别中,请稍候..._'
|
|
190
|
+
last_update_time = current_time
|
|
191
|
+
|
|
192
|
+
# Wait for thread to complete
|
|
193
|
+
thread.join(timeout=2.0)
|
|
194
|
+
|
|
195
|
+
# Final yield - now render the complete result
|
|
196
|
+
formatted_result = format_markdown_output(current_text)
|
|
197
|
+
yield formatted_result, formatted_result
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def format_markdown_output(markdown_text):
|
|
201
|
+
"""Format markdown text for display.
|
|
202
|
+
|
|
203
|
+
This function handles:
|
|
204
|
+
- HTML tables (pass through as-is for Gradio Markdown rendering)
|
|
205
|
+
- LaTeX formulas (already in proper format)
|
|
206
|
+
- Basic markdown formatting
|
|
207
|
+
"""
|
|
208
|
+
if not markdown_text:
|
|
209
|
+
return '_等待识别结果..._'
|
|
210
|
+
if '<table>' in markdown_text:
|
|
211
|
+
markdown_text = markdown_converter._handle_table(markdown_text)
|
|
212
|
+
if '\\(' in markdown_text or '\\[' in markdown_text:
|
|
213
|
+
# extract the formula
|
|
214
|
+
formula_pattern = r'\n\n\\\[.*?\\\]\n\n'
|
|
215
|
+
# print(re.findall(formula_pattern, markdown_text, flags=re.DOTALL))
|
|
216
|
+
# markdown_text = re.sub(formula_pattern, markdown_converter._handle_formula, markdown_text, flags=re.DOTALL)
|
|
217
|
+
for formula in re.findall(formula_pattern, markdown_text, flags=re.DOTALL):
|
|
218
|
+
markdown_text = markdown_text.replace(formula, markdown_converter._handle_formula(formula))
|
|
219
|
+
if '\\(' in markdown_text:
|
|
220
|
+
markdown_text = markdown_text.replace('\\(', '$')
|
|
221
|
+
markdown_text = markdown_text.replace('\\)', '$')
|
|
222
|
+
# Return the markdown text as-is
|
|
223
|
+
# Gradio's Markdown component will handle the rendering
|
|
224
|
+
return markdown_text
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# --- 4. Gradio Interface ---
|
|
228
|
+
# Get example images path and download if necessary
|
|
229
|
+
example_img_dir = get_example_images_path(demo_type='unirec')
|
|
230
|
+
|
|
231
|
+
# Get list of example images
|
|
232
|
+
example_images = []
|
|
233
|
+
if os.path.exists(example_img_dir):
|
|
234
|
+
for file in os.listdir(example_img_dir):
|
|
235
|
+
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
|
|
236
|
+
example_images.append(os.path.join(example_img_dir, file))
|
|
237
|
+
example_images = sorted(example_images)
|
|
238
|
+
|
|
239
|
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
240
|
+
gr.HTML("""
|
|
241
|
+
<h1 style='text-align: center;'><a href="https://github.com/Topdu/OpenOCR">UniRec-0.1B: Unified Text and Formula Recognition with 0.1B Parameters</a></h1>
|
|
242
|
+
<p style='text-align: center;'>0.1B超轻量模型统一文本与公式识别(由<a href="https://fvl.fudan.edu.cn">FVL实验室</a> <a href="https://github.com/Topdu/OpenOCR">OCR Team</a> 创建)</p>
|
|
243
|
+
<p style='text-align: center;'><a href="https://github.com/Topdu/OpenOCR/blob/main/docs/unirec.md">[本地GPU部署]</a>获取快速识别体验</p>"""
|
|
244
|
+
)
|
|
245
|
+
gr.Markdown('上传一张图片,点击"运行识别"按钮进行文本和公式识别。')
|
|
246
|
+
with gr.Row():
|
|
247
|
+
with gr.Column(scale=4): # 左侧竖排:图片 + 按钮
|
|
248
|
+
image_input = gr.Image(label='上传图片 or 粘贴截图', type='pil')
|
|
249
|
+
|
|
250
|
+
# Add examples if available
|
|
251
|
+
if example_images:
|
|
252
|
+
gr.Examples(
|
|
253
|
+
examples=example_images,
|
|
254
|
+
inputs=image_input,
|
|
255
|
+
label='📚 示例图片'
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
with gr.Row():
|
|
259
|
+
run_button = gr.Button('🚀 运行识别', variant='primary')
|
|
260
|
+
clear_button = gr.Button('🗑️ 清空', variant='secondary')
|
|
261
|
+
|
|
262
|
+
with gr.Column(scale=6):
|
|
263
|
+
with gr.Tabs():
|
|
264
|
+
with gr.Tab('📝 Markdown Source'):
|
|
265
|
+
markdown_output = gr.Code(label='Markdown Source',
|
|
266
|
+
language='markdown',
|
|
267
|
+
lines=20)
|
|
268
|
+
with gr.Tab('📝 Markdown Preview'):
|
|
269
|
+
markdown_render = gr.Markdown(
|
|
270
|
+
value='_渲染后的表格/公式将显示在这里..._',
|
|
271
|
+
latex_delimiters=LATEX_DELIMS,
|
|
272
|
+
elem_id='md_preview')
|
|
273
|
+
|
|
274
|
+
# 点击运行按钮后触发
|
|
275
|
+
run_button.click(
|
|
276
|
+
stream_recognize_image,
|
|
277
|
+
inputs=[image_input],
|
|
278
|
+
outputs=[markdown_output, markdown_render]
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# 清空按钮功能:清空图片和输出结果
|
|
282
|
+
def clear_all():
|
|
283
|
+
return None, '', '_渲染后的表格/公式将显示在这里..._'
|
|
284
|
+
|
|
285
|
+
clear_button.click(
|
|
286
|
+
clear_all,
|
|
287
|
+
outputs=[image_input, markdown_output, markdown_render]
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def launch_demo(
|
|
292
|
+
encoder_path=None,
|
|
293
|
+
decoder_path=None,
|
|
294
|
+
mapping_path=None,
|
|
295
|
+
use_gpu=None,
|
|
296
|
+
auto_download=True,
|
|
297
|
+
share=False,
|
|
298
|
+
server_name='0.0.0.0',
|
|
299
|
+
server_port=7860
|
|
300
|
+
):
|
|
301
|
+
"""Launch UniRec ONNX Gradio demo with default configuration.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
encoder_path: Path to encoder ONNX model (default: auto-download)
|
|
305
|
+
decoder_path: Path to decoder ONNX model (default: auto-download)
|
|
306
|
+
mapping_path: Path to tokenizer mapping JSON (default: auto-download)
|
|
307
|
+
use_gpu: Whether to use GPU. If None, auto-detect (default: None)
|
|
308
|
+
auto_download: If True, automatically download missing models (default: True)
|
|
309
|
+
share: Create a public share link (default: False)
|
|
310
|
+
server_name: Server name for Gradio (default: '0.0.0.0')
|
|
311
|
+
server_port: Server port for Gradio (default: 7860)
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
gr.Blocks: Gradio demo instance
|
|
315
|
+
"""
|
|
316
|
+
global model
|
|
317
|
+
|
|
318
|
+
# Initialize model with specified parameters
|
|
319
|
+
model = initialize_model(
|
|
320
|
+
encoder_path=encoder_path,
|
|
321
|
+
decoder_path=decoder_path,
|
|
322
|
+
mapping_path=mapping_path,
|
|
323
|
+
use_gpu=use_gpu,
|
|
324
|
+
auto_download=auto_download
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Launch demo
|
|
328
|
+
demo.queue().launch(
|
|
329
|
+
share=share,
|
|
330
|
+
server_name=server_name,
|
|
331
|
+
server_port=server_port
|
|
332
|
+
)
|
|
333
|
+
return demo
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# --- 5. Launch application ---
|
|
337
|
+
if __name__ == '__main__':
|
|
338
|
+
import argparse
|
|
339
|
+
|
|
340
|
+
parser = argparse.ArgumentParser(description='UniRec ONNX Gradio Demo')
|
|
341
|
+
parser.add_argument('--encoder_model',
|
|
342
|
+
type=str,
|
|
343
|
+
default=None,
|
|
344
|
+
help='Path to encoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_encoder.onnx)')
|
|
345
|
+
parser.add_argument('--decoder_model',
|
|
346
|
+
type=str,
|
|
347
|
+
default=None,
|
|
348
|
+
help='Path to decoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_decoder.onnx)')
|
|
349
|
+
parser.add_argument('--mapping',
|
|
350
|
+
type=str,
|
|
351
|
+
default=None,
|
|
352
|
+
help='Path to tokenizer mapping JSON (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_tokenizer_mapping.json)')
|
|
353
|
+
parser.add_argument('--use-gpu',
|
|
354
|
+
type=str,
|
|
355
|
+
default='auto',
|
|
356
|
+
choices=['auto', 'true', 'false'],
|
|
357
|
+
help='Use GPU for inference (auto: auto-detect, true: force GPU, false: force CPU)')
|
|
358
|
+
parser.add_argument('--no-auto-download',
|
|
359
|
+
action='store_true',
|
|
360
|
+
help='Disable automatic model download')
|
|
361
|
+
parser.add_argument('--share',
|
|
362
|
+
action='store_true',
|
|
363
|
+
help='Create a public share link')
|
|
364
|
+
parser.add_argument('--server-name',
|
|
365
|
+
type=str,
|
|
366
|
+
default='0.0.0.0',
|
|
367
|
+
help='Server name for Gradio')
|
|
368
|
+
parser.add_argument('--server-port',
|
|
369
|
+
type=int,
|
|
370
|
+
default=7860,
|
|
371
|
+
help='Server port for Gradio')
|
|
372
|
+
args = parser.parse_args()
|
|
373
|
+
|
|
374
|
+
# Parse use_gpu argument
|
|
375
|
+
if args.use_gpu == 'auto':
|
|
376
|
+
use_gpu = None
|
|
377
|
+
elif args.use_gpu == 'true':
|
|
378
|
+
use_gpu = True
|
|
379
|
+
else:
|
|
380
|
+
use_gpu = False
|
|
381
|
+
|
|
382
|
+
# Launch demo with parsed arguments
|
|
383
|
+
launch_demo(
|
|
384
|
+
encoder_path=args.encoder_model,
|
|
385
|
+
decoder_path=args.decoder_model,
|
|
386
|
+
mapping_path=args.mapping,
|
|
387
|
+
use_gpu=use_gpu,
|
|
388
|
+
auto_download=not args.no_auto_download,
|
|
389
|
+
share=args.share,
|
|
390
|
+
server_name=args.server_name,
|
|
391
|
+
server_port=args.server_port
|
|
392
|
+
)
|
|
@@ -10,13 +10,11 @@ def build_loss(config):
|
|
|
10
10
|
config = copy.deepcopy(config)
|
|
11
11
|
module_name = config.pop('name')
|
|
12
12
|
assert module_name in name_to_module, Exception(
|
|
13
|
-
'
|
|
13
|
+
'{} is not supported. The losses in {} are supportes'.format(
|
|
14
|
+
module_name, list(name_to_module.keys())))
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
module_path = name_to_module[module_name]
|
|
19
|
-
module = import_module(module_path, package=__package__)
|
|
20
|
-
module_class = getattr(module, module_name)
|
|
16
|
+
module_path = name_to_module[module_name]
|
|
17
|
+
module = import_module(module_path, package=__package__)
|
|
18
|
+
module_class = getattr(module, module_name)
|
|
21
19
|
|
|
22
20
|
return module_class(**config)
|
|
@@ -111,7 +111,8 @@ def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
|
|
|
111
111
|
else:
|
|
112
112
|
ymin, ymax = random_select(h_axis, h)
|
|
113
113
|
|
|
114
|
-
if (xmax - xmin < min_crop_side_ratio * w
|
|
114
|
+
if (xmax - xmin < min_crop_side_ratio * w
|
|
115
|
+
or ymax - ymin < min_crop_side_ratio * h):
|
|
115
116
|
# area too small
|
|
116
117
|
continue
|
|
117
118
|
num_poly_in_rect = 0
|