magic-pdf 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +24 -0
- magic_pdf/filter/__init__.py +1 -1
- magic_pdf/filter/pdf_classify_by_type.py +6 -4
- magic_pdf/filter/pdf_meta_scan.py +4 -4
- magic_pdf/libs/boxbase.py +5 -2
- magic_pdf/libs/draw_bbox.py +14 -2
- magic_pdf/libs/language.py +9 -0
- magic_pdf/libs/pdf_check.py +11 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +103 -99
- magic_pdf/model/doc_analyze_by_custom_model.py +87 -36
- magic_pdf/model/magic_model.py +161 -4
- magic_pdf/model/pdf_extract_kit.py +23 -28
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
- magic_pdf/model/sub_modules/model_init.py +34 -19
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
- magic_pdf/pdf_parse_union_core_v2.py +176 -61
- magic_pdf/post_proc/llm_aided.py +55 -24
- magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
- magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
- magic_pdf/resources/model_config/model_configs.yaml +2 -2
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/METADATA +36 -19
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/RECORD +30 -30
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/top_level.txt +0 -0
magic_pdf/model/magic_model.py
CHANGED
@@ -450,11 +450,168 @@ class MagicModel:
|
|
450
450
|
)
|
451
451
|
return ret
|
452
452
|
|
453
|
+
|
454
|
+
def __tie_up_category_by_distance_v3(
|
455
|
+
self,
|
456
|
+
page_no: int,
|
457
|
+
subject_category_id: int,
|
458
|
+
object_category_id: int,
|
459
|
+
priority_pos: PosRelationEnum,
|
460
|
+
):
|
461
|
+
subjects = self.__reduct_overlap(
|
462
|
+
list(
|
463
|
+
map(
|
464
|
+
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
465
|
+
filter(
|
466
|
+
lambda x: x['category_id'] == subject_category_id,
|
467
|
+
self.__model_list[page_no]['layout_dets'],
|
468
|
+
),
|
469
|
+
)
|
470
|
+
)
|
471
|
+
)
|
472
|
+
objects = self.__reduct_overlap(
|
473
|
+
list(
|
474
|
+
map(
|
475
|
+
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
476
|
+
filter(
|
477
|
+
lambda x: x['category_id'] == object_category_id,
|
478
|
+
self.__model_list[page_no]['layout_dets'],
|
479
|
+
),
|
480
|
+
)
|
481
|
+
)
|
482
|
+
)
|
483
|
+
|
484
|
+
ret = []
|
485
|
+
N, M = len(subjects), len(objects)
|
486
|
+
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
487
|
+
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
488
|
+
|
489
|
+
OBJ_IDX_OFFSET = 10000
|
490
|
+
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
|
491
|
+
|
492
|
+
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
|
493
|
+
seen_idx = set()
|
494
|
+
seen_sub_idx = set()
|
495
|
+
|
496
|
+
while N > len(seen_sub_idx):
|
497
|
+
candidates = []
|
498
|
+
for idx, kind, x0, y0 in all_boxes_with_idx:
|
499
|
+
if idx in seen_idx:
|
500
|
+
continue
|
501
|
+
candidates.append((idx, kind, x0, y0))
|
502
|
+
|
503
|
+
if len(candidates) == 0:
|
504
|
+
break
|
505
|
+
left_x = min([v[2] for v in candidates])
|
506
|
+
top_y = min([v[3] for v in candidates])
|
507
|
+
|
508
|
+
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
|
509
|
+
|
510
|
+
|
511
|
+
fst_idx, fst_kind, left_x, top_y = candidates[0]
|
512
|
+
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
|
513
|
+
nxt = None
|
514
|
+
|
515
|
+
for i in range(1, len(candidates)):
|
516
|
+
if candidates[i][1] ^ fst_kind == 1:
|
517
|
+
nxt = candidates[i]
|
518
|
+
break
|
519
|
+
if nxt is None:
|
520
|
+
break
|
521
|
+
|
522
|
+
if fst_kind == SUB_BIT_KIND:
|
523
|
+
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
|
524
|
+
|
525
|
+
else:
|
526
|
+
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
|
527
|
+
|
528
|
+
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
|
529
|
+
nearest_dis = float('inf')
|
530
|
+
for i in range(N):
|
531
|
+
if i in seen_idx:continue
|
532
|
+
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
|
533
|
+
|
534
|
+
if pair_dis >= 3*nearest_dis:
|
535
|
+
seen_idx.add(sub_idx)
|
536
|
+
continue
|
537
|
+
|
538
|
+
|
539
|
+
seen_idx.add(sub_idx)
|
540
|
+
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
|
541
|
+
seen_sub_idx.add(sub_idx)
|
542
|
+
|
543
|
+
ret.append(
|
544
|
+
{
|
545
|
+
'sub_bbox': {
|
546
|
+
'bbox': subjects[sub_idx]['bbox'],
|
547
|
+
'score': subjects[sub_idx]['score'],
|
548
|
+
},
|
549
|
+
'obj_bboxes': [
|
550
|
+
{'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
|
551
|
+
],
|
552
|
+
'sub_idx': sub_idx,
|
553
|
+
}
|
554
|
+
)
|
555
|
+
|
556
|
+
for i in range(len(objects)):
|
557
|
+
j = i + OBJ_IDX_OFFSET
|
558
|
+
if j in seen_idx:
|
559
|
+
continue
|
560
|
+
seen_idx.add(j)
|
561
|
+
nearest_dis, nearest_sub_idx = float('inf'), -1
|
562
|
+
for k in range(len(subjects)):
|
563
|
+
dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
|
564
|
+
if dis < nearest_dis:
|
565
|
+
nearest_dis = dis
|
566
|
+
nearest_sub_idx = k
|
567
|
+
|
568
|
+
for k in range(len(subjects)):
|
569
|
+
if k != nearest_sub_idx: continue
|
570
|
+
if k in seen_sub_idx:
|
571
|
+
for kk in range(len(ret)):
|
572
|
+
if ret[kk]['sub_idx'] == k:
|
573
|
+
ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
|
574
|
+
break
|
575
|
+
else:
|
576
|
+
ret.append(
|
577
|
+
{
|
578
|
+
'sub_bbox': {
|
579
|
+
'bbox': subjects[k]['bbox'],
|
580
|
+
'score': subjects[k]['score'],
|
581
|
+
},
|
582
|
+
'obj_bboxes': [
|
583
|
+
{'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
|
584
|
+
],
|
585
|
+
'sub_idx': k,
|
586
|
+
}
|
587
|
+
)
|
588
|
+
seen_sub_idx.add(k)
|
589
|
+
seen_idx.add(k)
|
590
|
+
|
591
|
+
|
592
|
+
for i in range(len(subjects)):
|
593
|
+
if i in seen_sub_idx:
|
594
|
+
continue
|
595
|
+
ret.append(
|
596
|
+
{
|
597
|
+
'sub_bbox': {
|
598
|
+
'bbox': subjects[i]['bbox'],
|
599
|
+
'score': subjects[i]['score'],
|
600
|
+
},
|
601
|
+
'obj_bboxes': [],
|
602
|
+
'sub_idx': i,
|
603
|
+
}
|
604
|
+
)
|
605
|
+
|
606
|
+
|
607
|
+
return ret
|
608
|
+
|
609
|
+
|
453
610
|
def get_imgs_v2(self, page_no: int):
|
454
|
-
with_captions = self.
|
611
|
+
with_captions = self.__tie_up_category_by_distance_v3(
|
455
612
|
page_no, 3, 4, PosRelationEnum.BOTTOM
|
456
613
|
)
|
457
|
-
with_footnotes = self.
|
614
|
+
with_footnotes = self.__tie_up_category_by_distance_v3(
|
458
615
|
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
|
459
616
|
)
|
460
617
|
ret = []
|
@@ -470,10 +627,10 @@ class MagicModel:
|
|
470
627
|
return ret
|
471
628
|
|
472
629
|
def get_tables_v2(self, page_no: int) -> list:
|
473
|
-
with_captions = self.
|
630
|
+
with_captions = self.__tie_up_category_by_distance_v3(
|
474
631
|
page_no, 5, 6, PosRelationEnum.UP
|
475
632
|
)
|
476
|
-
with_footnotes = self.
|
633
|
+
with_footnotes = self.__tie_up_category_by_distance_v3(
|
477
634
|
page_no, 5, 7, PosRelationEnum.ALL
|
478
635
|
)
|
479
636
|
ret = []
|
@@ -69,6 +69,7 @@ class CustomPEKModel:
|
|
69
69
|
self.apply_table = self.table_config.get('enable', False)
|
70
70
|
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
|
71
71
|
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
|
72
|
+
self.table_sub_model_name = self.table_config.get('sub_model', None)
|
72
73
|
|
73
74
|
# ocr config
|
74
75
|
self.apply_ocr = ocr
|
@@ -88,13 +89,6 @@ class CustomPEKModel:
|
|
88
89
|
# 初始化解析方案
|
89
90
|
self.device = kwargs.get('device', 'cpu')
|
90
91
|
|
91
|
-
if str(self.device).startswith("npu"):
|
92
|
-
import torch_npu
|
93
|
-
os.environ['FLAGS_npu_jit_compile'] = '0'
|
94
|
-
os.environ['FLAGS_use_stride_kernel'] = '0'
|
95
|
-
elif str(self.device).startswith("mps"):
|
96
|
-
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
97
|
-
|
98
92
|
logger.info('using device: {}'.format(self.device))
|
99
93
|
models_dir = kwargs.get(
|
100
94
|
'models_dir', os.path.join(root_dir, 'resources', 'models')
|
@@ -144,7 +138,7 @@ class CustomPEKModel:
|
|
144
138
|
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
|
145
139
|
)
|
146
140
|
),
|
147
|
-
device=self.device,
|
141
|
+
device='cpu' if str(self.device).startswith("mps") else self.device,
|
148
142
|
)
|
149
143
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
150
144
|
self.layout_model = atom_model_manager.get_atom_model(
|
@@ -174,6 +168,7 @@ class CustomPEKModel:
|
|
174
168
|
table_max_time=self.table_max_time,
|
175
169
|
device=self.device,
|
176
170
|
ocr_engine=self.ocr_model,
|
171
|
+
table_sub_model_name=self.table_sub_model_name
|
177
172
|
)
|
178
173
|
|
179
174
|
logger.info('DocAnalysis init done!')
|
@@ -192,24 +187,24 @@ class CustomPEKModel:
|
|
192
187
|
layout_res = self.layout_model(image, ignore_catids=[])
|
193
188
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
194
189
|
# doclayout_yolo
|
195
|
-
if height > width:
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
else:
|
212
|
-
|
190
|
+
# if height > width:
|
191
|
+
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
|
192
|
+
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
|
193
|
+
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
194
|
+
# layout_res = self.layout_model.predict(new_image)
|
195
|
+
# for res in layout_res:
|
196
|
+
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
|
197
|
+
# p1 = p1 - paste_x + xmin
|
198
|
+
# p2 = p2 - paste_y + ymin
|
199
|
+
# p3 = p3 - paste_x + xmin
|
200
|
+
# p4 = p4 - paste_y + ymin
|
201
|
+
# p5 = p5 - paste_x + xmin
|
202
|
+
# p6 = p6 - paste_y + ymin
|
203
|
+
# p7 = p7 - paste_x + xmin
|
204
|
+
# p8 = p8 - paste_y + ymin
|
205
|
+
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
|
206
|
+
# else:
|
207
|
+
layout_res = self.layout_model.predict(image)
|
213
208
|
|
214
209
|
layout_cost = round(time.time() - layout_start, 2)
|
215
210
|
logger.info(f'layout detection time: {layout_cost}')
|
@@ -228,7 +223,7 @@ class CustomPEKModel:
|
|
228
223
|
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
|
229
224
|
|
230
225
|
# 清理显存
|
231
|
-
clean_vram(self.device, vram_threshold=
|
226
|
+
clean_vram(self.device, vram_threshold=6)
|
232
227
|
|
233
228
|
# 从layout_res中获取ocr区域、表格区域、公式区域
|
234
229
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
@@ -276,7 +271,7 @@ class CustomPEKModel:
|
|
276
271
|
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
277
272
|
html_code = self.table_model.img2html(new_image)
|
278
273
|
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
279
|
-
html_code, table_cell_bboxes, elapse = self.table_model.predict(
|
274
|
+
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
|
280
275
|
new_image
|
281
276
|
)
|
282
277
|
run_time = time.time() - single_table_start_time
|
@@ -1,4 +1,5 @@
|
|
1
1
|
# Copyright (c) Opendatalab. All rights reserved.
|
2
|
+
import time
|
2
3
|
from collections import Counter
|
3
4
|
from uuid import uuid4
|
4
5
|
|
@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object):
|
|
102
103
|
temp_images = split_images(image)
|
103
104
|
for temp_image in temp_images:
|
104
105
|
all_images.append(resize_images_to_224(temp_image))
|
105
|
-
|
106
|
-
images_lang_res = self.batch_predict(all_images, batch_size=
|
107
|
-
# logger.info(f"
|
106
|
+
# langdetect_start = time.time()
|
107
|
+
images_lang_res = self.batch_predict(all_images, batch_size=256)
|
108
|
+
# logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
|
108
109
|
if len(images_lang_res) > 0:
|
109
110
|
count_dict = Counter(images_lang_res)
|
110
111
|
language = max(count_dict, key=count_dict.get)
|
@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
|
|
9
9
|
def predict(self, image):
|
10
10
|
layout_res = []
|
11
11
|
doclayout_yolo_res = self.model.predict(
|
12
|
-
image,
|
12
|
+
image,
|
13
|
+
imgsz=1280,
|
14
|
+
conf=0.10,
|
15
|
+
iou=0.45,
|
16
|
+
verbose=False, device=self.device
|
13
17
|
)[0]
|
14
18
|
for xyxy, conf, cla in zip(
|
15
19
|
doclayout_yolo_res.boxes.xyxy.cpu(),
|
@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
|
|
32
36
|
image_res.cpu()
|
33
37
|
for image_res in self.model.predict(
|
34
38
|
images[index : index + batch_size],
|
35
|
-
imgsz=
|
36
|
-
conf=0.
|
39
|
+
imgsz=1280,
|
40
|
+
conf=0.10,
|
37
41
|
iou=0.45,
|
38
42
|
verbose=False,
|
39
43
|
device=self.device,
|
@@ -89,7 +89,7 @@ class UnimernetModel(object):
|
|
89
89
|
mf_image_list.append(bbox_img)
|
90
90
|
|
91
91
|
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
92
|
-
dataloader = DataLoader(dataset, batch_size=
|
92
|
+
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
|
93
93
|
mfr_res = []
|
94
94
|
for mf_img in dataloader:
|
95
95
|
mf_img = mf_img.to(self.device)
|
@@ -4,24 +4,39 @@ from loguru import logger
|
|
4
4
|
from magic_pdf.config.constants import MODEL_NAME
|
5
5
|
from magic_pdf.model.model_list import AtomicModel
|
6
6
|
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
|
7
|
-
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import
|
8
|
-
|
9
|
-
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
|
10
|
-
Layoutlmv3_Predictor
|
7
|
+
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
|
8
|
+
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
|
11
9
|
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
12
10
|
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
13
|
-
|
14
|
-
|
15
|
-
from
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
11
|
+
|
12
|
+
try:
|
13
|
+
from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
|
14
|
+
from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
|
15
|
+
from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
|
16
|
+
license_key = load_license()
|
17
|
+
logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
|
18
|
+
f' License expired at {license_key["payload"]["date"]["end_date"]}')
|
19
|
+
except Exception as e:
|
20
|
+
if isinstance(e, ImportError):
|
21
|
+
pass
|
22
|
+
elif isinstance(e, LicenseFormatError):
|
23
|
+
logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
|
24
|
+
elif isinstance(e, LicenseSignatureError):
|
25
|
+
logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
|
26
|
+
elif isinstance(e, LicenseExpiredError):
|
27
|
+
logger.error("Ascend Plugin: License has expired. Please renew your license.")
|
28
|
+
elif isinstance(e, FileNotFoundError):
|
29
|
+
logger.error("Ascend Plugin: Not found License file.")
|
30
|
+
else:
|
31
|
+
logger.error(f"Ascend Plugin: {e}")
|
32
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
|
33
|
+
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
34
|
+
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
|
35
|
+
|
36
|
+
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
|
37
|
+
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
|
38
|
+
|
39
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
|
25
40
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
26
41
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
27
42
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
@@ -31,7 +46,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
|
|
31
46
|
}
|
32
47
|
table_model = TableMasterPaddleModel(config)
|
33
48
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
34
|
-
table_model = RapidTableModel(ocr_engine)
|
49
|
+
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
|
35
50
|
else:
|
36
51
|
logger.error('table model type not allow')
|
37
52
|
exit(1)
|
@@ -76,7 +91,6 @@ def ocr_model_init(show_log: bool = False,
|
|
76
91
|
use_dilation=True,
|
77
92
|
det_db_unclip_ratio=1.8,
|
78
93
|
):
|
79
|
-
|
80
94
|
if lang is not None and lang != '':
|
81
95
|
model = ModifiedPaddleOCR(
|
82
96
|
show_log=show_log,
|
@@ -163,7 +177,8 @@ def atom_model_init(model_name: str, **kwargs):
|
|
163
177
|
kwargs.get('table_model_path'),
|
164
178
|
kwargs.get('table_max_time'),
|
165
179
|
kwargs.get('device'),
|
166
|
-
kwargs.get('ocr_engine')
|
180
|
+
kwargs.get('ocr_engine'),
|
181
|
+
kwargs.get('table_sub_model_name')
|
167
182
|
)
|
168
183
|
elif model_name == AtomicModel.LangDetect:
|
169
184
|
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
|
@@ -7,6 +7,8 @@ import base64
|
|
7
7
|
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
8
8
|
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
9
9
|
|
10
|
+
import importlib.resources
|
11
|
+
from paddleocr import PaddleOCR
|
10
12
|
from ppocr.utils.utility import check_and_read
|
11
13
|
|
12
14
|
|
@@ -327,30 +329,35 @@ class ONNXModelSingleton:
|
|
327
329
|
return self._models[key]
|
328
330
|
|
329
331
|
def onnx_model_init(key):
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
|
334
|
-
|
335
|
-
onnx_model = None
|
336
|
-
additional_ocr_params = {
|
337
|
-
"use_onnx": True,
|
338
|
-
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
|
339
|
-
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
|
340
|
-
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
|
341
|
-
"det_db_box_thresh": key[1],
|
342
|
-
"use_dilation": key[2],
|
343
|
-
"det_db_unclip_ratio": key[3],
|
344
|
-
}
|
345
|
-
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
|
346
|
-
if key[0] is not None:
|
347
|
-
additional_ocr_params["lang"] = key[0]
|
348
|
-
|
349
|
-
from paddleocr import PaddleOCR
|
350
|
-
onnx_model = PaddleOCR(**additional_ocr_params)
|
351
|
-
|
352
|
-
if onnx_model is None:
|
353
|
-
logger.error('model init failed')
|
332
|
+
if len(key) < 4:
|
333
|
+
logger.error('Invalid key length, expected at least 4 elements')
|
354
334
|
exit(1)
|
355
|
-
|
356
|
-
|
335
|
+
|
336
|
+
try:
|
337
|
+
with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
|
338
|
+
additional_ocr_params = {
|
339
|
+
"use_onnx": True,
|
340
|
+
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
|
341
|
+
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
|
342
|
+
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
|
343
|
+
"det_db_box_thresh": key[1],
|
344
|
+
"use_dilation": key[2],
|
345
|
+
"det_db_unclip_ratio": key[3],
|
346
|
+
}
|
347
|
+
|
348
|
+
if key[0] is not None:
|
349
|
+
additional_ocr_params["lang"] = key[0]
|
350
|
+
|
351
|
+
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
|
352
|
+
|
353
|
+
onnx_model = PaddleOCR(**additional_ocr_params)
|
354
|
+
|
355
|
+
if onnx_model is None:
|
356
|
+
logger.error('model init failed')
|
357
|
+
exit(1)
|
358
|
+
else:
|
359
|
+
return onnx_model
|
360
|
+
|
361
|
+
except Exception as e:
|
362
|
+
logger.exception(f'Error initializing model: {e}')
|
363
|
+
exit(1)
|
@@ -2,12 +2,27 @@ import cv2
|
|
2
2
|
import numpy as np
|
3
3
|
import torch
|
4
4
|
from loguru import logger
|
5
|
-
from rapid_table import RapidTable
|
5
|
+
from rapid_table import RapidTable, RapidTableInput
|
6
|
+
from rapid_table.main import ModelType
|
7
|
+
|
8
|
+
from magic_pdf.libs.config_reader import get_device
|
6
9
|
|
7
10
|
|
8
11
|
class RapidTableModel(object):
|
9
|
-
def __init__(self, ocr_engine):
|
10
|
-
|
12
|
+
def __init__(self, ocr_engine, table_sub_model_name):
|
13
|
+
sub_model_list = [model.value for model in ModelType]
|
14
|
+
if table_sub_model_name is None:
|
15
|
+
input_args = RapidTableInput()
|
16
|
+
elif table_sub_model_name in sub_model_list:
|
17
|
+
if torch.cuda.is_available() and table_sub_model_name == "unitable":
|
18
|
+
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
|
19
|
+
else:
|
20
|
+
input_args = RapidTableInput(model_type=table_sub_model_name)
|
21
|
+
else:
|
22
|
+
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
|
23
|
+
|
24
|
+
self.table_model = RapidTable(input_args)
|
25
|
+
|
11
26
|
# if ocr_engine is None:
|
12
27
|
# self.ocr_model_name = "RapidOCR"
|
13
28
|
# if torch.cuda.is_available():
|
@@ -45,7 +60,11 @@ class RapidTableModel(object):
|
|
45
60
|
ocr_result = None
|
46
61
|
|
47
62
|
if ocr_result:
|
48
|
-
|
49
|
-
|
63
|
+
table_results = self.table_model(np.asarray(image), ocr_result)
|
64
|
+
html_code = table_results.pred_html
|
65
|
+
table_cell_bboxes = table_results.cell_bboxes
|
66
|
+
logic_points = table_results.logic_points
|
67
|
+
elapse = table_results.elapse
|
68
|
+
return html_code, table_cell_bboxes, logic_points, elapse
|
50
69
|
else:
|
51
|
-
return None, None, None
|
70
|
+
return None, None, None, None
|