magic-pdf 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +24 -0
  2. magic_pdf/filter/__init__.py +1 -1
  3. magic_pdf/filter/pdf_classify_by_type.py +6 -4
  4. magic_pdf/filter/pdf_meta_scan.py +4 -4
  5. magic_pdf/libs/boxbase.py +5 -2
  6. magic_pdf/libs/draw_bbox.py +14 -2
  7. magic_pdf/libs/language.py +9 -0
  8. magic_pdf/libs/pdf_check.py +11 -1
  9. magic_pdf/libs/version.py +1 -1
  10. magic_pdf/model/batch_analyze.py +103 -99
  11. magic_pdf/model/doc_analyze_by_custom_model.py +87 -36
  12. magic_pdf/model/magic_model.py +161 -4
  13. magic_pdf/model/pdf_extract_kit.py +23 -28
  14. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
  15. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
  16. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
  17. magic_pdf/model/sub_modules/model_init.py +34 -19
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
  19. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
  20. magic_pdf/pdf_parse_union_core_v2.py +176 -61
  21. magic_pdf/post_proc/llm_aided.py +55 -24
  22. magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
  23. magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
  24. magic_pdf/resources/model_config/model_configs.yaml +2 -2
  25. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/METADATA +36 -19
  26. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/RECORD +30 -30
  27. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/LICENSE.md +0 -0
  28. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/WHEEL +0 -0
  29. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/entry_points.txt +0 -0
  30. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/top_level.txt +0 -0
@@ -450,11 +450,168 @@ class MagicModel:
450
450
  )
451
451
  return ret
452
452
 
453
+
454
+ def __tie_up_category_by_distance_v3(
455
+ self,
456
+ page_no: int,
457
+ subject_category_id: int,
458
+ object_category_id: int,
459
+ priority_pos: PosRelationEnum,
460
+ ):
461
+ subjects = self.__reduct_overlap(
462
+ list(
463
+ map(
464
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
465
+ filter(
466
+ lambda x: x['category_id'] == subject_category_id,
467
+ self.__model_list[page_no]['layout_dets'],
468
+ ),
469
+ )
470
+ )
471
+ )
472
+ objects = self.__reduct_overlap(
473
+ list(
474
+ map(
475
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
476
+ filter(
477
+ lambda x: x['category_id'] == object_category_id,
478
+ self.__model_list[page_no]['layout_dets'],
479
+ ),
480
+ )
481
+ )
482
+ )
483
+
484
+ ret = []
485
+ N, M = len(subjects), len(objects)
486
+ subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
487
+ objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
488
+
489
+ OBJ_IDX_OFFSET = 10000
490
+ SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
491
+
492
+ all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
493
+ seen_idx = set()
494
+ seen_sub_idx = set()
495
+
496
+ while N > len(seen_sub_idx):
497
+ candidates = []
498
+ for idx, kind, x0, y0 in all_boxes_with_idx:
499
+ if idx in seen_idx:
500
+ continue
501
+ candidates.append((idx, kind, x0, y0))
502
+
503
+ if len(candidates) == 0:
504
+ break
505
+ left_x = min([v[2] for v in candidates])
506
+ top_y = min([v[3] for v in candidates])
507
+
508
+ candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
509
+
510
+
511
+ fst_idx, fst_kind, left_x, top_y = candidates[0]
512
+ candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
513
+ nxt = None
514
+
515
+ for i in range(1, len(candidates)):
516
+ if candidates[i][1] ^ fst_kind == 1:
517
+ nxt = candidates[i]
518
+ break
519
+ if nxt is None:
520
+ break
521
+
522
+ if fst_kind == SUB_BIT_KIND:
523
+ sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
524
+
525
+ else:
526
+ sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
527
+
528
+ pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
529
+ nearest_dis = float('inf')
530
+ for i in range(N):
531
+ if i in seen_idx:continue
532
+ nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
533
+
534
+ if pair_dis >= 3*nearest_dis:
535
+ seen_idx.add(sub_idx)
536
+ continue
537
+
538
+
539
+ seen_idx.add(sub_idx)
540
+ seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
541
+ seen_sub_idx.add(sub_idx)
542
+
543
+ ret.append(
544
+ {
545
+ 'sub_bbox': {
546
+ 'bbox': subjects[sub_idx]['bbox'],
547
+ 'score': subjects[sub_idx]['score'],
548
+ },
549
+ 'obj_bboxes': [
550
+ {'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
551
+ ],
552
+ 'sub_idx': sub_idx,
553
+ }
554
+ )
555
+
556
+ for i in range(len(objects)):
557
+ j = i + OBJ_IDX_OFFSET
558
+ if j in seen_idx:
559
+ continue
560
+ seen_idx.add(j)
561
+ nearest_dis, nearest_sub_idx = float('inf'), -1
562
+ for k in range(len(subjects)):
563
+ dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
564
+ if dis < nearest_dis:
565
+ nearest_dis = dis
566
+ nearest_sub_idx = k
567
+
568
+ for k in range(len(subjects)):
569
+ if k != nearest_sub_idx: continue
570
+ if k in seen_sub_idx:
571
+ for kk in range(len(ret)):
572
+ if ret[kk]['sub_idx'] == k:
573
+ ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
574
+ break
575
+ else:
576
+ ret.append(
577
+ {
578
+ 'sub_bbox': {
579
+ 'bbox': subjects[k]['bbox'],
580
+ 'score': subjects[k]['score'],
581
+ },
582
+ 'obj_bboxes': [
583
+ {'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
584
+ ],
585
+ 'sub_idx': k,
586
+ }
587
+ )
588
+ seen_sub_idx.add(k)
589
+ seen_idx.add(k)
590
+
591
+
592
+ for i in range(len(subjects)):
593
+ if i in seen_sub_idx:
594
+ continue
595
+ ret.append(
596
+ {
597
+ 'sub_bbox': {
598
+ 'bbox': subjects[i]['bbox'],
599
+ 'score': subjects[i]['score'],
600
+ },
601
+ 'obj_bboxes': [],
602
+ 'sub_idx': i,
603
+ }
604
+ )
605
+
606
+
607
+ return ret
608
+
609
+
453
610
  def get_imgs_v2(self, page_no: int):
454
- with_captions = self.__tie_up_category_by_distance_v2(
611
+ with_captions = self.__tie_up_category_by_distance_v3(
455
612
  page_no, 3, 4, PosRelationEnum.BOTTOM
456
613
  )
457
- with_footnotes = self.__tie_up_category_by_distance_v2(
614
+ with_footnotes = self.__tie_up_category_by_distance_v3(
458
615
  page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
459
616
  )
460
617
  ret = []
@@ -470,10 +627,10 @@ class MagicModel:
470
627
  return ret
471
628
 
472
629
  def get_tables_v2(self, page_no: int) -> list:
473
- with_captions = self.__tie_up_category_by_distance_v2(
630
+ with_captions = self.__tie_up_category_by_distance_v3(
474
631
  page_no, 5, 6, PosRelationEnum.UP
475
632
  )
476
- with_footnotes = self.__tie_up_category_by_distance_v2(
633
+ with_footnotes = self.__tie_up_category_by_distance_v3(
477
634
  page_no, 5, 7, PosRelationEnum.ALL
478
635
  )
479
636
  ret = []
@@ -69,6 +69,7 @@ class CustomPEKModel:
69
69
  self.apply_table = self.table_config.get('enable', False)
70
70
  self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
71
71
  self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
72
+ self.table_sub_model_name = self.table_config.get('sub_model', None)
72
73
 
73
74
  # ocr config
74
75
  self.apply_ocr = ocr
@@ -88,13 +89,6 @@ class CustomPEKModel:
88
89
  # 初始化解析方案
89
90
  self.device = kwargs.get('device', 'cpu')
90
91
 
91
- if str(self.device).startswith("npu"):
92
- import torch_npu
93
- os.environ['FLAGS_npu_jit_compile'] = '0'
94
- os.environ['FLAGS_use_stride_kernel'] = '0'
95
- elif str(self.device).startswith("mps"):
96
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
97
-
98
92
  logger.info('using device: {}'.format(self.device))
99
93
  models_dir = kwargs.get(
100
94
  'models_dir', os.path.join(root_dir, 'resources', 'models')
@@ -144,7 +138,7 @@ class CustomPEKModel:
144
138
  model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
145
139
  )
146
140
  ),
147
- device=self.device,
141
+ device='cpu' if str(self.device).startswith("mps") else self.device,
148
142
  )
149
143
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
150
144
  self.layout_model = atom_model_manager.get_atom_model(
@@ -174,6 +168,7 @@ class CustomPEKModel:
174
168
  table_max_time=self.table_max_time,
175
169
  device=self.device,
176
170
  ocr_engine=self.ocr_model,
171
+ table_sub_model_name=self.table_sub_model_name
177
172
  )
178
173
 
179
174
  logger.info('DocAnalysis init done!')
@@ -192,24 +187,24 @@ class CustomPEKModel:
192
187
  layout_res = self.layout_model(image, ignore_catids=[])
193
188
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
194
189
  # doclayout_yolo
195
- if height > width:
196
- input_res = {"poly":[0,0,width,0,width,height,0,height]}
197
- new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
198
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
199
- layout_res = self.layout_model.predict(new_image)
200
- for res in layout_res:
201
- p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
202
- p1 = p1 - paste_x + xmin
203
- p2 = p2 - paste_y + ymin
204
- p3 = p3 - paste_x + xmin
205
- p4 = p4 - paste_y + ymin
206
- p5 = p5 - paste_x + xmin
207
- p6 = p6 - paste_y + ymin
208
- p7 = p7 - paste_x + xmin
209
- p8 = p8 - paste_y + ymin
210
- res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
211
- else:
212
- layout_res = self.layout_model.predict(image)
190
+ # if height > width:
191
+ # input_res = {"poly":[0,0,width,0,width,height,0,height]}
192
+ # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
193
+ # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
194
+ # layout_res = self.layout_model.predict(new_image)
195
+ # for res in layout_res:
196
+ # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
197
+ # p1 = p1 - paste_x + xmin
198
+ # p2 = p2 - paste_y + ymin
199
+ # p3 = p3 - paste_x + xmin
200
+ # p4 = p4 - paste_y + ymin
201
+ # p5 = p5 - paste_x + xmin
202
+ # p6 = p6 - paste_y + ymin
203
+ # p7 = p7 - paste_x + xmin
204
+ # p8 = p8 - paste_y + ymin
205
+ # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
206
+ # else:
207
+ layout_res = self.layout_model.predict(image)
213
208
 
214
209
  layout_cost = round(time.time() - layout_start, 2)
215
210
  logger.info(f'layout detection time: {layout_cost}')
@@ -228,7 +223,7 @@ class CustomPEKModel:
228
223
  logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
229
224
 
230
225
  # 清理显存
231
- clean_vram(self.device, vram_threshold=8)
226
+ clean_vram(self.device, vram_threshold=6)
232
227
 
233
228
  # 从layout_res中获取ocr区域、表格区域、公式区域
234
229
  ocr_res_list, table_res_list, single_page_mfdetrec_res = (
@@ -276,7 +271,7 @@ class CustomPEKModel:
276
271
  elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
277
272
  html_code = self.table_model.img2html(new_image)
278
273
  elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
279
- html_code, table_cell_bboxes, elapse = self.table_model.predict(
274
+ html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
280
275
  new_image
281
276
  )
282
277
  run_time = time.time() - single_table_start_time
@@ -1,4 +1,5 @@
1
1
  # Copyright (c) Opendatalab. All rights reserved.
2
+ import time
2
3
  from collections import Counter
3
4
  from uuid import uuid4
4
5
 
@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object):
102
103
  temp_images = split_images(image)
103
104
  for temp_image in temp_images:
104
105
  all_images.append(resize_images_to_224(temp_image))
105
-
106
- images_lang_res = self.batch_predict(all_images, batch_size=8)
107
- # logger.info(f"images_lang_res: {images_lang_res}")
106
+ # langdetect_start = time.time()
107
+ images_lang_res = self.batch_predict(all_images, batch_size=256)
108
+ # logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
108
109
  if len(images_lang_res) > 0:
109
110
  count_dict = Counter(images_lang_res)
110
111
  language = max(count_dict, key=count_dict.get)
@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
9
9
  def predict(self, image):
10
10
  layout_res = []
11
11
  doclayout_yolo_res = self.model.predict(
12
- image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
12
+ image,
13
+ imgsz=1280,
14
+ conf=0.10,
15
+ iou=0.45,
16
+ verbose=False, device=self.device
13
17
  )[0]
14
18
  for xyxy, conf, cla in zip(
15
19
  doclayout_yolo_res.boxes.xyxy.cpu(),
@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
32
36
  image_res.cpu()
33
37
  for image_res in self.model.predict(
34
38
  images[index : index + batch_size],
35
- imgsz=1024,
36
- conf=0.25,
39
+ imgsz=1280,
40
+ conf=0.10,
37
41
  iou=0.45,
38
42
  verbose=False,
39
43
  device=self.device,
@@ -89,7 +89,7 @@ class UnimernetModel(object):
89
89
  mf_image_list.append(bbox_img)
90
90
 
91
91
  dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
92
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
92
+ dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
93
93
  mfr_res = []
94
94
  for mf_img in dataloader:
95
95
  mf_img = mf_img.to(self.device)
@@ -4,24 +4,39 @@ from loguru import logger
4
4
  from magic_pdf.config.constants import MODEL_NAME
5
5
  from magic_pdf.model.model_list import AtomicModel
6
6
  from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
7
- from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
8
- DocLayoutYOLOModel
9
- from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
10
- Layoutlmv3_Predictor
7
+ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
8
+ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
11
9
  from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
12
10
  from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
13
- from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
14
- ModifiedPaddleOCR
15
- from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
16
- RapidTableModel
17
- # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
18
- from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
19
- StructTableModel
20
- from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
21
- TableMasterPaddleModel
22
-
23
-
24
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
11
+
12
+ try:
13
+ from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
14
+ from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
15
+ from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
16
+ license_key = load_license()
17
+ logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
18
+ f' License expired at {license_key["payload"]["date"]["end_date"]}')
19
+ except Exception as e:
20
+ if isinstance(e, ImportError):
21
+ pass
22
+ elif isinstance(e, LicenseFormatError):
23
+ logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
24
+ elif isinstance(e, LicenseSignatureError):
25
+ logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
26
+ elif isinstance(e, LicenseExpiredError):
27
+ logger.error("Ascend Plugin: License has expired. Please renew your license.")
28
+ elif isinstance(e, FileNotFoundError):
29
+ logger.error("Ascend Plugin: Not found License file.")
30
+ else:
31
+ logger.error(f"Ascend Plugin: {e}")
32
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
33
+ # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
34
+ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
35
+
36
+ from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
37
+ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
38
+
39
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
25
40
  if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
26
41
  table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
27
42
  elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -31,7 +46,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
31
46
  }
32
47
  table_model = TableMasterPaddleModel(config)
33
48
  elif table_model_type == MODEL_NAME.RAPID_TABLE:
34
- table_model = RapidTableModel(ocr_engine)
49
+ table_model = RapidTableModel(ocr_engine, table_sub_model_name)
35
50
  else:
36
51
  logger.error('table model type not allow')
37
52
  exit(1)
@@ -76,7 +91,6 @@ def ocr_model_init(show_log: bool = False,
76
91
  use_dilation=True,
77
92
  det_db_unclip_ratio=1.8,
78
93
  ):
79
-
80
94
  if lang is not None and lang != '':
81
95
  model = ModifiedPaddleOCR(
82
96
  show_log=show_log,
@@ -163,7 +177,8 @@ def atom_model_init(model_name: str, **kwargs):
163
177
  kwargs.get('table_model_path'),
164
178
  kwargs.get('table_max_time'),
165
179
  kwargs.get('device'),
166
- kwargs.get('ocr_engine')
180
+ kwargs.get('ocr_engine'),
181
+ kwargs.get('table_sub_model_name')
167
182
  )
168
183
  elif model_name == AtomicModel.LangDetect:
169
184
  if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
@@ -7,6 +7,8 @@ import base64
7
7
  from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
8
8
  from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
9
9
 
10
+ import importlib.resources
11
+ from paddleocr import PaddleOCR
10
12
  from ppocr.utils.utility import check_and_read
11
13
 
12
14
 
@@ -327,30 +329,35 @@ class ONNXModelSingleton:
327
329
  return self._models[key]
328
330
 
329
331
  def onnx_model_init(key):
330
-
331
- import importlib.resources
332
-
333
- resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
334
-
335
- onnx_model = None
336
- additional_ocr_params = {
337
- "use_onnx": True,
338
- "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
339
- "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
340
- "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
341
- "det_db_box_thresh": key[1],
342
- "use_dilation": key[2],
343
- "det_db_unclip_ratio": key[3],
344
- }
345
- # logger.info(f"additional_ocr_params: {additional_ocr_params}")
346
- if key[0] is not None:
347
- additional_ocr_params["lang"] = key[0]
348
-
349
- from paddleocr import PaddleOCR
350
- onnx_model = PaddleOCR(**additional_ocr_params)
351
-
352
- if onnx_model is None:
353
- logger.error('model init failed')
332
+ if len(key) < 4:
333
+ logger.error('Invalid key length, expected at least 4 elements')
354
334
  exit(1)
355
- else:
356
- return onnx_model
335
+
336
+ try:
337
+ with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
338
+ additional_ocr_params = {
339
+ "use_onnx": True,
340
+ "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
341
+ "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
342
+ "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
343
+ "det_db_box_thresh": key[1],
344
+ "use_dilation": key[2],
345
+ "det_db_unclip_ratio": key[3],
346
+ }
347
+
348
+ if key[0] is not None:
349
+ additional_ocr_params["lang"] = key[0]
350
+
351
+ # logger.info(f"additional_ocr_params: {additional_ocr_params}")
352
+
353
+ onnx_model = PaddleOCR(**additional_ocr_params)
354
+
355
+ if onnx_model is None:
356
+ logger.error('model init failed')
357
+ exit(1)
358
+ else:
359
+ return onnx_model
360
+
361
+ except Exception as e:
362
+ logger.exception(f'Error initializing model: {e}')
363
+ exit(1)
@@ -2,12 +2,27 @@ import cv2
2
2
  import numpy as np
3
3
  import torch
4
4
  from loguru import logger
5
- from rapid_table import RapidTable
5
+ from rapid_table import RapidTable, RapidTableInput
6
+ from rapid_table.main import ModelType
7
+
8
+ from magic_pdf.libs.config_reader import get_device
6
9
 
7
10
 
8
11
  class RapidTableModel(object):
9
- def __init__(self, ocr_engine):
10
- self.table_model = RapidTable()
12
+ def __init__(self, ocr_engine, table_sub_model_name):
13
+ sub_model_list = [model.value for model in ModelType]
14
+ if table_sub_model_name is None:
15
+ input_args = RapidTableInput()
16
+ elif table_sub_model_name in sub_model_list:
17
+ if torch.cuda.is_available() and table_sub_model_name == "unitable":
18
+ input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
19
+ else:
20
+ input_args = RapidTableInput(model_type=table_sub_model_name)
21
+ else:
22
+ raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
23
+
24
+ self.table_model = RapidTable(input_args)
25
+
11
26
  # if ocr_engine is None:
12
27
  # self.ocr_model_name = "RapidOCR"
13
28
  # if torch.cuda.is_available():
@@ -45,7 +60,11 @@ class RapidTableModel(object):
45
60
  ocr_result = None
46
61
 
47
62
  if ocr_result:
48
- html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
49
- return html_code, table_cell_bboxes, elapse
63
+ table_results = self.table_model(np.asarray(image), ocr_result)
64
+ html_code = table_results.pred_html
65
+ table_cell_bboxes = table_results.cell_bboxes
66
+ logic_points = table_results.logic_points
67
+ elapse = table_results.elapse
68
+ return html_code, table_cell_bboxes, logic_points, elapse
50
69
  else:
51
- return None, None, None
70
+ return None, None, None, None