magic-pdf 0.10.6__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +2 -0
- magic_pdf/config/exceptions.py +7 -0
- magic_pdf/data/data_reader_writer/filebase.py +1 -1
- magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
- magic_pdf/data/dataset.py +13 -1
- magic_pdf/data/read_api.py +59 -12
- magic_pdf/data/utils.py +35 -0
- magic_pdf/dict2md/ocr_mkcontent.py +14 -13
- magic_pdf/libs/clean_memory.py +11 -4
- magic_pdf/libs/config_reader.py +9 -0
- magic_pdf/libs/draw_bbox.py +8 -12
- magic_pdf/libs/language.py +3 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -125
- magic_pdf/model/batch_analyze.py +275 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +4 -51
- magic_pdf/model/magic_model.py +4 -435
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +33 -22
- magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
- magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
- magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
- magic_pdf/model/sub_modules/model_init.py +30 -4
- magic_pdf/model/sub_modules/model_utils.py +8 -2
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
- magic_pdf/operators/__init__.py +94 -0
- magic_pdf/{model/operators.py → operators/models.py} +2 -38
- magic_pdf/{pipe/operators.py → operators/pipes.py} +70 -17
- magic_pdf/pdf_parse_union_core_v2.py +71 -17
- magic_pdf/post_proc/__init__.py +1 -0
- magic_pdf/post_proc/llm_aided.py +133 -0
- magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
- magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
- magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
- magic_pdf/tools/cli.py +36 -11
- magic_pdf/tools/common.py +28 -18
- magic_pdf/utils/office_to_pdf.py +29 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.1.dist-info}/METADATA +73 -23
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.1.dist-info}/RECORD +50 -53
- magic_pdf/para/__init__.py +0 -0
- magic_pdf/pdf_parse_by_ocr.py +0 -22
- magic_pdf/pdf_parse_by_txt.py +0 -23
- magic_pdf/pipe/AbsPipe.py +0 -99
- magic_pdf/pipe/OCRPipe.py +0 -80
- magic_pdf/pipe/TXTPipe.py +0 -42
- magic_pdf/pipe/UNIPipe.py +0 -150
- magic_pdf/pipe/__init__.py +0 -0
- magic_pdf/rw/AbsReaderWriter.py +0 -17
- magic_pdf/rw/DiskReaderWriter.py +0 -74
- magic_pdf/rw/S3ReaderWriter.py +0 -142
- magic_pdf/rw/__init__.py +0 -0
- magic_pdf/user_api.py +0 -144
- /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.1.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.1.dist-info}/WHEEL +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.1.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
2
|
+
from collections import Counter
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from PIL import Image
|
7
|
+
from loguru import logger
|
8
|
+
from ultralytics import YOLO
|
9
|
+
|
10
|
+
language_dict = {
|
11
|
+
"ch": "中文简体",
|
12
|
+
"en": "英语",
|
13
|
+
"japan": "日语",
|
14
|
+
"korean": "韩语",
|
15
|
+
"fr": "法语",
|
16
|
+
"german": "德语",
|
17
|
+
"ar": "阿拉伯语",
|
18
|
+
"ru": "俄语"
|
19
|
+
}
|
20
|
+
|
21
|
+
|
22
|
+
def split_images(image, result_images=None):
|
23
|
+
"""
|
24
|
+
对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
|
25
|
+
每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
|
26
|
+
避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
|
27
|
+
"""
|
28
|
+
if result_images is None:
|
29
|
+
result_images = []
|
30
|
+
|
31
|
+
width, height = image.size
|
32
|
+
long_side = max(width, height) # 获取较长边长度
|
33
|
+
|
34
|
+
if long_side <= 400:
|
35
|
+
result_images.append(image)
|
36
|
+
return result_images
|
37
|
+
|
38
|
+
new_long_side = long_side // 2
|
39
|
+
sub_images = []
|
40
|
+
|
41
|
+
if width >= height: # 如果宽度是较长边
|
42
|
+
for x in range(0, width, new_long_side):
|
43
|
+
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
|
44
|
+
if x + new_long_side > width:
|
45
|
+
continue
|
46
|
+
box = (x, 0, x + new_long_side, height)
|
47
|
+
sub_image = image.crop(box)
|
48
|
+
sub_images.append(sub_image)
|
49
|
+
else: # 如果高度是较长边
|
50
|
+
for y in range(0, height, new_long_side):
|
51
|
+
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
|
52
|
+
if y + new_long_side > height:
|
53
|
+
continue
|
54
|
+
box = (0, y, width, y + new_long_side)
|
55
|
+
sub_image = image.crop(box)
|
56
|
+
sub_images.append(sub_image)
|
57
|
+
|
58
|
+
for sub_image in sub_images:
|
59
|
+
split_images(sub_image, result_images)
|
60
|
+
|
61
|
+
return result_images
|
62
|
+
|
63
|
+
|
64
|
+
def resize_images_to_224(image):
|
65
|
+
"""
|
66
|
+
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
|
67
|
+
"""
|
68
|
+
try:
|
69
|
+
width, height = image.size
|
70
|
+
if width < 224 or height < 224:
|
71
|
+
new_image = Image.new('RGB', (224, 224), (0, 0, 0))
|
72
|
+
paste_x = (224 - width) // 2
|
73
|
+
paste_y = (224 - height) // 2
|
74
|
+
new_image.paste(image, (paste_x, paste_y))
|
75
|
+
image = new_image
|
76
|
+
else:
|
77
|
+
image = image.resize((224, 224), Image.Resampling.LANCZOS)
|
78
|
+
|
79
|
+
# uuid = str(uuid4())
|
80
|
+
# image.save(f"/tmp/{uuid}.jpg")
|
81
|
+
return image
|
82
|
+
except Exception as e:
|
83
|
+
logger.exception(e)
|
84
|
+
|
85
|
+
|
86
|
+
class YOLOv11LangDetModel(object):
|
87
|
+
def __init__(self, langdetect_model_weight, device):
|
88
|
+
|
89
|
+
self.model = YOLO(langdetect_model_weight)
|
90
|
+
|
91
|
+
if str(device).startswith("npu"):
|
92
|
+
self.device = torch.device(device)
|
93
|
+
else:
|
94
|
+
self.device = device
|
95
|
+
def do_detect(self, images: list):
|
96
|
+
all_images = []
|
97
|
+
for image in images:
|
98
|
+
width, height = image.size
|
99
|
+
# logger.info(f"image size: {width} x {height}")
|
100
|
+
if width < 100 and height < 100:
|
101
|
+
continue
|
102
|
+
temp_images = split_images(image)
|
103
|
+
for temp_image in temp_images:
|
104
|
+
all_images.append(resize_images_to_224(temp_image))
|
105
|
+
|
106
|
+
images_lang_res = self.batch_predict(all_images, batch_size=8)
|
107
|
+
# logger.info(f"images_lang_res: {images_lang_res}")
|
108
|
+
if len(images_lang_res) > 0:
|
109
|
+
count_dict = Counter(images_lang_res)
|
110
|
+
language = max(count_dict, key=count_dict.get)
|
111
|
+
else:
|
112
|
+
language = None
|
113
|
+
return language
|
114
|
+
|
115
|
+
def predict(self, image):
|
116
|
+
results = self.model.predict(image, verbose=False, device=self.device)
|
117
|
+
predicted_class_id = int(results[0].probs.top1)
|
118
|
+
predicted_class_name = self.model.names[predicted_class_id]
|
119
|
+
return predicted_class_name
|
120
|
+
|
121
|
+
|
122
|
+
def batch_predict(self, images: list, batch_size: int) -> list:
|
123
|
+
images_lang_res = []
|
124
|
+
|
125
|
+
for index in range(0, len(images), batch_size):
|
126
|
+
lang_res = [
|
127
|
+
image_res.cpu()
|
128
|
+
for image_res in self.model.predict(
|
129
|
+
images[index: index + batch_size],
|
130
|
+
verbose = False,
|
131
|
+
device=self.device,
|
132
|
+
)
|
133
|
+
]
|
134
|
+
for res in lang_res:
|
135
|
+
predicted_class_id = int(res.probs.top1)
|
136
|
+
predicted_class_name = self.model.names[predicted_class_id]
|
137
|
+
images_lang_res.append(predicted_class_name)
|
138
|
+
|
139
|
+
return images_lang_res
|
@@ -0,0 +1 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
@@ -8,14 +8,51 @@ class DocLayoutYOLOModel(object):
|
|
8
8
|
|
9
9
|
def predict(self, image):
|
10
10
|
layout_res = []
|
11
|
-
doclayout_yolo_res = self.model.predict(
|
12
|
-
|
13
|
-
|
11
|
+
doclayout_yolo_res = self.model.predict(
|
12
|
+
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
|
13
|
+
)[0]
|
14
|
+
for xyxy, conf, cla in zip(
|
15
|
+
doclayout_yolo_res.boxes.xyxy.cpu(),
|
16
|
+
doclayout_yolo_res.boxes.conf.cpu(),
|
17
|
+
doclayout_yolo_res.boxes.cls.cpu(),
|
18
|
+
):
|
14
19
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
15
20
|
new_item = {
|
16
|
-
|
17
|
-
|
18
|
-
|
21
|
+
"category_id": int(cla.item()),
|
22
|
+
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
23
|
+
"score": round(float(conf.item()), 3),
|
19
24
|
}
|
20
25
|
layout_res.append(new_item)
|
21
|
-
return layout_res
|
26
|
+
return layout_res
|
27
|
+
|
28
|
+
def batch_predict(self, images: list, batch_size: int) -> list:
|
29
|
+
images_layout_res = []
|
30
|
+
for index in range(0, len(images), batch_size):
|
31
|
+
doclayout_yolo_res = [
|
32
|
+
image_res.cpu()
|
33
|
+
for image_res in self.model.predict(
|
34
|
+
images[index : index + batch_size],
|
35
|
+
imgsz=1024,
|
36
|
+
conf=0.25,
|
37
|
+
iou=0.45,
|
38
|
+
verbose=False,
|
39
|
+
device=self.device,
|
40
|
+
)
|
41
|
+
]
|
42
|
+
for image_res in doclayout_yolo_res:
|
43
|
+
layout_res = []
|
44
|
+
for xyxy, conf, cla in zip(
|
45
|
+
image_res.boxes.xyxy,
|
46
|
+
image_res.boxes.conf,
|
47
|
+
image_res.boxes.cls,
|
48
|
+
):
|
49
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
50
|
+
new_item = {
|
51
|
+
"category_id": int(cla.item()),
|
52
|
+
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
53
|
+
"score": round(float(conf.item()), 3),
|
54
|
+
}
|
55
|
+
layout_res.append(new_item)
|
56
|
+
images_layout_res.append(layout_res)
|
57
|
+
|
58
|
+
return images_layout_res
|
@@ -2,11 +2,30 @@ from ultralytics import YOLO
|
|
2
2
|
|
3
3
|
|
4
4
|
class YOLOv8MFDModel(object):
|
5
|
-
def __init__(self, weight, device=
|
5
|
+
def __init__(self, weight, device="cpu"):
|
6
6
|
self.mfd_model = YOLO(weight)
|
7
7
|
self.device = device
|
8
8
|
|
9
9
|
def predict(self, image):
|
10
|
-
mfd_res = self.mfd_model.predict(
|
10
|
+
mfd_res = self.mfd_model.predict(
|
11
|
+
image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
|
12
|
+
)[0]
|
11
13
|
return mfd_res
|
12
14
|
|
15
|
+
def batch_predict(self, images: list, batch_size: int) -> list:
|
16
|
+
images_mfd_res = []
|
17
|
+
for index in range(0, len(images), batch_size):
|
18
|
+
mfd_res = [
|
19
|
+
image_res.cpu()
|
20
|
+
for image_res in self.mfd_model.predict(
|
21
|
+
images[index : index + batch_size],
|
22
|
+
imgsz=1888,
|
23
|
+
conf=0.25,
|
24
|
+
iou=0.45,
|
25
|
+
verbose=False,
|
26
|
+
device=self.device,
|
27
|
+
)
|
28
|
+
]
|
29
|
+
for image_res in mfd_res:
|
30
|
+
images_mfd_res.append(image_res)
|
31
|
+
return images_mfd_res
|
@@ -1,13 +1,13 @@
|
|
1
|
-
import os
|
2
1
|
import argparse
|
2
|
+
import os
|
3
3
|
import re
|
4
4
|
|
5
|
-
from PIL import Image
|
6
5
|
import torch
|
7
|
-
|
6
|
+
import unimernet.tasks as tasks
|
7
|
+
from PIL import Image
|
8
|
+
from torch.utils.data import DataLoader, Dataset
|
8
9
|
from torchvision import transforms
|
9
10
|
from unimernet.common.config import Config
|
10
|
-
import unimernet.tasks as tasks
|
11
11
|
from unimernet.processors import load_processor
|
12
12
|
|
13
13
|
|
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
|
|
31
31
|
|
32
32
|
|
33
33
|
def latex_rm_whitespace(s: str):
|
34
|
-
"""Remove unnecessary whitespace from LaTeX code.
|
35
|
-
""
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
34
|
+
"""Remove unnecessary whitespace from LaTeX code."""
|
35
|
+
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
|
36
|
+
letter = "[a-zA-Z]"
|
37
|
+
noletter = "[\W_^\d]"
|
38
|
+
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
|
40
39
|
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
41
40
|
news = s
|
42
41
|
while True:
|
43
42
|
s = news
|
44
|
-
news = re.sub(r
|
45
|
-
news = re.sub(r
|
46
|
-
news = re.sub(r
|
43
|
+
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
|
44
|
+
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
|
45
|
+
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
|
47
46
|
if news == s:
|
48
47
|
break
|
49
48
|
return s
|
50
49
|
|
51
50
|
|
52
51
|
class UnimernetModel(object):
|
53
|
-
def __init__(self, weight_dir, cfg_path, _device_=
|
54
|
-
|
52
|
+
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
|
55
53
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
56
54
|
cfg = Config(args)
|
57
55
|
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
@@ -62,20 +60,28 @@ class UnimernetModel(object):
|
|
62
60
|
self.device = _device_
|
63
61
|
self.model.to(_device_)
|
64
62
|
self.model.eval()
|
65
|
-
vis_processor = load_processor(
|
66
|
-
|
63
|
+
vis_processor = load_processor(
|
64
|
+
"formula_image_eval",
|
65
|
+
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
|
66
|
+
)
|
67
|
+
self.mfr_transform = transforms.Compose(
|
68
|
+
[
|
69
|
+
vis_processor,
|
70
|
+
]
|
71
|
+
)
|
67
72
|
|
68
73
|
def predict(self, mfd_res, image):
|
69
|
-
|
70
74
|
formula_list = []
|
71
75
|
mf_image_list = []
|
72
|
-
for xyxy, conf, cla in zip(
|
76
|
+
for xyxy, conf, cla in zip(
|
77
|
+
mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
|
78
|
+
):
|
73
79
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
74
80
|
new_item = {
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
81
|
+
"category_id": 13 + int(cla.item()),
|
82
|
+
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
83
|
+
"score": round(float(conf.item()), 2),
|
84
|
+
"latex": "",
|
79
85
|
}
|
80
86
|
formula_list.append(new_item)
|
81
87
|
pil_img = Image.fromarray(image)
|
@@ -88,11 +94,48 @@ class UnimernetModel(object):
|
|
88
94
|
for mf_img in dataloader:
|
89
95
|
mf_img = mf_img.to(self.device)
|
90
96
|
with torch.no_grad():
|
91
|
-
output = self.model.generate({
|
92
|
-
mfr_res.extend(output[
|
97
|
+
output = self.model.generate({"image": mf_img})
|
98
|
+
mfr_res.extend(output["pred_str"])
|
93
99
|
for res, latex in zip(formula_list, mfr_res):
|
94
|
-
res[
|
100
|
+
res["latex"] = latex_rm_whitespace(latex)
|
95
101
|
return formula_list
|
96
102
|
|
103
|
+
def batch_predict(
|
104
|
+
self, images_mfd_res: list, images: list, batch_size: int = 64
|
105
|
+
) -> list:
|
106
|
+
images_formula_list = []
|
107
|
+
mf_image_list = []
|
108
|
+
backfill_list = []
|
109
|
+
for image_index in range(len(images_mfd_res)):
|
110
|
+
mfd_res = images_mfd_res[image_index]
|
111
|
+
pil_img = Image.fromarray(images[image_index])
|
112
|
+
formula_list = []
|
113
|
+
|
114
|
+
for xyxy, conf, cla in zip(
|
115
|
+
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
116
|
+
):
|
117
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
118
|
+
new_item = {
|
119
|
+
"category_id": 13 + int(cla.item()),
|
120
|
+
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
121
|
+
"score": round(float(conf.item()), 2),
|
122
|
+
"latex": "",
|
123
|
+
}
|
124
|
+
formula_list.append(new_item)
|
125
|
+
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
126
|
+
mf_image_list.append(bbox_img)
|
127
|
+
|
128
|
+
images_formula_list.append(formula_list)
|
129
|
+
backfill_list += formula_list
|
97
130
|
|
98
|
-
|
131
|
+
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
132
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
133
|
+
mfr_res = []
|
134
|
+
for mf_img in dataloader:
|
135
|
+
mf_img = mf_img.to(self.device)
|
136
|
+
with torch.no_grad():
|
137
|
+
output = self.model.generate({"image": mf_img})
|
138
|
+
mfr_res.extend(output["pred_str"])
|
139
|
+
for res, latex in zip(backfill_list, mfr_res):
|
140
|
+
res["latex"] = latex_rm_whitespace(latex)
|
141
|
+
return images_formula_list
|
@@ -1,7 +1,9 @@
|
|
1
|
+
import torch
|
1
2
|
from loguru import logger
|
2
3
|
|
3
4
|
from magic_pdf.config.constants import MODEL_NAME
|
4
5
|
from magic_pdf.model.model_list import AtomicModel
|
6
|
+
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
|
5
7
|
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
|
6
8
|
DocLayoutYOLOModel
|
7
9
|
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
|
@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
|
|
19
21
|
TableMasterPaddleModel
|
20
22
|
|
21
23
|
|
22
|
-
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
24
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
|
23
25
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
24
26
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
25
27
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
29
31
|
}
|
30
32
|
table_model = TableMasterPaddleModel(config)
|
31
33
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
32
|
-
table_model = RapidTableModel()
|
34
|
+
table_model = RapidTableModel(ocr_engine)
|
33
35
|
else:
|
34
36
|
logger.error('table model type not allow')
|
35
37
|
exit(1)
|
@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
38
40
|
|
39
41
|
|
40
42
|
def mfd_model_init(weight, device='cpu'):
|
43
|
+
if str(device).startswith("npu"):
|
44
|
+
device = torch.device(device)
|
41
45
|
mfd_model = YOLOv8MFDModel(weight, device)
|
42
46
|
return mfd_model
|
43
47
|
|
@@ -53,16 +57,26 @@ def layout_model_init(weight, config_file, device):
|
|
53
57
|
|
54
58
|
|
55
59
|
def doclayout_yolo_model_init(weight, device='cpu'):
|
60
|
+
if str(device).startswith("npu"):
|
61
|
+
device = torch.device(device)
|
56
62
|
model = DocLayoutYOLOModel(weight, device)
|
57
63
|
return model
|
58
64
|
|
59
65
|
|
66
|
+
def langdetect_model_init(langdetect_model_weight, device='cpu'):
|
67
|
+
if str(device).startswith("npu"):
|
68
|
+
device = torch.device(device)
|
69
|
+
model = YOLOv11LangDetModel(langdetect_model_weight, device)
|
70
|
+
return model
|
71
|
+
|
72
|
+
|
60
73
|
def ocr_model_init(show_log: bool = False,
|
61
74
|
det_db_box_thresh=0.3,
|
62
75
|
lang=None,
|
63
76
|
use_dilation=True,
|
64
77
|
det_db_unclip_ratio=1.8,
|
65
78
|
):
|
79
|
+
|
66
80
|
if lang is not None and lang != '':
|
67
81
|
model = ModifiedPaddleOCR(
|
68
82
|
show_log=show_log,
|
@@ -77,7 +91,6 @@ def ocr_model_init(show_log: bool = False,
|
|
77
91
|
det_db_box_thresh=det_db_box_thresh,
|
78
92
|
use_dilation=use_dilation,
|
79
93
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
80
|
-
# use_angle_cls=True,
|
81
94
|
)
|
82
95
|
return model
|
83
96
|
|
@@ -124,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
|
|
124
137
|
kwargs.get('doclayout_yolo_weights'),
|
125
138
|
kwargs.get('device')
|
126
139
|
)
|
140
|
+
else:
|
141
|
+
logger.error('layout model name not allow')
|
142
|
+
exit(1)
|
127
143
|
elif model_name == AtomicModel.MFD:
|
128
144
|
atom_model = mfd_model_init(
|
129
145
|
kwargs.get('mfd_weights'),
|
@@ -146,8 +162,18 @@ def atom_model_init(model_name: str, **kwargs):
|
|
146
162
|
kwargs.get('table_model_name'),
|
147
163
|
kwargs.get('table_model_path'),
|
148
164
|
kwargs.get('table_max_time'),
|
149
|
-
kwargs.get('device')
|
165
|
+
kwargs.get('device'),
|
166
|
+
kwargs.get('ocr_engine')
|
150
167
|
)
|
168
|
+
elif model_name == AtomicModel.LangDetect:
|
169
|
+
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
|
170
|
+
atom_model = langdetect_model_init(
|
171
|
+
kwargs.get('langdetect_model_weight'),
|
172
|
+
kwargs.get('device')
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
logger.error('langdetect model name not allow')
|
176
|
+
exit(1)
|
151
177
|
else:
|
152
178
|
logger.error('model name not allow')
|
153
179
|
exit(1)
|
@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8):
|
|
45
45
|
total_memory = get_vram(device)
|
46
46
|
if total_memory and total_memory <= vram_threshold:
|
47
47
|
gc_start = time.time()
|
48
|
-
clean_memory()
|
48
|
+
clean_memory(device)
|
49
49
|
gc_time = round(time.time() - gc_start, 2)
|
50
50
|
logger.info(f"gc time: {gc_time}")
|
51
51
|
|
@@ -54,4 +54,10 @@ def get_vram(device):
|
|
54
54
|
if torch.cuda.is_available() and device != 'cpu':
|
55
55
|
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
56
56
|
return total_memory
|
57
|
-
|
57
|
+
elif str(device).startswith("npu"):
|
58
|
+
import torch_npu
|
59
|
+
if torch_npu.npu.is_available():
|
60
|
+
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
|
61
|
+
return total_memory
|
62
|
+
else:
|
63
|
+
return None
|
@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
|
|
303
303
|
return False
|
304
304
|
else:
|
305
305
|
# logger.info((p3[1] - p1[1])/height)
|
306
|
-
return True
|
306
|
+
return True
|
307
|
+
|
308
|
+
|
309
|
+
class ONNXModelSingleton:
|
310
|
+
_instance = None
|
311
|
+
_models = {}
|
312
|
+
|
313
|
+
def __new__(cls, *args, **kwargs):
|
314
|
+
if cls._instance is None:
|
315
|
+
cls._instance = super().__new__(cls)
|
316
|
+
return cls._instance
|
317
|
+
|
318
|
+
def get_onnx_model(self, **kwargs):
|
319
|
+
|
320
|
+
lang = kwargs.get('lang', None)
|
321
|
+
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
|
322
|
+
use_dilation = kwargs.get('use_dilation', True)
|
323
|
+
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
|
324
|
+
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
|
325
|
+
if key not in self._models:
|
326
|
+
self._models[key] = onnx_model_init(key)
|
327
|
+
return self._models[key]
|
328
|
+
|
329
|
+
def onnx_model_init(key):
|
330
|
+
|
331
|
+
import importlib.resources
|
332
|
+
|
333
|
+
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
|
334
|
+
|
335
|
+
onnx_model = None
|
336
|
+
additional_ocr_params = {
|
337
|
+
"use_onnx": True,
|
338
|
+
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
|
339
|
+
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
|
340
|
+
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
|
341
|
+
"det_db_box_thresh": key[1],
|
342
|
+
"use_dilation": key[2],
|
343
|
+
"det_db_unclip_ratio": key[3],
|
344
|
+
}
|
345
|
+
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
|
346
|
+
if key[0] is not None:
|
347
|
+
additional_ocr_params["lang"] = key[0]
|
348
|
+
|
349
|
+
from paddleocr import PaddleOCR
|
350
|
+
onnx_model = PaddleOCR(**additional_ocr_params)
|
351
|
+
|
352
|
+
if onnx_model is None:
|
353
|
+
logger.error('model init failed')
|
354
|
+
exit(1)
|
355
|
+
else:
|
356
|
+
return onnx_model
|
@@ -1,7 +1,9 @@
|
|
1
1
|
import copy
|
2
|
+
import platform
|
2
3
|
import time
|
3
4
|
import cv2
|
4
5
|
import numpy as np
|
6
|
+
import torch
|
5
7
|
|
6
8
|
from paddleocr import PaddleOCR
|
7
9
|
from ppocr.utils.logging import get_logger
|
@@ -9,12 +11,25 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
|
|
9
11
|
from tools.infer.predict_system import sorted_boxes
|
10
12
|
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
|
11
13
|
|
12
|
-
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img
|
14
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
|
15
|
+
ONNXModelSingleton
|
13
16
|
|
14
17
|
logger = get_logger()
|
15
18
|
|
16
19
|
|
17
20
|
class ModifiedPaddleOCR(PaddleOCR):
|
21
|
+
def __init__(self, *args, **kwargs):
|
22
|
+
|
23
|
+
super().__init__(*args, **kwargs)
|
24
|
+
self.lang = kwargs.get('lang', 'ch')
|
25
|
+
# 在cpu架构为arm且不支持cuda时调用onnx、
|
26
|
+
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
|
27
|
+
self.use_onnx = True
|
28
|
+
onnx_model_manager = ONNXModelSingleton()
|
29
|
+
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
|
30
|
+
else:
|
31
|
+
self.use_onnx = False
|
32
|
+
|
18
33
|
def ocr(self,
|
19
34
|
img,
|
20
35
|
det=True,
|
@@ -79,7 +94,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
79
94
|
ocr_res = []
|
80
95
|
for img in imgs:
|
81
96
|
img = preprocess_image(img)
|
82
|
-
|
97
|
+
if self.lang in ['ch'] and self.use_onnx:
|
98
|
+
dt_boxes, elapse = self.additional_ocr.text_detector(img)
|
99
|
+
else:
|
100
|
+
dt_boxes, elapse = self.text_detector(img)
|
83
101
|
if dt_boxes is None:
|
84
102
|
ocr_res.append(None)
|
85
103
|
continue
|
@@ -106,7 +124,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
106
124
|
img, cls_res_tmp, elapse = self.text_classifier(img)
|
107
125
|
if not rec:
|
108
126
|
cls_res.append(cls_res_tmp)
|
109
|
-
|
127
|
+
if self.lang in ['ch'] and self.use_onnx:
|
128
|
+
rec_res, elapse = self.additional_ocr.text_recognizer(img)
|
129
|
+
else:
|
130
|
+
rec_res, elapse = self.text_recognizer(img)
|
110
131
|
ocr_res.append(rec_res)
|
111
132
|
if not rec:
|
112
133
|
return cls_res
|
@@ -121,7 +142,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
121
142
|
|
122
143
|
start = time.time()
|
123
144
|
ori_im = img.copy()
|
124
|
-
|
145
|
+
if self.lang in ['ch'] and self.use_onnx:
|
146
|
+
dt_boxes, elapse = self.additional_ocr.text_detector(img)
|
147
|
+
else:
|
148
|
+
dt_boxes, elapse = self.text_detector(img)
|
125
149
|
time_dict['det'] = elapse
|
126
150
|
|
127
151
|
if dt_boxes is None:
|
@@ -159,8 +183,10 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
159
183
|
time_dict['cls'] = elapse
|
160
184
|
logger.debug("cls num : {}, elapsed : {}".format(
|
161
185
|
len(img_crop_list), elapse))
|
162
|
-
|
163
|
-
|
186
|
+
if self.lang in ['ch'] and self.use_onnx:
|
187
|
+
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
|
188
|
+
else:
|
189
|
+
rec_res, elapse = self.text_recognizer(img_crop_list)
|
164
190
|
time_dict['rec'] = elapse
|
165
191
|
logger.debug("rec_res num : {}, elapsed : {}".format(
|
166
192
|
len(rec_res), elapse))
|