openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -75,7 +75,7 @@ def sorted_boxes(dt_boxes):
75
75
  return _boxes
76
76
 
77
77
 
78
- class OpenOCR(object):
78
+ class OpenOCRE2E(object):
79
79
 
80
80
  def __init__(self,
81
81
  mode='mobile',
@@ -84,7 +84,7 @@ class OpenOCR(object):
84
84
  onnx_rec_model_path=None,
85
85
  drop_score=0.5,
86
86
  det_box_type='quad',
87
- device='gpu'):
87
+ use_gpu='auto'):
88
88
  """
89
89
  初始化函数,用于初始化OCR引擎的相关配置和组件。
90
90
 
@@ -92,11 +92,26 @@ class OpenOCR(object):
92
92
  mode (str, optional): 运行模式,可选值为'mobile'或'server'。默认为'mobile'。
93
93
  drop_score (float, optional): 检测框的置信度阈值,低于该阈值的检测框将被丢弃。默认为0.5。
94
94
  det_box_type (str, optional): 检测框的类型,可选值为'quad' and 'poly'。默认为'quad'。
95
+ use_gpu (str, optional): GPU使用策略,可选值为'auto'/'true'/'false'。默认为'auto'。
95
96
 
96
97
  Returns:
97
98
  无返回值。
98
99
 
99
100
  """
101
+ # Parse use_gpu parameter
102
+ if use_gpu == 'auto':
103
+ try:
104
+ import torch
105
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
106
+ except:
107
+ device = 'cpu'
108
+ elif use_gpu == 'true':
109
+ device = 'gpu'
110
+ elif use_gpu == 'false':
111
+ device = 'cpu'
112
+ else:
113
+ raise ValueError(f"use_gpu must be 'auto', 'true', or 'false', got '{use_gpu}'")
114
+
100
115
  cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg # mobile model
101
116
  cfg_det['Global']['device'] = device
102
117
  if mode == 'server':
@@ -108,9 +123,10 @@ class OpenOCR(object):
108
123
 
109
124
  self.text_detector = OpenDetector(cfg_det,
110
125
  backend=backend,
111
- onnx_model_path=onnx_det_model_path)
126
+ onnx_model_path=onnx_det_model_path,
127
+ use_gpu=use_gpu)
112
128
  self.text_recognizer = OpenRecognizer(
113
- cfg_rec, backend=backend, onnx_model_path=onnx_rec_model_path)
129
+ cfg_rec, backend=backend, onnx_model_path=onnx_rec_model_path, use_gpu=use_gpu)
114
130
  self.det_box_type = det_box_type
115
131
  self.drop_score = drop_score
116
132
 
@@ -415,10 +431,11 @@ def main():
415
431
  type=float,
416
432
  default=0.5,
417
433
  help='Score threshold for text recognition.')
418
- parser.add_argument('--device',
434
+ parser.add_argument('--use_gpu',
419
435
  type=str,
420
- default='gpu',
421
- help='Device to use for inference.')
436
+ default='auto',
437
+ choices=['auto', 'true', 'false'],
438
+ help='GPU usage strategy: auto (detect automatically), true (force GPU), false (force CPU)')
422
439
  args = parser.parse_args()
423
440
 
424
441
  img_path = args.img_path
@@ -429,15 +446,15 @@ def main():
429
446
  save_dir = args.save_dir
430
447
  is_visualize = args.is_vis
431
448
  drop_score = args.drop_score
432
- device = args.device
449
+ use_gpu = args.use_gpu
433
450
 
434
- text_sys = OpenOCR(mode=mode,
451
+ text_sys = OpenOCRE2E(mode=mode,
435
452
  backend=backend,
436
453
  onnx_det_model_path=onnx_det_model_path,
437
454
  onnx_rec_model_path=onnx_rec_model_path,
438
455
  drop_score=drop_score,
439
456
  det_box_type='quad',
440
- device=device) # det_box_type: 'quad' or 'poly'
457
+ use_gpu=use_gpu) # det_box_type: 'quad' or 'poly'
441
458
  text_sys(img_path=img_path, save_dir=save_dir, is_visualize=is_visualize)
442
459
 
443
460
 
@@ -127,7 +127,7 @@ def build_rec_process(cfg):
127
127
  ratio_resize_flag = True
128
128
  for op in cfg['Eval']['dataset']['transforms']:
129
129
  op_name = list(op)[0]
130
- if 'Resize' in op_name:
130
+ if 'Resize' in op_name or 'Processor' in op_name:
131
131
  ratio_resize_flag = False
132
132
  if 'Label' in op_name:
133
133
  continue
@@ -149,6 +149,8 @@ def set_device(device, numId=0):
149
149
  import torch
150
150
  if device == 'gpu' and torch.cuda.is_available():
151
151
  device = torch.device(f'cuda:{numId}')
152
+ elif device == 'mps' and torch.backends.mps.is_available():
153
+ device = torch.device('mps')
152
154
  else:
153
155
  logger.info('GPU is not available, using CPU.')
154
156
  device = torch.device('cpu')
@@ -162,6 +164,7 @@ class OpenRecognizer:
162
164
  mode='mobile',
163
165
  backend='torch',
164
166
  onnx_model_path=None,
167
+ use_gpu='auto',
165
168
  numId=0):
166
169
  """
167
170
  Args:
@@ -169,12 +172,30 @@ class OpenRecognizer:
169
172
  mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。
170
173
  backend (str): 'torch' 或 'onnx'
171
174
  onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要)
175
+ use_gpu (str, optional): GPU使用策略,可选值为'auto'/'true'/'false'。默认为'auto'。
172
176
  numId (int, optional): 设备编号。默认为0。
173
177
  """
174
178
 
175
179
  if config is None:
176
180
  config_file = DEFAULT_CFG_PATH_REC_SERVER if mode == 'server' else DEFAULT_CFG_PATH_REC
177
181
  config = Config(config_file).cfg
182
+
183
+ # Parse use_gpu parameter
184
+ if use_gpu == 'auto':
185
+ try:
186
+ import torch
187
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
188
+ except:
189
+ device = 'cpu'
190
+ elif use_gpu == 'true':
191
+ device = 'gpu'
192
+ elif use_gpu == 'false':
193
+ device = 'cpu'
194
+ else:
195
+ raise ValueError(f"use_gpu must be 'auto', 'true', or 'false', got '{use_gpu}'")
196
+
197
+ config['Global']['device'] = device
198
+
178
199
  self.cfg = config
179
200
  # 公共初始化
180
201
  self._init_common()
@@ -197,7 +218,7 @@ class OpenRecognizer:
197
218
  else:
198
219
  raise ValueError('ONNX模式需要指定onnx_model_path参数')
199
220
  self.onnx_rec_engine = ONNXEngine(
200
- onnx_model_path, use_gpu=config['Global']['device'] == 'gpu')
221
+ onnx_model_path, use_gpu=(device == 'gpu'))
201
222
  else:
202
223
  raise ValueError("backend参数必须是'torch'或'onnx'")
203
224
 
@@ -222,26 +243,43 @@ class OpenRecognizer:
222
243
 
223
244
  def _init_torch_model(self, numId):
224
245
  from tools.utils.ckpt import load_ckpt
225
- from tools.infer_det import replace_batchnorm
226
- # PyTorch专用初始化
227
246
  algorithm_name = self.cfg['Architecture']['algorithm']
228
- if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']:
229
- if not os.path.exists(self.cfg['Global']['pretrained_model']):
230
- pretrained_model = check_and_download_model(
231
- MODEL_NAME_REC, DOWNLOAD_URL_REC
232
- ) if algorithm_name == 'SVTRv2_mobile' else check_and_download_model(
233
- MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
234
- self.cfg['Global']['pretrained_model'] = pretrained_model
235
-
236
- from openrec.modeling import build_model as build_rec_model
247
+ if self.cfg['Global'].get('use_transformers', False):
248
+ if algorithm_name == 'UniRec':
249
+ from openrec.modeling.unirec_modeling.modeling_unirec import UniRecForConditionalGenerationNew
250
+ from openrec.modeling.unirec_modeling.configuration_unirec import UniRecConfig
251
+ cfg_model = UniRecConfig.from_pretrained(
252
+ self.cfg['Global']['vlm_ocr_config'])
253
+ # cfg_model._attn_implementation = "flash_attention_2"
254
+ cfg_model._attn_implementation = 'eager'
255
+ self.model = UniRecForConditionalGenerationNew(
256
+ config=cfg_model)
257
+ elif algorithm_name == 'CMER':
258
+ from openrec.modeling.cmer_modeling.modeling_cmer import CMER, CMERConfig
259
+ cfg_model = CMERConfig(
260
+ self.cfg['Architecture']['vision_config'],
261
+ self.cfg['Architecture']['decoder_config'])
262
+ self.model = CMER(config=cfg_model)
263
+ else:
264
+ # PyTorch专用初始化
265
+ if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']:
266
+ if not os.path.exists(self.cfg['Global']['pretrained_model']):
267
+ pretrained_model = check_and_download_model(
268
+ MODEL_NAME_REC, DOWNLOAD_URL_REC
269
+ ) if algorithm_name == 'SVTRv2_mobile' else check_and_download_model(
270
+ MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
271
+ self.cfg['Global']['pretrained_model'] = pretrained_model
272
+
273
+ from openrec.modeling import build_model as build_rec_model
274
+ self.model = build_rec_model(self.cfg['Architecture'])
237
275
 
238
- self.model = build_rec_model(self.cfg['Architecture'])
239
276
  load_ckpt(self.model, self.cfg)
240
277
 
241
278
  self.device = set_device(self.cfg['Global']['device'], numId)
242
279
  self.model.to(self.device)
243
280
  self.model.eval()
244
281
  if algorithm_name == 'SVTRv2_mobile':
282
+ from tools.infer_det import replace_batchnorm
245
283
  replace_batchnorm(self.model.encoder)
246
284
 
247
285
  def _inference_onnx(self, images):
@@ -329,7 +367,18 @@ class OpenRecognizer:
329
367
  images = self.torch.from_numpy(padded_batch).to(
330
368
  device=self.device)
331
369
  with self.torch.no_grad():
332
- preds = self.model(images, others) # bs, len, num_classes
370
+ if self.cfg['Global'].get('use_transformers', False):
371
+ # transformers模型推理
372
+ inputs = {
373
+ 'pixel_values': images,
374
+ 'input_ids': None,
375
+ 'attention_mask': None
376
+ }
377
+ preds = self.model.generate(**inputs)
378
+ else:
379
+ # PyTorch模型推理
380
+ preds = self.model(images,
381
+ others) # bs, len, num_classes
333
382
  torch_tensor = True
334
383
  elif self.backend == 'onnx':
335
384
  # ONNX推理