openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,730 @@
1
+ """
2
+ ONNX inference script for UniRec model.
3
+ Standalone version without transformers dependency.
4
+
5
+ Version: Optimized v2
6
+ - Supports optimized KV cache format: [batch_size, num_heads, seq_len, head_dim]
7
+ - Compatible with merged QKV/KV projection models
8
+ - No reshape overhead during generation
9
+ """
10
+
11
+ import json
12
+ import os
13
+ import re
14
+ import time
15
+ from pathlib import Path
16
+ import numpy as np
17
+ import onnxruntime as ort
18
+ from PIL import Image
19
+
20
+
21
+ def download_model_files(model_dir=None):
22
+ """Download ONNX model files from ModelScope or HuggingFace.
23
+
24
+ Args:
25
+ model_dir: Directory to save model files. If None, use default cache directory.
26
+
27
+ Returns:
28
+ Tuple of (encoder_path, decoder_path, mapping_path)
29
+ """
30
+ # Use default cache directory if not specified
31
+ if model_dir is None:
32
+ cache_dir = Path.home() / '.cache' / 'openocr'
33
+ model_dir = cache_dir / 'unirec_0_1b_onnx'
34
+ else:
35
+ model_dir = Path(model_dir)
36
+
37
+ model_dir.mkdir(parents=True, exist_ok=True)
38
+
39
+ required_files = [
40
+ 'unirec_encoder.onnx',
41
+ 'unirec_decoder.onnx',
42
+ 'unirec_tokenizer_mapping.json'
43
+ ]
44
+
45
+ # Check which files are missing
46
+ missing_files = [f for f in required_files if not (model_dir / f).exists()]
47
+
48
+ if not missing_files:
49
+ print(f'✅ All model files found in {model_dir}')
50
+ return tuple(str(model_dir / f) for f in required_files)
51
+
52
+ print(f'📥 Missing files: {missing_files}')
53
+ print(f'📥 Downloading model files to {model_dir}...')
54
+
55
+ download_success = False
56
+
57
+ try:
58
+ # Try ModelScope first (default)
59
+ print('🌐 Trying ModelScope (China mirror) first...')
60
+ try:
61
+ from modelscope import snapshot_download
62
+ model_path = snapshot_download(
63
+ 'topdktu/unirec_0_1b_onnx',
64
+ cache_dir=str(model_dir.parent)
65
+ )
66
+ print(f'✅ Downloaded to {model_path}')
67
+
68
+ # Copy files to target directory
69
+ import shutil
70
+ for file in required_files:
71
+ src = Path(model_path) / file
72
+ dst = model_dir / file
73
+ if src.exists() and not dst.exists():
74
+ shutil.copy(str(src), str(dst))
75
+ print(f' ✓ {file}')
76
+
77
+ # Verify all files exist after download
78
+ all_files_exist = all((model_dir / f).exists() for f in required_files)
79
+ if all_files_exist:
80
+ download_success = True
81
+ print('✅ All files downloaded successfully from ModelScope!')
82
+ else:
83
+ print('⚠️ ModelScope download incomplete, trying HuggingFace...')
84
+
85
+ except ImportError:
86
+ print('⚠️ modelscope not installed. Install with: pip install modelscope')
87
+ print(' Trying HuggingFace...')
88
+ except Exception as e:
89
+ print(f'⚠️ ModelScope download failed: {e}')
90
+ print(' Trying HuggingFace...')
91
+
92
+ if not download_success:
93
+ # Try HuggingFace
94
+ print('🌐 Using HuggingFace...')
95
+ try:
96
+ from huggingface_hub import hf_hub_download
97
+
98
+ for file in missing_files:
99
+ print(f' Downloading {file}...')
100
+ downloaded_path = hf_hub_download(
101
+ repo_id='topdu/unirec_0_1b_onnx',
102
+ filename=file,
103
+ cache_dir=str(model_dir.parent),
104
+ local_dir=str(model_dir),
105
+ local_dir_use_symlinks=False
106
+ )
107
+ print(f' ✓ {file}')
108
+
109
+ # Verify all files exist after download
110
+ all_files_exist = all((model_dir / f).exists() for f in required_files)
111
+ if all_files_exist:
112
+ download_success = True
113
+ print('✅ All files downloaded successfully from HuggingFace!')
114
+
115
+ except ImportError:
116
+ print('⚠️ huggingface_hub not installed. Install with: pip install huggingface_hub')
117
+ raise RuntimeError(
118
+ 'Cannot download models. Please install either:\n'
119
+ ' - huggingface_hub: pip install huggingface_hub\n'
120
+ ' - modelscope: pip install modelscope\n'
121
+ 'Or manually download from:\n'
122
+ ' - https://huggingface.co/topdu/unirec_0_1b_onnx\n'
123
+ ' - https://modelscope.cn/models/topdktu/unirec_0_1b_onnx'
124
+ )
125
+
126
+ if not download_success:
127
+ raise RuntimeError(
128
+ 'Failed to download all required files. Please manually download from:\n'
129
+ ' - https://huggingface.co/topdu/unirec_0_1b_onnx\n'
130
+ ' - https://modelscope.cn/models/topdktu/unirec_0_1b_onnx'
131
+ )
132
+ except Exception as e:
133
+ print(f'❌ Download failed: {e}')
134
+ print('\n📝 Manual download instructions:')
135
+ print(' 1. Visit: https://huggingface.co/topdu/unirec_0_1b_onnx')
136
+ print(' or: https://modelscope.cn/models/topdktu/unirec_0_1b_onnx')
137
+ print(f' 2. Download these files to {model_dir}:')
138
+ for file in required_files:
139
+ print(f' - {file}')
140
+ raise
141
+
142
+ return tuple(str(model_dir / f) for f in required_files)
143
+
144
+
145
+ def check_and_download_models(encoder_path, decoder_path, mapping_path, auto_download=True):
146
+ """Check if model files exist, download if missing.
147
+
148
+ Args:
149
+ encoder_path: Path to encoder ONNX model
150
+ decoder_path: Path to decoder ONNX model
151
+ mapping_path: Path to tokenizer mapping JSON
152
+ auto_download: If True, automatically download missing files
153
+
154
+ Returns:
155
+ Tuple of (encoder_path, decoder_path, mapping_path) with verified paths
156
+ """
157
+ files_to_check = {
158
+ 'encoder': encoder_path,
159
+ 'decoder': decoder_path,
160
+ 'mapping': mapping_path
161
+ }
162
+
163
+ missing_files = {k: v for k, v in files_to_check.items() if not os.path.exists(v)}
164
+
165
+ if not missing_files:
166
+ return encoder_path, decoder_path, mapping_path
167
+
168
+ print('⚠️ Missing model files:')
169
+ for name, path in missing_files.items():
170
+ print(f' - {name}: {path}')
171
+
172
+ if not auto_download:
173
+ raise FileNotFoundError(
174
+ 'Model files not found. Please download from:\n'
175
+ ' - https://huggingface.co/topdu/unirec_0_1b_onnx\n'
176
+ ' - https://modelscope.cn/models/topdktu/unirec_0_1b_onnx'
177
+ )
178
+
179
+ # Determine model directory from encoder path
180
+ encoder_dir = os.path.dirname(encoder_path)
181
+ if encoder_dir and encoder_dir != './unirec_0_1b_onnx':
182
+ # User specified a custom path
183
+ model_dir = encoder_dir
184
+ else:
185
+ # Use default cache directory
186
+ model_dir = None
187
+
188
+ # Try ModelScope first (faster in China), then HuggingFace
189
+ try:
190
+ print('🇨🇳 Trying ModelScope (China mirror) first...')
191
+ return download_model_files(model_dir)
192
+ except:
193
+ print('🌍 Trying HuggingFace...')
194
+ return download_model_files(model_dir)
195
+
196
+
197
+ class SimpleImageProcessor:
198
+ """Standalone image processor without transformers dependency."""
199
+
200
+ def __init__(
201
+ self,
202
+ max_side=(960, 1408), # (width, height)
203
+ divided_factor=(64, 64),
204
+ image_mean=(0.5, 0.5, 0.5),
205
+ image_std=(0.5, 0.5, 0.5),
206
+ ):
207
+ self.max_side = max_side
208
+ self.divided_factor = divided_factor
209
+ self.image_mean = np.array(image_mean, dtype=np.float32)
210
+ self.image_std = np.array(image_std, dtype=np.float32)
211
+
212
+ def _calculate_target_size(self, original_width, original_height):
213
+ """Calculate target size with aspect ratio preservation."""
214
+ max_width, max_height = self.max_side
215
+ aspect_ratio = original_width / original_height
216
+
217
+ if original_width > max_width or original_height > max_height:
218
+ if (max_width / max_height) >= aspect_ratio:
219
+ new_height = max_height
220
+ new_width = int(new_height * aspect_ratio)
221
+ else:
222
+ new_width = max_width
223
+ new_height = int(new_width / aspect_ratio)
224
+ else:
225
+ new_width, new_height = original_width, original_height
226
+
227
+ # Apply divided factor
228
+ div_w, div_h = self.divided_factor
229
+ final_width = max(int(new_width // div_w * div_w), 64)
230
+ final_height = max(int(new_height // div_h * div_h), 64)
231
+
232
+ return (final_width, final_height)
233
+
234
+ def __call__(self, image):
235
+ """
236
+ Process image for model input.
237
+
238
+ Args:
239
+ image: PIL Image
240
+
241
+ Returns:
242
+ dict with 'pixel_values' as numpy array [1, 3, H, W]
243
+ """
244
+ if not isinstance(image, Image.Image):
245
+ raise ValueError('Input must be PIL Image')
246
+
247
+ original_width, original_height = image.size
248
+
249
+ # Resize
250
+ target_size = self._calculate_target_size(original_width,
251
+ original_height)
252
+ image = image.resize(target_size, resample=Image.BICUBIC)
253
+
254
+ # Convert to numpy array [H, W, C] and normalize to [0, 1]
255
+ image_np = np.array(image, dtype=np.float32)[:, :, :3] / 255.0
256
+
257
+ # Normalize: (x - mean) / std
258
+ image_np = (image_np - self.image_mean) / self.image_std
259
+
260
+ # Transpose to [C, H, W]
261
+ image_np = image_np.transpose(2, 0, 1)
262
+
263
+ # Add batch dimension [1, C, H, W]
264
+ image_np = np.expand_dims(image_np, axis=0)
265
+
266
+ return {'pixel_values': image_np}
267
+
268
+
269
+ class SimpleTokenizer:
270
+ """Standalone tokenizer without transformers dependency."""
271
+
272
+ def __init__(self, mapping_file=None):
273
+ """
274
+ Load vocabulary from mapping file or tokenizer.json.
275
+
276
+ Args:
277
+ vocab_file: path to tokenizer.json (deprecated, use mapping_file)
278
+ mapping_file: path to unirec_tokenizer_mapping.json (recommended)
279
+ """
280
+
281
+ if mapping_file and os.path.exists(mapping_file):
282
+ # 使用导出的映射文件 (推荐)
283
+ print(f'Loading tokenizer from mapping file: {mapping_file}')
284
+ with open(mapping_file, 'r', encoding='utf-8') as f:
285
+ mapping_data = json.load(f)
286
+
287
+ # 直接使用 id_to_token 映射
288
+ self.id_to_token = {
289
+ int(k): v
290
+ for k, v in mapping_data['id_to_token'].items()
291
+ }
292
+ self.vocab_size = mapping_data['vocab_size']
293
+
294
+ # 特殊 token
295
+ special_tokens = mapping_data['special_tokens']
296
+ self.bos_token_id = special_tokens['bos_token_id']
297
+ self.eos_token_id = special_tokens['eos_token_id']
298
+ self.pad_token_id = special_tokens['pad_token_id']
299
+
300
+ print(f'✅ Loaded vocabulary with {self.vocab_size} tokens')
301
+
302
+ def decode(self, token_ids, skip_special_tokens=False):
303
+ """
304
+ Decode token IDs to text.
305
+
306
+ Args:
307
+ token_ids: list of token IDs
308
+ skip_special_tokens: whether to skip special tokens
309
+
310
+ Returns:
311
+ decoded text string
312
+ """
313
+ tokens = []
314
+ for token_id in token_ids:
315
+ if token_id in self.id_to_token:
316
+ token = self.id_to_token[token_id]
317
+
318
+ # Skip special tokens if requested
319
+ if skip_special_tokens and token_id in [
320
+ self.bos_token_id, self.eos_token_id, self.pad_token_id
321
+ ]:
322
+ continue
323
+
324
+ tokens.append(token)
325
+ else:
326
+ tokens.append(f'<unk_{token_id}>')
327
+
328
+ # Join tokens
329
+ text = ''.join(tokens)
330
+
331
+ return text
332
+
333
+
334
+ def clean_special_tokens(text):
335
+ """Clean special tokens from decoded text."""
336
+ # Remove special formatting tokens
337
+ text = text.replace('Ġ', ' ').replace('Ċ', '\n')
338
+ text = text.replace('<|bos|>', '').replace('<|eos|>',
339
+ '').replace('<|pad|>', '')
340
+
341
+ # Apply regex rules
342
+ rules = [
343
+ (r'-<\|sn\|>', ''),
344
+ (r' <\|sn\|>', ' '),
345
+ (r'<\|sn\|>', ' '),
346
+ (r'<\|unk\|>', ''),
347
+ (r'<s>', ''),
348
+ (r'</s>', ''),
349
+ (r'\uffff', ''),
350
+ (r'_{4,}', '___'),
351
+ (r'\.{4,}', '...'),
352
+ ]
353
+
354
+ for pattern, replacement in rules:
355
+ text = re.sub(pattern, replacement, text)
356
+
357
+ return text
358
+
359
+
360
+ class UniRecONNX:
361
+ """ONNX-based inference for UniRec model (standalone version)."""
362
+
363
+ def __init__(
364
+ self,
365
+ encoder_path=None,
366
+ decoder_path=None,
367
+ mapping_path=None,
368
+ use_gpu=None,
369
+ auto_download=True,
370
+ ):
371
+ """Initialize ONNX inference sessions.
372
+
373
+ Args:
374
+ encoder_path: Path to encoder ONNX model. If None, use default cache directory.
375
+ decoder_path: Path to decoder ONNX model. If None, use default cache directory.
376
+ mapping_path: Path to tokenizer mapping JSON. If None, use default cache directory.
377
+ use_gpu: Whether to use GPU. If None, auto-detect. If True, force GPU. If False, force CPU.
378
+ auto_download: If True, automatically download missing model files
379
+ """
380
+ # Set default paths if not provided
381
+ if encoder_path is None or decoder_path is None or mapping_path is None:
382
+ cache_dir = Path.home() / '.cache' / 'openocr'
383
+ model_path = cache_dir / 'unirec_0_1b_onnx'
384
+ if encoder_path is None:
385
+ encoder_path = str(model_path / 'unirec_encoder.onnx')
386
+ if decoder_path is None:
387
+ decoder_path = str(model_path / 'unirec_decoder.onnx')
388
+ if mapping_path is None:
389
+ mapping_path = str(model_path / 'unirec_tokenizer_mapping.json')
390
+
391
+ # Check and download models if needed
392
+ encoder_path, decoder_path, mapping_path = check_and_download_models(
393
+ encoder_path, decoder_path, mapping_path, auto_download=auto_download
394
+ )
395
+
396
+ print('Loading ONNX models...')
397
+
398
+ # Determine execution provider
399
+ providers = self._get_execution_providers(use_gpu)
400
+ print(f'Using execution providers: {providers}')
401
+
402
+ # Create ONNX runtime sessions
403
+ sess_options = ort.SessionOptions()
404
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
405
+ self.decoder_session = ort.InferenceSession(decoder_path, sess_options, providers=providers)
406
+ self.encoder_session = ort.InferenceSession(encoder_path, sess_options, providers=providers)
407
+
408
+ # Initialize processor and tokenizer
409
+ self.processor = SimpleImageProcessor()
410
+ self.tokenizer = SimpleTokenizer(mapping_file=mapping_path)
411
+
412
+ # Get model info from decoder session
413
+ # Shape: [batch_size, num_heads, seq_len, head_dim]
414
+ self.num_decoder_layers = None
415
+ self.num_heads = None
416
+ self.head_dim = None
417
+
418
+ for inp in self.decoder_session.get_inputs():
419
+ if 'past_key' in inp.name:
420
+ layer_idx = int(inp.name.split('_')[-1])
421
+ if self.num_decoder_layers is None or layer_idx + 1 > self.num_decoder_layers:
422
+ self.num_decoder_layers = layer_idx + 1
423
+ # Get shape info: [batch_size, num_heads, seq_len, head_dim]
424
+ if len(inp.shape) == 4:
425
+ if self.num_heads is None and isinstance(
426
+ inp.shape[1], int):
427
+ self.num_heads = inp.shape[1]
428
+ if self.head_dim is None and isinstance(inp.shape[3], int):
429
+ self.head_dim = inp.shape[3]
430
+
431
+ # Calculate d_model
432
+ if self.num_heads and self.head_dim:
433
+ self.d_model = self.num_heads * self.head_dim
434
+ else:
435
+ self.d_model = None
436
+
437
+ print('\n✅ Models loaded successfully!')
438
+ print(f' Number of decoder layers: {self.num_decoder_layers}')
439
+ print(f' Number of attention heads: {self.num_heads}')
440
+ print(f' Head dimension: {self.head_dim}')
441
+ print(f' Model dimension (d_model): {self.d_model}')
442
+ print(f' Vocabulary size: {self.tokenizer.vocab_size}')
443
+
444
+ def _get_execution_providers(self, use_gpu):
445
+ """Determine execution providers based on GPU availability and user preference.
446
+
447
+ Args:
448
+ use_gpu: None (auto-detect), True (force GPU), or False (force CPU)
449
+
450
+ Returns:
451
+ List of execution providers in priority order
452
+ """
453
+ available_providers = ort.get_available_providers()
454
+
455
+ if use_gpu is False:
456
+ # Force CPU
457
+ print('🔧 User specified: Using CPU')
458
+ return ['CPUExecutionProvider']
459
+
460
+ # Check for GPU providers
461
+ gpu_providers = []
462
+ if 'CUDAExecutionProvider' in available_providers:
463
+ gpu_providers.append('CUDAExecutionProvider')
464
+ # if 'TensorrtExecutionProvider' in available_providers:
465
+ # gpu_providers.append('TensorrtExecutionProvider')
466
+
467
+ if use_gpu is True:
468
+ # Force GPU
469
+ if gpu_providers:
470
+ print(f'🔧 User specified: Using GPU ({gpu_providers[0]})')
471
+ return gpu_providers + ['CPUExecutionProvider']
472
+ else:
473
+ print('⚠️ GPU requested but not available, falling back to CPU')
474
+ return ['CPUExecutionProvider']
475
+
476
+ # Auto-detect (use_gpu is None)
477
+ if gpu_providers:
478
+ print(f'✅ GPU detected: Using {gpu_providers[0]}')
479
+ return gpu_providers + ['CPUExecutionProvider']
480
+ else:
481
+ print('ℹ️ No GPU detected, using CPU')
482
+ return ['CPUExecutionProvider']
483
+
484
+ def encode_image(self, image):
485
+ """Encode image using encoder ONNX model."""
486
+ # Preprocess image
487
+ data_img = self.processor(image)
488
+ pixel_values = data_img['pixel_values']
489
+
490
+ # Run encoder
491
+ encoder_outputs = self.encoder_session.run(
492
+ None, {'pixel_values': pixel_values.astype(np.float32)})
493
+
494
+ encoder_hidden_states = encoder_outputs[0]
495
+ cross_k = encoder_outputs[1]
496
+ cross_v = encoder_outputs[2]
497
+
498
+ return encoder_hidden_states, cross_k, cross_v
499
+
500
+ def decode_step(self,
501
+ input_id,
502
+ past_length,
503
+ cross_k,
504
+ cross_v,
505
+ past_key_values,
506
+ padding_idx=1):
507
+ """Unified decoder step with or without cache."""
508
+ # Prepare inputs
509
+ input_ids = np.array([[input_id]], dtype=np.int64)
510
+ # Use M2M100's position ID calculation with past_key_values_length
511
+ position_ids = np.array([[padding_idx + 1 + past_length]],
512
+ dtype=np.int64)
513
+
514
+ decoder_inputs = {
515
+ 'input_ids': input_ids,
516
+ 'position_ids': position_ids,
517
+ 'cross_k': cross_k.astype(np.float32),
518
+ 'cross_v': cross_v.astype(np.float32),
519
+ }
520
+
521
+ # Add past_key_values
522
+ for i, (past_key, past_value) in enumerate(past_key_values):
523
+ decoder_inputs[f'past_key_{i}'] = past_key.astype(np.float32)
524
+ decoder_inputs[f'past_value_{i}'] = past_value.astype(np.float32)
525
+
526
+ # Run decoder
527
+ decoder_outputs = self.decoder_session.run(None, decoder_inputs)
528
+
529
+ # Parse outputs
530
+ logits = decoder_outputs[0]
531
+
532
+ # Extract present_key_values
533
+ present_key_values = []
534
+ for i in range(self.num_decoder_layers):
535
+ key = decoder_outputs[1 + i * 2]
536
+ value = decoder_outputs[1 + i * 2 + 1]
537
+ present_key_values.append((key, value))
538
+
539
+ return logits, present_key_values
540
+
541
+ def __call__(
542
+ self,
543
+ img_path=None,
544
+ img_numpy=None,
545
+ image=None,
546
+ max_length=2048,
547
+ bos_token_id=None,
548
+ eos_token_id=None,
549
+ pad_token_id=None,
550
+ ):
551
+ """
552
+ Unified interface for UniRec inference.
553
+
554
+ Args:
555
+ img_path: Path to input image (str or Path)
556
+ img_numpy: Input image as numpy array (BGR format)
557
+ image: PIL Image object (RGB format)
558
+ max_length: Maximum generation length
559
+ bos_token_id: Beginning of sequence token ID
560
+ eos_token_id: End of sequence token ID
561
+ pad_token_id: Padding token ID
562
+
563
+ Returns:
564
+ Tuple of (generated_text, generated_ids)
565
+ """
566
+ # Load image from path, numpy array, or use provided PIL image
567
+ if img_path is not None:
568
+ image = Image.open(img_path).convert('RGB')
569
+ elif img_numpy is not None:
570
+ # Convert BGR to RGB if needed
571
+ if len(img_numpy.shape) == 3 and img_numpy.shape[2] == 3:
572
+ import cv2
573
+ img_numpy = cv2.cvtColor(img_numpy, cv2.COLOR_BGR2RGB)
574
+ image = Image.fromarray(img_numpy)
575
+ elif image is None:
576
+ raise ValueError('Either img_path, img_numpy, or image must be provided')
577
+
578
+ # Get token IDs
579
+ if bos_token_id is None:
580
+ bos_token_id = self.tokenizer.bos_token_id
581
+ if eos_token_id is None:
582
+ eos_token_id = self.tokenizer.eos_token_id
583
+ if pad_token_id is None:
584
+ pad_token_id = self.tokenizer.pad_token_id
585
+
586
+ # Encode image
587
+ print('Encoding image...')
588
+ t_start = time.time()
589
+ encoder_hidden_states, cross_k, cross_v = self.encode_image(image)
590
+ print(f'Encoding time: {time.time() - t_start:.2f} seconds')
591
+ print(f' cross_k shape: {cross_k.shape}')
592
+ print(f' cross_v shape: {cross_v.shape}')
593
+
594
+ # Initialize generation
595
+ print('Generating text...')
596
+ generated_ids = [bos_token_id]
597
+
598
+ # Initialize empty past_key_values for first step
599
+ # Shape: [batch_size, num_heads, 0, head_dim]
600
+ batch_size = encoder_hidden_states.shape[0]
601
+ past_key_values = []
602
+ for _ in range(self.num_decoder_layers):
603
+ empty_key = np.zeros(
604
+ (batch_size, self.num_heads, 0, self.head_dim),
605
+ dtype=np.float32)
606
+ empty_value = np.zeros(
607
+ (batch_size, self.num_heads, 0, self.head_dim),
608
+ dtype=np.float32)
609
+ past_key_values.append((empty_key, empty_value))
610
+
611
+ # Generation loop
612
+ t_start = time.time()
613
+ for step in range(max_length - 1):
614
+ # Current token to decode
615
+ current_token = generated_ids[-1]
616
+
617
+ # past_length is the sequence length in cache
618
+ past_length = step
619
+
620
+ # Decode step
621
+ logits, past_key_values = self.decode_step(
622
+ current_token,
623
+ past_length,
624
+ cross_k,
625
+ cross_v,
626
+ past_key_values,
627
+ padding_idx=pad_token_id)
628
+
629
+ # Get next token
630
+ next_token_id = int(np.argmax(logits[0, -1, :]))
631
+ generated_ids.append(next_token_id)
632
+
633
+ # Check for EOS
634
+ if next_token_id == eos_token_id:
635
+ break
636
+
637
+ # Progress indicator
638
+ if (step + 1) % 50 == 0:
639
+ print(f' Generated {step + 1} tokens...')
640
+
641
+ t_end = time.time()
642
+ print(f'✅ Generation complete! Total tokens: {len(generated_ids)}')
643
+ print(f' Time taken: {t_end - t_start:.2f} seconds')
644
+ print(
645
+ f' Tokens per second: {len(generated_ids) / (t_end - t_start):.2f}'
646
+ )
647
+
648
+ # Decode tokens
649
+ generated_text = self.tokenizer.decode(generated_ids,
650
+ skip_special_tokens=False)
651
+ cleaned_text = clean_special_tokens(generated_text)
652
+
653
+ return cleaned_text, generated_ids
654
+
655
+
656
+ def main():
657
+ """Example usage."""
658
+ import argparse
659
+
660
+ parser = argparse.ArgumentParser(
661
+ description='UniRec ONNX Inference (Standalone)')
662
+ parser.add_argument('--image',
663
+ type=str,
664
+ required=True,
665
+ help='Path to input image')
666
+ parser.add_argument('--encoder-model',
667
+ type=str,
668
+ default=None,
669
+ help='Path to encoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_encoder.onnx)')
670
+ parser.add_argument('--decoder-model',
671
+ type=str,
672
+ default=None,
673
+ help='Path to decoder ONNX model (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_decoder.onnx)')
674
+ parser.add_argument(
675
+ '--mapping',
676
+ type=str,
677
+ default=None,
678
+ help='Path to tokenizer mapping JSON (default: ~/.cache/openocr/unirec_0_1b_onnx/unirec_tokenizer_mapping.json)')
679
+ parser.add_argument('--max-length',
680
+ type=int,
681
+ default=2048,
682
+ help='Maximum generation length')
683
+ parser.add_argument('--use-gpu',
684
+ type=str,
685
+ default='auto',
686
+ choices=['auto', 'true', 'false'],
687
+ help='Use GPU for inference (auto: auto-detect, true: force GPU, false: force CPU)')
688
+ parser.add_argument('--no-auto-download',
689
+ action='store_true',
690
+ help='Disable automatic model download')
691
+ args = parser.parse_args()
692
+
693
+ # Parse use_gpu argument
694
+ if args.use_gpu == 'auto':
695
+ use_gpu = None
696
+ elif args.use_gpu == 'true':
697
+ use_gpu = True
698
+ else:
699
+ use_gpu = False
700
+
701
+ # Load image
702
+ print(f'Loading image: {args.image}')
703
+ image = Image.open(args.image).convert('RGB')
704
+
705
+ # Initialize inference
706
+ inference = UniRecONNX(
707
+ encoder_path=args.encoder_model,
708
+ decoder_path=args.decoder_model,
709
+ mapping_path=args.mapping,
710
+ use_gpu=use_gpu,
711
+ auto_download=not args.no_auto_download,
712
+ )
713
+
714
+ # Generate
715
+ result_text, generated_ids = inference(
716
+ image=image,
717
+ max_length=args.max_length,
718
+ )
719
+
720
+ # Print result
721
+ print('\n' + '=' * 80)
722
+ print('RESULT:')
723
+ print('=' * 80)
724
+ print(result_text)
725
+ print('=' * 80)
726
+ print(f'\nGenerated {len(generated_ids)} tokens')
727
+
728
+
729
+ if __name__ == '__main__':
730
+ main()