openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,730 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONNX inference script for UniRec model.
|
|
3
|
+
Standalone version without transformers dependency.
|
|
4
|
+
|
|
5
|
+
Version: Optimized v2
|
|
6
|
+
- Supports optimized KV cache format: [batch_size, num_heads, seq_len, head_dim]
|
|
7
|
+
- Compatible with merged QKV/KV projection models
|
|
8
|
+
- No reshape overhead during generation
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import re
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
import numpy as np
|
|
17
|
+
import onnxruntime as ort
|
|
18
|
+
from PIL import Image
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def download_model_files(model_dir=None):
|
|
22
|
+
"""Download ONNX model files from ModelScope or HuggingFace.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
model_dir: Directory to save model files. If None, use default cache directory.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tuple of (encoder_path, decoder_path, mapping_path)
|
|
29
|
+
"""
|
|
30
|
+
# Use default cache directory if not specified
|
|
31
|
+
if model_dir is None:
|
|
32
|
+
cache_dir = Path.home() / '.cache' / 'openocr'
|
|
33
|
+
model_dir = cache_dir / 'unirec_0_1b_onnx'
|
|
34
|
+
else:
|
|
35
|
+
model_dir = Path(model_dir)
|
|
36
|
+
|
|
37
|
+
model_dir.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
|
|
39
|
+
required_files = [
|
|
40
|
+
'unirec_encoder.onnx',
|
|
41
|
+
'unirec_decoder.onnx',
|
|
42
|
+
'unirec_tokenizer_mapping.json'
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
# Check which files are missing
|
|
46
|
+
missing_files = [f for f in required_files if not (model_dir / f).exists()]
|
|
47
|
+
|
|
48
|
+
if not missing_files:
|
|
49
|
+
print(f'✅ All model files found in {model_dir}')
|
|
50
|
+
return tuple(str(model_dir / f) for f in required_files)
|
|
51
|
+
|
|
52
|
+
print(f'📥 Missing files: {missing_files}')
|
|
53
|
+
print(f'📥 Downloading model files to {model_dir}...')
|
|
54
|
+
|
|
55
|
+
download_success = False
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Try ModelScope first (default)
|
|
59
|
+
print('🌐 Trying ModelScope (China mirror) first...')
|
|
60
|
+
try:
|
|
61
|
+
from modelscope import snapshot_download
|
|
62
|
+
model_path = snapshot_download(
|
|
63
|
+
'topdktu/unirec_0_1b_onnx',
|
|
64
|
+
cache_dir=str(model_dir.parent)
|
|
65
|
+
)
|
|
66
|
+
print(f'✅ Downloaded to {model_path}')
|
|
67
|
+
|
|
68
|
+
# Copy files to target directory
|
|
69
|
+
import shutil
|
|
70
|
+
for file in required_files:
|
|
71
|
+
src = Path(model_path) / file
|
|
72
|
+
dst = model_dir / file
|
|
73
|
+
if src.exists() and not dst.exists():
|
|
74
|
+
shutil.copy(str(src), str(dst))
|
|
75
|
+
print(f' ✓ {file}')
|
|
76
|
+
|
|
77
|
+
# Verify all files exist after download
|
|
78
|
+
all_files_exist = all((model_dir / f).exists() for f in required_files)
|
|
79
|
+
if all_files_exist:
|
|
80
|
+
download_success = True
|
|
81
|
+
print('✅ All files downloaded successfully from ModelScope!')
|
|
82
|
+
else:
|
|
83
|
+
print('⚠️ ModelScope download incomplete, trying HuggingFace...')
|
|
84
|
+
|
|
85
|
+
except ImportError:
|
|
86
|
+
print('⚠️ modelscope not installed. Install with: pip install modelscope')
|
|
87
|
+
print(' Trying HuggingFace...')
|
|
88
|
+
except Exception as e:
|
|
89
|
+
print(f'⚠️ ModelScope download failed: {e}')
|
|
90
|
+
print(' Trying HuggingFace...')
|
|
91
|
+
|
|
92
|
+
if not download_success:
|
|
93
|
+
# Try HuggingFace
|
|
94
|
+
print('🌐 Using HuggingFace...')
|
|
95
|
+
try:
|
|
96
|
+
from huggingface_hub import hf_hub_download
|
|
97
|
+
|
|
98
|
+
for file in missing_files:
|
|
99
|
+
print(f' Downloading {file}...')
|
|
100
|
+
downloaded_path = hf_hub_download(
|
|
101
|
+
repo_id='topdu/unirec_0_1b_onnx',
|
|
102
|
+
filename=file,
|
|
103
|
+
cache_dir=str(model_dir.parent),
|
|
104
|
+
local_dir=str(model_dir),
|
|
105
|
+
local_dir_use_symlinks=False
|
|
106
|
+
)
|
|
107
|
+
print(f' ✓ {file}')
|
|
108
|
+
|
|
109
|
+
# Verify all files exist after download
|
|
110
|
+
all_files_exist = all((model_dir / f).exists() for f in required_files)
|
|
111
|
+
if all_files_exist:
|
|
112
|
+
download_success = True
|
|
113
|
+
print('✅ All files downloaded successfully from HuggingFace!')
|
|
114
|
+
|
|
115
|
+
except ImportError:
|
|
116
|
+
print('⚠️ huggingface_hub not installed. Install with: pip install huggingface_hub')
|
|
117
|
+
raise RuntimeError(
|
|
118
|
+
'Cannot download models. Please install either:\n'
|
|
119
|
+
' - huggingface_hub: pip install huggingface_hub\n'
|
|
120
|
+
' - modelscope: pip install modelscope\n'
|
|
121
|
+
'Or manually download from:\n'
|
|
122
|
+
' - https://huggingface.co/topdu/unirec_0_1b_onnx\n'
|
|
123
|
+
' - https://modelscope.cn/models/topdktu/unirec_0_1b_onnx'
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if not download_success:
|
|
127
|
+
raise RuntimeError(
|
|
128
|
+
'Failed to download all required files. Please manually download from:\n'
|
|
129
|
+
' - https://huggingface.co/topdu/unirec_0_1b_onnx\n'
|
|
130
|
+
' - https://modelscope.cn/models/topdktu/unirec_0_1b_onnx'
|
|
131
|
+
)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
print(f'❌ Download failed: {e}')
|
|
134
|
+
print('\n📝 Manual download instructions:')
|
|
135
|
+
print(' 1. Visit: https://huggingface.co/topdu/unirec_0_1b_onnx')
|
|
136
|
+
print(' or: https://modelscope.cn/models/topdktu/unirec_0_1b_onnx')
|
|
137
|
+
print(f' 2. Download these files to {model_dir}:')
|
|
138
|
+
for file in required_files:
|
|
139
|
+
print(f' - {file}')
|
|
140
|
+
raise
|
|
141
|
+
|
|
142
|
+
return tuple(str(model_dir / f) for f in required_files)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def check_and_download_models(encoder_path, decoder_path, mapping_path, auto_download=True):
|
|
146
|
+
"""Check if model files exist, download if missing.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
encoder_path: Path to encoder ONNX model
|
|
150
|
+
decoder_path: Path to decoder ONNX model
|
|
151
|
+
mapping_path: Path to tokenizer mapping JSON
|
|
152
|
+
auto_download: If True, automatically download missing files
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Tuple of (encoder_path, decoder_path, mapping_path) with verified paths
|
|
156
|
+
"""
|
|
157
|
+
files_to_check = {
|
|
158
|
+
'encoder': encoder_path,
|
|
159
|
+
'decoder': decoder_path,
|
|
160
|
+
'mapping': mapping_path
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
missing_files = {k: v for k, v in files_to_check.items() if not os.path.exists(v)}
|
|
164
|
+
|
|
165
|
+
if not missing_files:
|
|
166
|
+
return encoder_path, decoder_path, mapping_path
|
|
167
|
+
|
|
168
|
+
print('⚠️ Missing model files:')
|
|
169
|
+
for name, path in missing_files.items():
|
|
170
|
+
print(f' - {name}: {path}')
|
|
171
|
+
|
|
172
|
+
if not auto_download:
|
|
173
|
+
raise FileNotFoundError(
|
|
174
|
+
'Model files not found. Please download from:\n'
|
|
175
|
+
' - https://huggingface.co/topdu/unirec_0_1b_onnx\n'
|
|
176
|
+
' - https://modelscope.cn/models/topdktu/unirec_0_1b_onnx'
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Determine model directory from encoder path
|
|
180
|
+
encoder_dir = os.path.dirname(encoder_path)
|
|
181
|
+
if encoder_dir and encoder_dir != './unirec_0_1b_onnx':
|
|
182
|
+
# User specified a custom path
|
|
183
|
+
model_dir = encoder_dir
|
|
184
|
+
else:
|
|
185
|
+
# Use default cache directory
|
|
186
|
+
model_dir = None
|
|
187
|
+
|
|
188
|
+
# Try ModelScope first (faster in China), then HuggingFace
|
|
189
|
+
try:
|
|
190
|
+
print('🇨🇳 Trying ModelScope (China mirror) first...')
|
|
191
|
+
return download_model_files(model_dir)
|
|
192
|
+
except:
|
|
193
|
+
print('🌍 Trying HuggingFace...')
|
|
194
|
+
return download_model_files(model_dir)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class SimpleImageProcessor:
|
|
198
|
+
"""Standalone image processor without transformers dependency."""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
max_side=(960, 1408), # (width, height)
|
|
203
|
+
divided_factor=(64, 64),
|
|
204
|
+
image_mean=(0.5, 0.5, 0.5),
|
|
205
|
+
image_std=(0.5, 0.5, 0.5),
|
|
206
|
+
):
|
|
207
|
+
self.max_side = max_side
|
|
208
|
+
self.divided_factor = divided_factor
|
|
209
|
+
self.image_mean = np.array(image_mean, dtype=np.float32)
|
|
210
|
+
self.image_std = np.array(image_std, dtype=np.float32)
|
|
211
|
+
|
|
212
|
+
def _calculate_target_size(self, original_width, original_height):
|
|
213
|
+
"""Calculate target size with aspect ratio preservation."""
|
|
214
|
+
max_width, max_height = self.max_side
|
|
215
|
+
aspect_ratio = original_width / original_height
|
|
216
|
+
|
|
217
|
+
if original_width > max_width or original_height > max_height:
|
|
218
|
+
if (max_width / max_height) >= aspect_ratio:
|
|
219
|
+
new_height = max_height
|
|
220
|
+
new_width = int(new_height * aspect_ratio)
|
|
221
|
+
else:
|
|
222
|
+
new_width = max_width
|
|
223
|
+
new_height = int(new_width / aspect_ratio)
|
|
224
|
+
else:
|
|
225
|
+
new_width, new_height = original_width, original_height
|
|
226
|
+
|
|
227
|
+
# Apply divided factor
|
|
228
|
+
div_w, div_h = self.divided_factor
|
|
229
|
+
final_width = max(int(new_width // div_w * div_w), 64)
|
|
230
|
+
final_height = max(int(new_height // div_h * div_h), 64)
|
|
231
|
+
|
|
232
|
+
return (final_width, final_height)
|
|
233
|
+
|
|
234
|
+
def __call__(self, image):
|
|
235
|
+
"""
|
|
236
|
+
Process image for model input.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
image: PIL Image
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
dict with 'pixel_values' as numpy array [1, 3, H, W]
|
|
243
|
+
"""
|
|
244
|
+
if not isinstance(image, Image.Image):
|
|
245
|
+
raise ValueError('Input must be PIL Image')
|
|
246
|
+
|
|
247
|
+
original_width, original_height = image.size
|
|
248
|
+
|
|
249
|
+
# Resize
|
|
250
|
+
target_size = self._calculate_target_size(original_width,
|
|
251
|
+
original_height)
|
|
252
|
+
image = image.resize(target_size, resample=Image.BICUBIC)
|
|
253
|
+
|
|
254
|
+
# Convert to numpy array [H, W, C] and normalize to [0, 1]
|
|
255
|
+
image_np = np.array(image, dtype=np.float32)[:, :, :3] / 255.0
|
|
256
|
+
|
|
257
|
+
# Normalize: (x - mean) / std
|
|
258
|
+
image_np = (image_np - self.image_mean) / self.image_std
|
|
259
|
+
|
|
260
|
+
# Transpose to [C, H, W]
|
|
261
|
+
image_np = image_np.transpose(2, 0, 1)
|
|
262
|
+
|
|
263
|
+
# Add batch dimension [1, C, H, W]
|
|
264
|
+
image_np = np.expand_dims(image_np, axis=0)
|
|
265
|
+
|
|
266
|
+
return {'pixel_values': image_np}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class SimpleTokenizer:
|
|
270
|
+
"""Standalone tokenizer without transformers dependency."""
|
|
271
|
+
|
|
272
|
+
def __init__(self, mapping_file=None):
|
|
273
|
+
"""
|
|
274
|
+
Load vocabulary from mapping file or tokenizer.json.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
vocab_file: path to tokenizer.json (deprecated, use mapping_file)
|
|
278
|
+
mapping_file: path to unirec_tokenizer_mapping.json (recommended)
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
if mapping_file and os.path.exists(mapping_file):
|
|
282
|
+
# 使用导出的映射文件 (推荐)
|
|
283
|
+
print(f'Loading tokenizer from mapping file: {mapping_file}')
|
|
284
|
+
with open(mapping_file, 'r', encoding='utf-8') as f:
|
|
285
|
+
mapping_data = json.load(f)
|
|
286
|
+
|
|
287
|
+
# 直接使用 id_to_token 映射
|
|
288
|
+
self.id_to_token = {
|
|
289
|
+
int(k): v
|
|
290
|
+
for k, v in mapping_data['id_to_token'].items()
|
|
291
|
+
}
|
|
292
|
+
self.vocab_size = mapping_data['vocab_size']
|
|
293
|
+
|
|
294
|
+
# 特殊 token
|
|
295
|
+
special_tokens = mapping_data['special_tokens']
|
|
296
|
+
self.bos_token_id = special_tokens['bos_token_id']
|
|
297
|
+
self.eos_token_id = special_tokens['eos_token_id']
|
|
298
|
+
self.pad_token_id = special_tokens['pad_token_id']
|
|
299
|
+
|
|
300
|
+
print(f'✅ Loaded vocabulary with {self.vocab_size} tokens')
|
|
301
|
+
|
|
302
|
+
def decode(self, token_ids, skip_special_tokens=False):
|
|
303
|
+
"""
|
|
304
|
+
Decode token IDs to text.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
token_ids: list of token IDs
|
|
308
|
+
skip_special_tokens: whether to skip special tokens
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
decoded text string
|
|
312
|
+
"""
|
|
313
|
+
tokens = []
|
|
314
|
+
for token_id in token_ids:
|
|
315
|
+
if token_id in self.id_to_token:
|
|
316
|
+
token = self.id_to_token[token_id]
|
|
317
|
+
|
|
318
|
+
# Skip special tokens if requested
|
|
319
|
+
if skip_special_tokens and token_id in [
|
|
320
|
+
self.bos_token_id, self.eos_token_id, self.pad_token_id
|
|
321
|
+
]:
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
tokens.append(token)
|
|
325
|
+
else:
|
|
326
|
+
tokens.append(f'<unk_{token_id}>')
|
|
327
|
+
|
|
328
|
+
# Join tokens
|
|
329
|
+
text = ''.join(tokens)
|
|
330
|
+
|
|
331
|
+
return text
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def clean_special_tokens(text):
|
|
335
|
+
"""Clean special tokens from decoded text."""
|
|
336
|
+
# Remove special formatting tokens
|
|
337
|
+
text = text.replace('Ġ', ' ').replace('Ċ', '\n')
|
|
338
|
+
text = text.replace('<|bos|>', '').replace('<|eos|>',
|
|
339
|
+
'').replace('<|pad|>', '')
|
|
340
|
+
|
|
341
|
+
# Apply regex rules
|
|
342
|
+
rules = [
|
|
343
|
+
(r'-<\|sn\|>', ''),
|
|
344
|
+
(r' <\|sn\|>', ' '),
|
|
345
|
+
(r'<\|sn\|>', ' '),
|
|
346
|
+
(r'<\|unk\|>', ''),
|
|
347
|
+
(r'<s>', ''),
|
|
348
|
+
(r'</s>', ''),
|
|
349
|
+
(r'\uffff', ''),
|
|
350
|
+
(r'_{4,}', '___'),
|
|
351
|
+
(r'\.{4,}', '...'),
|
|
352
|
+
]
|
|
353
|
+
|
|
354
|
+
for pattern, replacement in rules:
|
|
355
|
+
text = re.sub(pattern, replacement, text)
|
|
356
|
+
|
|
357
|
+
return text
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class UniRecONNX:
|
|
361
|
+
"""ONNX-based inference for UniRec model (standalone version)."""
|
|
362
|
+
|
|
363
|
+
def __init__(
|
|
364
|
+
self,
|
|
365
|
+
encoder_path=None,
|
|
366
|
+
decoder_path=None,
|
|
367
|
+
mapping_path=None,
|
|
368
|
+
use_gpu=None,
|
|
369
|
+
auto_download=True,
|
|
370
|
+
):
|
|
371
|
+
"""Initialize ONNX inference sessions.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
encoder_path: Path to encoder ONNX model. If None, use default cache directory.
|
|
375
|
+
decoder_path: Path to decoder ONNX model. If None, use default cache directory.
|
|
376
|
+
mapping_path: Path to tokenizer mapping JSON. If None, use default cache directory.
|
|
377
|
+
use_gpu: Whether to use GPU. If None, auto-detect. If True, force GPU. If False, force CPU.
|
|
378
|
+
auto_download: If True, automatically download missing model files
|
|
379
|
+
"""
|
|
380
|
+
# Set default paths if not provided
|
|
381
|
+
if encoder_path is None or decoder_path is None or mapping_path is None:
|
|
382
|
+
cache_dir = Path.home() / '.cache' / 'openocr'
|
|
383
|
+
model_path = cache_dir / 'unirec_0_1b_onnx'
|
|
384
|
+
if encoder_path is None:
|
|
385
|
+
encoder_path = str(model_path / 'unirec_encoder.onnx')
|
|
386
|
+
if decoder_path is None:
|
|
387
|
+
decoder_path = str(model_path / 'unirec_decoder.onnx')
|
|
388
|
+
if mapping_path is None:
|
|
389
|
+
mapping_path = str(model_path / 'unirec_tokenizer_mapping.json')
|
|
390
|
+
|
|
391
|
+
# Check and download models if needed
|
|
392
|
+
encoder_path, decoder_path, mapping_path = check_and_download_models(
|
|
393
|
+
encoder_path, decoder_path, mapping_path, auto_download=auto_download
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
print('Loading ONNX models...')
|
|
397
|
+
|
|
398
|
+
# Determine execution provider
|
|
399
|
+
providers = self._get_execution_providers(use_gpu)
|
|
400
|
+
print(f'Using execution providers: {providers}')
|
|
401
|
+
|
|
402
|
+
# Create ONNX runtime sessions
|
|
403
|
+
sess_options = ort.SessionOptions()
|
|
404
|
+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
405
|
+
self.decoder_session = ort.InferenceSession(decoder_path, sess_options, providers=providers)
|
|
406
|
+
self.encoder_session = ort.InferenceSession(encoder_path, sess_options, providers=providers)
|
|
407
|
+
|
|
408
|
+
# Initialize processor and tokenizer
|
|
409
|
+
self.processor = SimpleImageProcessor()
|
|
410
|
+
self.tokenizer = SimpleTokenizer(mapping_file=mapping_path)
|
|
411
|
+
|
|
412
|
+
# Get model info from decoder session
|
|
413
|
+
# Shape: [batch_size, num_heads, seq_len, head_dim]
|
|
414
|
+
self.num_decoder_layers = None
|
|
415
|
+
self.num_heads = None
|
|
416
|
+
self.head_dim = None
|
|
417
|
+
|
|
418
|
+
for inp in self.decoder_session.get_inputs():
|
|
419
|
+
if 'past_key' in inp.name:
|
|
420
|
+
layer_idx = int(inp.name.split('_')[-1])
|
|
421
|
+
if self.num_decoder_layers is None or layer_idx + 1 > self.num_decoder_layers:
|
|
422
|
+
self.num_decoder_layers = layer_idx + 1
|
|
423
|
+
# Get shape info: [batch_size, num_heads, seq_len, head_dim]
|
|
424
|
+
if len(inp.shape) == 4:
|
|
425
|
+
if self.num_heads is None and isinstance(
|
|
426
|
+
inp.shape[1], int):
|
|
427
|
+
self.num_heads = inp.shape[1]
|
|
428
|
+
if self.head_dim is None and isinstance(inp.shape[3], int):
|
|
429
|
+
self.head_dim = inp.shape[3]
|
|
430
|
+
|
|
431
|
+
# Calculate d_model
|
|
432
|
+
if self.num_heads and self.head_dim:
|
|
433
|
+
self.d_model = self.num_heads * self.head_dim
|
|
434
|
+
else:
|
|
435
|
+
self.d_model = None
|
|
436
|
+
|
|
437
|
+
print('\n✅ Models loaded successfully!')
|
|
438
|
+
print(f' Number of decoder layers: {self.num_decoder_layers}')
|
|
439
|
+
print(f' Number of attention heads: {self.num_heads}')
|
|
440
|
+
print(f' Head dimension: {self.head_dim}')
|
|
441
|
+
print(f' Model dimension (d_model): {self.d_model}')
|
|
442
|
+
print(f' Vocabulary size: {self.tokenizer.vocab_size}')
|
|
443
|
+
|
|
444
|
+
def _get_execution_providers(self, use_gpu):
|
|
445
|
+
"""Determine execution providers based on GPU availability and user preference.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
use_gpu: None (auto-detect), True (force GPU), or False (force CPU)
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
List of execution providers in priority order
|
|
452
|
+
"""
|
|
453
|
+
available_providers = ort.get_available_providers()
|
|
454
|
+
|
|
455
|
+
if use_gpu is False:
|
|
456
|
+
# Force CPU
|
|
457
|
+
print('🔧 User specified: Using CPU')
|
|
458
|
+
return ['CPUExecutionProvider']
|
|
459
|
+
|
|
460
|
+
# Check for GPU providers
|
|
461
|
+
gpu_providers = []
|
|
462
|
+
if 'CUDAExecutionProvider' in available_providers:
|
|
463
|
+
gpu_providers.append('CUDAExecutionProvider')
|
|
464
|
+
# if 'TensorrtExecutionProvider' in available_providers:
|
|
465
|
+
# gpu_providers.append('TensorrtExecutionProvider')
|
|
466
|
+
|
|
467
|
+
if use_gpu is True:
|
|
468
|
+
# Force GPU
|
|
469
|
+
if gpu_providers:
|
|
470
|
+
print(f'🔧 User specified: Using GPU ({gpu_providers[0]})')
|
|
471
|
+
return gpu_providers + ['CPUExecutionProvider']
|
|
472
|
+
else:
|
|
473
|
+
print('⚠️ GPU requested but not available, falling back to CPU')
|
|
474
|
+
return ['CPUExecutionProvider']
|
|
475
|
+
|
|
476
|
+
# Auto-detect (use_gpu is None)
|
|
477
|
+
if gpu_providers:
|
|
478
|
+
print(f'✅ GPU detected: Using {gpu_providers[0]}')
|
|
479
|
+
return gpu_providers + ['CPUExecutionProvider']
|
|
480
|
+
else:
|
|
481
|
+
print('ℹ️ No GPU detected, using CPU')
|
|
482
|
+
return ['CPUExecutionProvider']
|
|
483
|
+
|
|
484
|
+
def encode_image(self, image):
|
|
485
|
+
"""Encode image using encoder ONNX model."""
|
|
486
|
+
# Preprocess image
|
|
487
|
+
data_img = self.processor(image)
|
|
488
|
+
pixel_values = data_img['pixel_values']
|
|
489
|
+
|
|
490
|
+
# Run encoder
|
|
491
|
+
encoder_outputs = self.encoder_session.run(
|
|
492
|
+
None, {'pixel_values': pixel_values.astype(np.float32)})
|
|
493
|
+
|
|
494
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
495
|
+
cross_k = encoder_outputs[1]
|
|
496
|
+
cross_v = encoder_outputs[2]
|
|
497
|
+
|
|
498
|
+
return encoder_hidden_states, cross_k, cross_v
|
|
499
|
+
|
|
500
|
+
def decode_step(self,
|
|
501
|
+
input_id,
|
|
502
|
+
past_length,
|
|
503
|
+
cross_k,
|
|
504
|
+
cross_v,
|
|
505
|
+
past_key_values,
|
|
506
|
+
padding_idx=1):
|
|
507
|
+
"""Unified decoder step with or without cache."""
|
|
508
|
+
# Prepare inputs
|
|
509
|
+
input_ids = np.array([[input_id]], dtype=np.int64)
|
|
510
|
+
# Use M2M100's position ID calculation with past_key_values_length
|
|
511
|
+
position_ids = np.array([[padding_idx + 1 + past_length]],
|
|
512
|
+
dtype=np.int64)
|
|
513
|
+
|
|
514
|
+
decoder_inputs = {
|
|
515
|
+
'input_ids': input_ids,
|
|
516
|
+
'position_ids': position_ids,
|
|
517
|
+
'cross_k': cross_k.astype(np.float32),
|
|
518
|
+
'cross_v': cross_v.astype(np.float32),
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
# Add past_key_values
|
|
522
|
+
for i, (past_key, past_value) in enumerate(past_key_values):
|
|
523
|
+
decoder_inputs[f'past_key_{i}'] = past_key.astype(np.float32)
|
|
524
|
+
decoder_inputs[f'past_value_{i}'] = past_value.astype(np.float32)
|
|
525
|
+
|
|
526
|
+
# Run decoder
|
|
527
|
+
decoder_outputs = self.decoder_session.run(None, decoder_inputs)
|
|
528
|
+
|
|
529
|
+
# Parse outputs
|
|
530
|
+
logits = decoder_outputs[0]
|
|
531
|
+
|
|
532
|
+
# Extract present_key_values
|
|
533
|
+
present_key_values = []
|
|
534
|
+
for i in range(self.num_decoder_layers):
|
|
535
|
+
key = decoder_outputs[1 + i * 2]
|
|
536
|
+
value = decoder_outputs[1 + i * 2 + 1]
|
|
537
|
+
present_key_values.append((key, value))
|
|
538
|
+
|
|
539
|
+
return logits, present_key_values
|
|
540
|
+
|
|
541
|
+
def __call__(
|
|
542
|
+
self,
|
|
543
|
+
img_path=None,
|
|
544
|
+
img_numpy=None,
|
|
545
|
+
image=None,
|
|
546
|
+
max_length=2048,
|
|
547
|
+
bos_token_id=None,
|
|
548
|
+
eos_token_id=None,
|
|
549
|
+
pad_token_id=None,
|
|
550
|
+
):
|
|
551
|
+
"""
|
|
552
|
+
Unified interface for UniRec inference.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
img_path: Path to input image (str or Path)
|
|
556
|
+
img_numpy: Input image as numpy array (BGR format)
|
|
557
|
+
image: PIL Image object (RGB format)
|
|
558
|
+
max_length: Maximum generation length
|
|
559
|
+
bos_token_id: Beginning of sequence token ID
|
|
560
|
+
eos_token_id: End of sequence token ID
|
|
561
|
+
pad_token_id: Padding token ID
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Tuple of (generated_text, generated_ids)
|
|
565
|
+
"""
|
|
566
|
+
# Load image from path, numpy array, or use provided PIL image
|
|
567
|
+
if img_path is not None:
|
|
568
|
+
image = Image.open(img_path).convert('RGB')
|
|
569
|
+
elif img_numpy is not None:
|
|
570
|
+
# Convert BGR to RGB if needed
|
|
571
|
+
if len(img_numpy.shape) == 3 and img_numpy.shape[2] == 3:
|
|
572
|
+
import cv2
|
|
573
|
+
img_numpy = cv2.cvtColor(img_numpy, cv2.COLOR_BGR2RGB)
|
|
574
|
+
image = Image.fromarray(img_numpy)
|
|
575
|
+
elif image is None:
|
|
576
|
+
raise ValueError('Either img_path, img_numpy, or image must be provided')
|
|
577
|
+
|
|
578
|
+
# Get token IDs
|
|
579
|
+
if bos_token_id is None:
|
|
580
|
+
bos_token_id = self.tokenizer.bos_token_id
|
|
581
|
+
if eos_token_id is None:
|
|
582
|
+
eos_token_id = self.tokenizer.eos_token_id
|
|
583
|
+
if pad_token_id is None:
|
|
584
|
+
pad_token_id = self.tokenizer.pad_token_id
|
|
585
|
+
|
|
586
|
+
# Encode image
|
|
587
|
+
print('Encoding image...')
|
|
588
|
+
t_start = time.time()
|
|
589
|
+
encoder_hidden_states, cross_k, cross_v = self.encode_image(image)
|
|
590
|
+
print(f'Encoding time: {time.time() - t_start:.2f} seconds')
|
|
591
|
+
print(f' cross_k shape: {cross_k.shape}')
|
|
592
|
+
print(f' cross_v shape: {cross_v.shape}')
|
|
593
|
+
|
|
594
|
+
# Initialize generation
|
|
595
|
+
print('Generating text...')
|
|
596
|
+
generated_ids = [bos_token_id]
|
|
597
|
+
|
|
598
|
+
# Initialize empty past_key_values for first step
|
|
599
|
+
# Shape: [batch_size, num_heads, 0, head_dim]
|
|
600
|
+
batch_size = encoder_hidden_states.shape[0]
|
|
601
|
+
past_key_values = []
|
|
602
|
+
for _ in range(self.num_decoder_layers):
|
|
603
|
+
empty_key = np.zeros(
|
|
604
|
+
(batch_size, self.num_heads, 0, self.head_dim),
|
|
605
|
+
dtype=np.float32)
|
|
606
|
+
empty_value = np.zeros(
|
|
607
|
+
(batch_size, self.num_heads, 0, self.head_dim),
|
|
608
|
+
dtype=np.float32)
|
|
609
|
+
past_key_values.append((empty_key, empty_value))
|
|
610
|
+
|
|
611
|
+
# Generation loop
|
|
612
|
+
t_start = time.time()
|
|
613
|
+
for step in range(max_length - 1):
|
|
614
|
+
# Current token to decode
|
|
615
|
+
current_token = generated_ids[-1]
|
|
616
|
+
|
|
617
|
+
# past_length is the sequence length in cache
|
|
618
|
+
past_length = step
|
|
619
|
+
|
|
620
|
+
# Decode step
|
|
621
|
+
logits, past_key_values = self.decode_step(
|
|
622
|
+
current_token,
|
|
623
|
+
past_length,
|
|
624
|
+
cross_k,
|
|
625
|
+
cross_v,
|
|
626
|
+
past_key_values,
|
|
627
|
+
padding_idx=pad_token_id)
|
|
628
|
+
|
|
629
|
+
# Get next token
|
|
630
|
+
next_token_id = int(np.argmax(logits[0, -1, :]))
|
|
631
|
+
generated_ids.append(next_token_id)
|
|
632
|
+
|
|
633
|
+
# Check for EOS
|
|
634
|
+
if next_token_id == eos_token_id:
|
|
635
|
+
break
|
|
636
|
+
|
|
637
|
+
# Progress indicator
|
|
638
|
+
if (step + 1) % 50 == 0:
|
|
639
|
+
print(f' Generated {step + 1} tokens...')
|
|
640
|
+
|
|
641
|
+
t_end = time.time()
|
|
642
|
+
print(f'✅ Generation complete! Total tokens: {len(generated_ids)}')
|
|
643
|
+
print(f' Time taken: {t_end - t_start:.2f} seconds')
|
|
644
|
+
print(
|
|
645
|
+
f' Tokens per second: {len(generated_ids) / (t_end - t_start):.2f}'
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
# Decode tokens
|
|
649
|
+
generated_text = self.tokenizer.decode(generated_ids,
|
|
650
|
+
skip_special_tokens=False)
|
|
651
|
+
cleaned_text = clean_special_tokens(generated_text)
|
|
652
|
+
|
|
653
|
+
return cleaned_text, generated_ids
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
def main():
|
|
657
|
+
"""Example usage."""
|
|
658
|
+
import argparse
|
|
659
|
+
|
|
660
|
+
parser = argparse.ArgumentParser(
|
|
661
|
+
description='UniRec ONNX Inference (Standalone)')
|
|
662
|
+
parser.add_argument('--image',
|
|
663
|
+
type=str,
|
|
664
|
+
required=True,
|
|
665
|
+
help='Path to input image')
|
|
666
|
+
parser.add_argument('--encoder-model',
|
|
667
|
+
type=str,
|
|
668
|
+
default=None,
|
|
669
|
+
help='Path to encoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_encoder.onnx)')
|
|
670
|
+
parser.add_argument('--decoder-model',
|
|
671
|
+
type=str,
|
|
672
|
+
default=None,
|
|
673
|
+
help='Path to decoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_decoder.onnx)')
|
|
674
|
+
parser.add_argument(
|
|
675
|
+
'--mapping',
|
|
676
|
+
type=str,
|
|
677
|
+
default=None,
|
|
678
|
+
help='Path to tokenizer mapping JSON (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_tokenizer_mapping.json)')
|
|
679
|
+
parser.add_argument('--max-length',
|
|
680
|
+
type=int,
|
|
681
|
+
default=2048,
|
|
682
|
+
help='Maximum generation length')
|
|
683
|
+
parser.add_argument('--use-gpu',
|
|
684
|
+
type=str,
|
|
685
|
+
default='auto',
|
|
686
|
+
choices=['auto', 'true', 'false'],
|
|
687
|
+
help='Use GPU for inference (auto: auto-detect, true: force GPU, false: force CPU)')
|
|
688
|
+
parser.add_argument('--no-auto-download',
|
|
689
|
+
action='store_true',
|
|
690
|
+
help='Disable automatic model download')
|
|
691
|
+
args = parser.parse_args()
|
|
692
|
+
|
|
693
|
+
# Parse use_gpu argument
|
|
694
|
+
if args.use_gpu == 'auto':
|
|
695
|
+
use_gpu = None
|
|
696
|
+
elif args.use_gpu == 'true':
|
|
697
|
+
use_gpu = True
|
|
698
|
+
else:
|
|
699
|
+
use_gpu = False
|
|
700
|
+
|
|
701
|
+
# Load image
|
|
702
|
+
print(f'Loading image: {args.image}')
|
|
703
|
+
image = Image.open(args.image).convert('RGB')
|
|
704
|
+
|
|
705
|
+
# Initialize inference
|
|
706
|
+
inference = UniRecONNX(
|
|
707
|
+
encoder_path=args.encoder_model,
|
|
708
|
+
decoder_path=args.decoder_model,
|
|
709
|
+
mapping_path=args.mapping,
|
|
710
|
+
use_gpu=use_gpu,
|
|
711
|
+
auto_download=not args.no_auto_download,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Generate
|
|
715
|
+
result_text, generated_ids = inference(
|
|
716
|
+
image=image,
|
|
717
|
+
max_length=args.max_length,
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
# Print result
|
|
721
|
+
print('\n' + '=' * 80)
|
|
722
|
+
print('RESULT:')
|
|
723
|
+
print('=' * 80)
|
|
724
|
+
print(result_text)
|
|
725
|
+
print('=' * 80)
|
|
726
|
+
print(f'\nGenerated {len(generated_ids)} tokens')
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
if __name__ == '__main__':
|
|
730
|
+
main()
|