magic-pdf 0.9.2__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +53 -0
- magic_pdf/config/drop_reason.py +35 -0
- magic_pdf/config/drop_tag.py +19 -0
- magic_pdf/config/make_content_config.py +11 -0
- magic_pdf/{libs/ModelBlockTypeEnum.py → config/model_block_type.py} +2 -1
- magic_pdf/data/read_api.py +1 -1
- magic_pdf/dict2md/mkcontent.py +226 -185
- magic_pdf/dict2md/ocr_mkcontent.py +12 -12
- magic_pdf/filter/pdf_meta_scan.py +101 -79
- magic_pdf/integrations/rag/utils.py +4 -5
- magic_pdf/libs/config_reader.py +6 -6
- magic_pdf/libs/draw_bbox.py +13 -6
- magic_pdf/libs/pdf_image_tools.py +36 -12
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +2 -0
- magic_pdf/model/magic_model.py +13 -13
- magic_pdf/model/pdf_extract_kit.py +142 -351
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
- magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
- magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
- magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
- magic_pdf/model/sub_modules/model_init.py +149 -0
- magic_pdf/model/sub_modules/model_utils.py +51 -0
- magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +285 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +176 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
- magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
- magic_pdf/model/sub_modules/table/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +16 -0
- magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
- magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
- magic_pdf/model/sub_modules/table/table_utils.py +11 -0
- magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
- magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +31 -29
- magic_pdf/para/para_split.py +411 -248
- magic_pdf/para/para_split_v2.py +352 -182
- magic_pdf/para/para_split_v3.py +121 -66
- magic_pdf/pdf_parse_by_ocr.py +2 -0
- magic_pdf/pdf_parse_by_txt.py +2 -0
- magic_pdf/pdf_parse_union_core.py +174 -100
- magic_pdf/pdf_parse_union_core_v2.py +253 -50
- magic_pdf/pipe/AbsPipe.py +28 -44
- magic_pdf/pipe/OCRPipe.py +5 -5
- magic_pdf/pipe/TXTPipe.py +5 -6
- magic_pdf/pipe/UNIPipe.py +24 -25
- magic_pdf/post_proc/pdf_post_filter.py +7 -14
- magic_pdf/pre_proc/cut_image.py +9 -11
- magic_pdf/pre_proc/equations_replace.py +203 -212
- magic_pdf/pre_proc/ocr_detect_all_bboxes.py +235 -49
- magic_pdf/pre_proc/ocr_dict_merge.py +5 -5
- magic_pdf/pre_proc/ocr_span_list_modify.py +122 -63
- magic_pdf/pre_proc/pdf_pre_filter.py +37 -33
- magic_pdf/pre_proc/remove_bbox_overlap.py +20 -18
- magic_pdf/pre_proc/remove_colored_strip_bbox.py +36 -14
- magic_pdf/pre_proc/remove_footer_header.py +2 -5
- magic_pdf/pre_proc/remove_rotate_bbox.py +111 -63
- magic_pdf/pre_proc/resolve_bbox_conflict.py +10 -17
- magic_pdf/resources/model_config/model_configs.yaml +2 -1
- magic_pdf/spark/spark_api.py +15 -17
- magic_pdf/tools/cli.py +3 -4
- magic_pdf/tools/cli_dev.py +6 -9
- magic_pdf/tools/common.py +70 -36
- magic_pdf/user_api.py +29 -38
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/METADATA +18 -13
- magic_pdf-0.10.0.dist-info/RECORD +198 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/WHEEL +1 -1
- magic_pdf/libs/Constants.py +0 -53
- magic_pdf/libs/MakeContentConfig.py +0 -11
- magic_pdf/libs/drop_reason.py +0 -27
- magic_pdf/libs/drop_tag.py +0 -19
- magic_pdf/model/pek_sub_modules/post_process.py +0 -36
- magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
- magic_pdf/para/para_pipeline.py +0 -297
- magic_pdf-0.9.2.dist-info/RECORD +0 -178
- /magic_pdf/{libs → config}/ocr_content_type.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,195 +1,32 @@
|
|
1
|
-
|
1
|
+
# flake8: noqa
|
2
2
|
import os
|
3
3
|
import time
|
4
|
-
|
5
|
-
import
|
6
|
-
|
7
|
-
|
8
|
-
|
4
|
+
|
5
|
+
import cv2
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import yaml
|
9
|
+
from loguru import logger
|
10
|
+
from PIL import Image
|
9
11
|
|
10
12
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
11
13
|
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
14
|
+
|
12
15
|
try:
|
13
|
-
import cv2
|
14
|
-
import yaml
|
15
|
-
import argparse
|
16
|
-
import numpy as np
|
17
|
-
import torch
|
18
16
|
import torchtext
|
19
17
|
|
20
|
-
if torchtext.__version__ >=
|
18
|
+
if torchtext.__version__ >= '0.18.0':
|
21
19
|
torchtext.disable_torchtext_deprecation_warning()
|
22
|
-
|
23
|
-
|
24
|
-
from torch.utils.data import Dataset, DataLoader
|
25
|
-
from ultralytics import YOLO
|
26
|
-
from unimernet.common.config import Config
|
27
|
-
import unimernet.tasks as tasks
|
28
|
-
from unimernet.processors import load_processor
|
29
|
-
from doclayout_yolo import YOLOv10
|
30
|
-
|
31
|
-
except ImportError as e:
|
32
|
-
logger.exception(e)
|
33
|
-
logger.error(
|
34
|
-
'Required dependency not installed, please install by \n'
|
35
|
-
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
36
|
-
exit(1)
|
37
|
-
|
38
|
-
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
39
|
-
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
40
|
-
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
41
|
-
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
42
|
-
from magic_pdf.model.ppTableModel import ppTableModel
|
43
|
-
|
44
|
-
|
45
|
-
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
46
|
-
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
47
|
-
table_model = StructTableModel(model_path, max_time=max_time)
|
48
|
-
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
49
|
-
config = {
|
50
|
-
"model_dir": model_path,
|
51
|
-
"device": _device_
|
52
|
-
}
|
53
|
-
table_model = ppTableModel(config)
|
54
|
-
else:
|
55
|
-
logger.error("table model type not allow")
|
56
|
-
exit(1)
|
57
|
-
return table_model
|
58
|
-
|
59
|
-
|
60
|
-
def mfd_model_init(weight):
|
61
|
-
mfd_model = YOLO(weight)
|
62
|
-
return mfd_model
|
63
|
-
|
64
|
-
|
65
|
-
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
66
|
-
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
67
|
-
cfg = Config(args)
|
68
|
-
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
69
|
-
cfg.config.model.model_config.model_name = weight_dir
|
70
|
-
cfg.config.model.tokenizer_config.path = weight_dir
|
71
|
-
task = tasks.setup_task(cfg)
|
72
|
-
model = task.build_model(cfg)
|
73
|
-
model.to(_device_)
|
74
|
-
model.eval()
|
75
|
-
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
76
|
-
mfr_transform = transforms.Compose([vis_processor, ])
|
77
|
-
return [model, mfr_transform]
|
78
|
-
|
79
|
-
|
80
|
-
def layout_model_init(weight, config_file, device):
|
81
|
-
model = Layoutlmv3_Predictor(weight, config_file, device)
|
82
|
-
return model
|
83
|
-
|
84
|
-
|
85
|
-
def doclayout_yolo_model_init(weight):
|
86
|
-
model = YOLOv10(weight)
|
87
|
-
return model
|
88
|
-
|
89
|
-
|
90
|
-
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
|
91
|
-
if lang is not None:
|
92
|
-
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
93
|
-
else:
|
94
|
-
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
95
|
-
return model
|
96
|
-
|
97
|
-
|
98
|
-
class MathDataset(Dataset):
|
99
|
-
def __init__(self, image_paths, transform=None):
|
100
|
-
self.image_paths = image_paths
|
101
|
-
self.transform = transform
|
102
|
-
|
103
|
-
def __len__(self):
|
104
|
-
return len(self.image_paths)
|
105
|
-
|
106
|
-
def __getitem__(self, idx):
|
107
|
-
# if not pil image, then convert to pil image
|
108
|
-
if isinstance(self.image_paths[idx], str):
|
109
|
-
raw_image = Image.open(self.image_paths[idx])
|
110
|
-
else:
|
111
|
-
raw_image = self.image_paths[idx]
|
112
|
-
if self.transform:
|
113
|
-
image = self.transform(raw_image)
|
114
|
-
return image
|
115
|
-
|
116
|
-
|
117
|
-
class AtomModelSingleton:
|
118
|
-
_instance = None
|
119
|
-
_models = {}
|
120
|
-
|
121
|
-
def __new__(cls, *args, **kwargs):
|
122
|
-
if cls._instance is None:
|
123
|
-
cls._instance = super().__new__(cls)
|
124
|
-
return cls._instance
|
125
|
-
|
126
|
-
def get_atom_model(self, atom_model_name: str, **kwargs):
|
127
|
-
lang = kwargs.get("lang", None)
|
128
|
-
layout_model_name = kwargs.get("layout_model_name", None)
|
129
|
-
key = (atom_model_name, layout_model_name, lang)
|
130
|
-
if key not in self._models:
|
131
|
-
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
132
|
-
return self._models[key]
|
133
|
-
|
134
|
-
|
135
|
-
def atom_model_init(model_name: str, **kwargs):
|
136
|
-
|
137
|
-
if model_name == AtomicModel.Layout:
|
138
|
-
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
139
|
-
atom_model = layout_model_init(
|
140
|
-
kwargs.get("layout_weights"),
|
141
|
-
kwargs.get("layout_config_file"),
|
142
|
-
kwargs.get("device")
|
143
|
-
)
|
144
|
-
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
145
|
-
atom_model = doclayout_yolo_model_init(
|
146
|
-
kwargs.get("doclayout_yolo_weights"),
|
147
|
-
)
|
148
|
-
elif model_name == AtomicModel.MFD:
|
149
|
-
atom_model = mfd_model_init(
|
150
|
-
kwargs.get("mfd_weights")
|
151
|
-
)
|
152
|
-
elif model_name == AtomicModel.MFR:
|
153
|
-
atom_model = mfr_model_init(
|
154
|
-
kwargs.get("mfr_weight_dir"),
|
155
|
-
kwargs.get("mfr_cfg_path"),
|
156
|
-
kwargs.get("device")
|
157
|
-
)
|
158
|
-
elif model_name == AtomicModel.OCR:
|
159
|
-
atom_model = ocr_model_init(
|
160
|
-
kwargs.get("ocr_show_log"),
|
161
|
-
kwargs.get("det_db_box_thresh"),
|
162
|
-
kwargs.get("lang")
|
163
|
-
)
|
164
|
-
elif model_name == AtomicModel.Table:
|
165
|
-
atom_model = table_model_init(
|
166
|
-
kwargs.get("table_model_name"),
|
167
|
-
kwargs.get("table_model_path"),
|
168
|
-
kwargs.get("table_max_time"),
|
169
|
-
kwargs.get("device")
|
170
|
-
)
|
171
|
-
else:
|
172
|
-
logger.error("model name not allow")
|
173
|
-
exit(1)
|
174
|
-
|
175
|
-
return atom_model
|
176
|
-
|
20
|
+
except ImportError:
|
21
|
+
pass
|
177
22
|
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
186
|
-
|
187
|
-
# Crop image
|
188
|
-
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
189
|
-
cropped_img = input_pil_img.crop(crop_box)
|
190
|
-
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
191
|
-
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
192
|
-
return return_image, return_list
|
23
|
+
from magic_pdf.config.constants import *
|
24
|
+
from magic_pdf.model.model_list import AtomicModel
|
25
|
+
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
26
|
+
from magic_pdf.model.sub_modules.model_utils import (
|
27
|
+
clean_vram, crop_img, get_res_list_from_layout_res)
|
28
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
29
|
+
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
193
30
|
|
194
31
|
|
195
32
|
class CustomPEKModel:
|
@@ -208,61 +45,80 @@ class CustomPEKModel:
|
|
208
45
|
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
|
209
46
|
# 构建 model_configs.yaml 文件的完整路径
|
210
47
|
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
|
211
|
-
with open(config_path,
|
48
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
212
49
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
213
50
|
# 初始化解析配置
|
214
51
|
|
215
52
|
# layout config
|
216
|
-
self.layout_config = kwargs.get(
|
217
|
-
self.layout_model_name = self.layout_config.get(
|
53
|
+
self.layout_config = kwargs.get('layout_config')
|
54
|
+
self.layout_model_name = self.layout_config.get(
|
55
|
+
'model', MODEL_NAME.DocLayout_YOLO
|
56
|
+
)
|
218
57
|
|
219
58
|
# formula config
|
220
|
-
self.formula_config = kwargs.get(
|
221
|
-
self.mfd_model_name = self.formula_config.get(
|
222
|
-
|
223
|
-
|
59
|
+
self.formula_config = kwargs.get('formula_config')
|
60
|
+
self.mfd_model_name = self.formula_config.get(
|
61
|
+
'mfd_model', MODEL_NAME.YOLO_V8_MFD
|
62
|
+
)
|
63
|
+
self.mfr_model_name = self.formula_config.get(
|
64
|
+
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
|
65
|
+
)
|
66
|
+
self.apply_formula = self.formula_config.get('enable', True)
|
224
67
|
|
225
68
|
# table config
|
226
|
-
self.table_config = kwargs.get(
|
227
|
-
self.apply_table = self.table_config.get(
|
228
|
-
self.table_max_time = self.table_config.get(
|
229
|
-
self.table_model_name = self.table_config.get(
|
69
|
+
self.table_config = kwargs.get('table_config')
|
70
|
+
self.apply_table = self.table_config.get('enable', False)
|
71
|
+
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
|
72
|
+
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
|
230
73
|
|
231
74
|
# ocr config
|
232
75
|
self.apply_ocr = ocr
|
233
|
-
self.lang = kwargs.get(
|
76
|
+
self.lang = kwargs.get('lang', None)
|
234
77
|
|
235
78
|
logger.info(
|
236
|
-
|
237
|
-
|
238
|
-
self.layout_model_name,
|
79
|
+
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
|
80
|
+
'apply_table: {}, table_model: {}, lang: {}'.format(
|
81
|
+
self.layout_model_name,
|
82
|
+
self.apply_formula,
|
83
|
+
self.apply_ocr,
|
84
|
+
self.apply_table,
|
85
|
+
self.table_model_name,
|
86
|
+
self.lang,
|
239
87
|
)
|
240
88
|
)
|
241
89
|
# 初始化解析方案
|
242
|
-
self.device = kwargs.get(
|
243
|
-
logger.info(
|
244
|
-
models_dir = kwargs.get(
|
245
|
-
|
90
|
+
self.device = kwargs.get('device', 'cpu')
|
91
|
+
logger.info('using device: {}'.format(self.device))
|
92
|
+
models_dir = kwargs.get(
|
93
|
+
'models_dir', os.path.join(root_dir, 'resources', 'models')
|
94
|
+
)
|
95
|
+
logger.info('using models_dir: {}'.format(models_dir))
|
246
96
|
|
247
97
|
atom_model_manager = AtomModelSingleton()
|
248
98
|
|
249
99
|
# 初始化公式识别
|
250
100
|
if self.apply_formula:
|
251
|
-
|
252
101
|
# 初始化公式检测模型
|
253
102
|
self.mfd_model = atom_model_manager.get_atom_model(
|
254
103
|
atom_model_name=AtomicModel.MFD,
|
255
|
-
mfd_weights=str(
|
104
|
+
mfd_weights=str(
|
105
|
+
os.path.join(
|
106
|
+
models_dir, self.configs['weights'][self.mfd_model_name]
|
107
|
+
)
|
108
|
+
),
|
109
|
+
device=self.device,
|
256
110
|
)
|
257
111
|
|
258
112
|
# 初始化公式解析模型
|
259
|
-
mfr_weight_dir = str(
|
260
|
-
|
261
|
-
|
113
|
+
mfr_weight_dir = str(
|
114
|
+
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
|
115
|
+
)
|
116
|
+
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
|
117
|
+
self.mfr_model = atom_model_manager.get_atom_model(
|
262
118
|
atom_model_name=AtomicModel.MFR,
|
263
119
|
mfr_weight_dir=mfr_weight_dir,
|
264
120
|
mfr_cfg_path=mfr_cfg_path,
|
265
|
-
device=self.device
|
121
|
+
device=self.device,
|
266
122
|
)
|
267
123
|
|
268
124
|
# 初始化layout模型
|
@@ -270,172 +126,110 @@ class CustomPEKModel:
|
|
270
126
|
self.layout_model = atom_model_manager.get_atom_model(
|
271
127
|
atom_model_name=AtomicModel.Layout,
|
272
128
|
layout_model_name=MODEL_NAME.LAYOUTLMv3,
|
273
|
-
layout_weights=str(
|
274
|
-
|
275
|
-
|
129
|
+
layout_weights=str(
|
130
|
+
os.path.join(
|
131
|
+
models_dir, self.configs['weights'][self.layout_model_name]
|
132
|
+
)
|
133
|
+
),
|
134
|
+
layout_config_file=str(
|
135
|
+
os.path.join(
|
136
|
+
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
|
137
|
+
)
|
138
|
+
),
|
139
|
+
device=self.device,
|
276
140
|
)
|
277
141
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
278
142
|
self.layout_model = atom_model_manager.get_atom_model(
|
279
143
|
atom_model_name=AtomicModel.Layout,
|
280
144
|
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
281
|
-
doclayout_yolo_weights=str(
|
145
|
+
doclayout_yolo_weights=str(
|
146
|
+
os.path.join(
|
147
|
+
models_dir, self.configs['weights'][self.layout_model_name]
|
148
|
+
)
|
149
|
+
),
|
150
|
+
device=self.device,
|
282
151
|
)
|
283
152
|
# 初始化ocr
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
)
|
153
|
+
self.ocr_model = atom_model_manager.get_atom_model(
|
154
|
+
atom_model_name=AtomicModel.OCR,
|
155
|
+
ocr_show_log=show_log,
|
156
|
+
det_db_box_thresh=0.3,
|
157
|
+
lang=self.lang
|
158
|
+
)
|
291
159
|
# init table model
|
292
160
|
if self.apply_table:
|
293
|
-
table_model_dir = self.configs[
|
161
|
+
table_model_dir = self.configs['weights'][self.table_model_name]
|
294
162
|
self.table_model = atom_model_manager.get_atom_model(
|
295
163
|
atom_model_name=AtomicModel.Table,
|
296
164
|
table_model_name=self.table_model_name,
|
297
165
|
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
298
166
|
table_max_time=self.table_max_time,
|
299
|
-
device=self.device
|
167
|
+
device=self.device,
|
300
168
|
)
|
301
169
|
|
302
170
|
logger.info('DocAnalysis init done!')
|
303
171
|
|
304
172
|
def __call__(self, image):
|
305
173
|
|
306
|
-
page_start = time.time()
|
307
|
-
|
308
|
-
latex_filling_list = []
|
309
|
-
mf_image_list = []
|
310
|
-
|
311
174
|
# layout检测
|
312
175
|
layout_start = time.time()
|
176
|
+
layout_res = []
|
313
177
|
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
314
178
|
# layoutlmv3
|
315
179
|
layout_res = self.layout_model(image, ignore_catids=[])
|
316
180
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
317
181
|
# doclayout_yolo
|
318
|
-
layout_res =
|
319
|
-
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
320
|
-
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
321
|
-
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
322
|
-
new_item = {
|
323
|
-
'category_id': int(cla.item()),
|
324
|
-
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
325
|
-
'score': round(float(conf.item()), 3),
|
326
|
-
}
|
327
|
-
layout_res.append(new_item)
|
182
|
+
layout_res = self.layout_model.predict(image)
|
328
183
|
layout_cost = round(time.time() - layout_start, 2)
|
329
|
-
logger.info(f
|
184
|
+
logger.info(f'layout detection time: {layout_cost}')
|
330
185
|
|
331
186
|
pil_img = Image.fromarray(image)
|
332
187
|
|
333
188
|
if self.apply_formula:
|
334
189
|
# 公式检测
|
335
190
|
mfd_start = time.time()
|
336
|
-
mfd_res = self.mfd_model.predict(image
|
337
|
-
logger.info(f
|
338
|
-
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
339
|
-
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
340
|
-
new_item = {
|
341
|
-
'category_id': 13 + int(cla.item()),
|
342
|
-
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
343
|
-
'score': round(float(conf.item()), 2),
|
344
|
-
'latex': '',
|
345
|
-
}
|
346
|
-
layout_res.append(new_item)
|
347
|
-
latex_filling_list.append(new_item)
|
348
|
-
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
349
|
-
mf_image_list.append(bbox_img)
|
191
|
+
mfd_res = self.mfd_model.predict(image)
|
192
|
+
logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
|
350
193
|
|
351
194
|
# 公式识别
|
352
195
|
mfr_start = time.time()
|
353
|
-
|
354
|
-
|
355
|
-
mfr_res = []
|
356
|
-
for mf_img in dataloader:
|
357
|
-
mf_img = mf_img.to(self.device)
|
358
|
-
with torch.no_grad():
|
359
|
-
output = self.mfr_model.generate({'image': mf_img})
|
360
|
-
mfr_res.extend(output['pred_str'])
|
361
|
-
for res, latex in zip(latex_filling_list, mfr_res):
|
362
|
-
res['latex'] = latex_rm_whitespace(latex)
|
196
|
+
formula_list = self.mfr_model.predict(mfd_res, image)
|
197
|
+
layout_res.extend(formula_list)
|
363
198
|
mfr_cost = round(time.time() - mfr_start, 2)
|
364
|
-
logger.info(f
|
199
|
+
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
|
365
200
|
|
366
|
-
#
|
367
|
-
|
368
|
-
table_res_list = []
|
369
|
-
single_page_mfdetrec_res = []
|
370
|
-
for res in layout_res:
|
371
|
-
if int(res['category_id']) in [13, 14]:
|
372
|
-
single_page_mfdetrec_res.append({
|
373
|
-
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
374
|
-
int(res['poly'][4]), int(res['poly'][5])],
|
375
|
-
})
|
376
|
-
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
377
|
-
ocr_res_list.append(res)
|
378
|
-
elif int(res['category_id']) in [5]:
|
379
|
-
table_res_list.append(res)
|
201
|
+
# 清理显存
|
202
|
+
clean_vram(self.device, vram_threshold=8)
|
380
203
|
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
gc_start = time.time()
|
386
|
-
clean_memory()
|
387
|
-
gc_time = round(time.time() - gc_start, 2)
|
388
|
-
logger.info(f"gc time: {gc_time}")
|
204
|
+
# 从layout_res中获取ocr区域、表格区域、公式区域
|
205
|
+
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
206
|
+
get_res_list_from_layout_res(layout_res)
|
207
|
+
)
|
389
208
|
|
390
209
|
# ocr识别
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
401
|
-
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
402
|
-
x0 = mf_xmin - xmin + paste_x
|
403
|
-
y0 = mf_ymin - ymin + paste_y
|
404
|
-
x1 = mf_xmax - xmin + paste_x
|
405
|
-
y1 = mf_ymax - ymin + paste_y
|
406
|
-
# Filter formula blocks outside the graph
|
407
|
-
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
408
|
-
continue
|
409
|
-
else:
|
410
|
-
adjusted_mfdetrec_res.append({
|
411
|
-
"bbox": [x0, y0, x1, y1],
|
412
|
-
})
|
413
|
-
|
414
|
-
# OCR recognition
|
415
|
-
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
210
|
+
ocr_start = time.time()
|
211
|
+
# Process each area that requires OCR processing
|
212
|
+
for res in ocr_res_list:
|
213
|
+
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
214
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
215
|
+
|
216
|
+
# OCR recognition
|
217
|
+
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
218
|
+
if self.apply_ocr:
|
416
219
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
220
|
+
else:
|
221
|
+
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
|
417
222
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
text, score = box_ocr_res[1]
|
223
|
+
# Integration results
|
224
|
+
if ocr_res:
|
225
|
+
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
226
|
+
layout_res.extend(ocr_result_list)
|
423
227
|
|
424
|
-
|
425
|
-
|
426
|
-
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
427
|
-
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
428
|
-
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
429
|
-
|
430
|
-
layout_res.append({
|
431
|
-
'category_id': 15,
|
432
|
-
'poly': p1 + p2 + p3 + p4,
|
433
|
-
'score': round(score, 2),
|
434
|
-
'text': text,
|
435
|
-
})
|
436
|
-
|
437
|
-
ocr_cost = round(time.time() - ocr_start, 2)
|
228
|
+
ocr_cost = round(time.time() - ocr_start, 2)
|
229
|
+
if self.apply_ocr:
|
438
230
|
logger.info(f"ocr time: {ocr_cost}")
|
231
|
+
else:
|
232
|
+
logger.info(f"det time: {ocr_cost}")
|
439
233
|
|
440
234
|
# 表格识别 table recognition
|
441
235
|
if self.apply_table:
|
@@ -443,41 +237,38 @@ class CustomPEKModel:
|
|
443
237
|
for res in table_res_list:
|
444
238
|
new_image, _ = crop_img(res, pil_img)
|
445
239
|
single_table_start_time = time.time()
|
446
|
-
# logger.info("------------------table recognition processing begins-----------------")
|
447
|
-
latex_code = None
|
448
240
|
html_code = None
|
449
241
|
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
450
242
|
with torch.no_grad():
|
451
|
-
table_result = self.table_model.predict(new_image,
|
243
|
+
table_result = self.table_model.predict(new_image, 'html')
|
452
244
|
if len(table_result) > 0:
|
453
245
|
html_code = table_result[0]
|
454
|
-
|
246
|
+
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
455
247
|
html_code = self.table_model.img2html(new_image)
|
456
|
-
|
248
|
+
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
249
|
+
html_code, table_cell_bboxes, elapse = self.table_model.predict(
|
250
|
+
new_image
|
251
|
+
)
|
457
252
|
run_time = time.time() - single_table_start_time
|
458
|
-
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
459
253
|
if run_time > self.table_max_time:
|
460
|
-
logger.warning(
|
254
|
+
logger.warning(
|
255
|
+
f'table recognition processing exceeds max time {self.table_max_time}s'
|
256
|
+
)
|
461
257
|
# 判断是否返回正常
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
res["latex"] = latex_code
|
467
|
-
else:
|
468
|
-
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
|
469
|
-
elif html_code:
|
470
|
-
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
|
258
|
+
if html_code:
|
259
|
+
expected_ending = html_code.strip().endswith(
|
260
|
+
'</html>'
|
261
|
+
) or html_code.strip().endswith('</table>')
|
471
262
|
if expected_ending:
|
472
|
-
res[
|
263
|
+
res['html'] = html_code
|
473
264
|
else:
|
474
|
-
logger.warning(
|
265
|
+
logger.warning(
|
266
|
+
'table recognition processing fails, not found expected HTML table end'
|
267
|
+
)
|
475
268
|
else:
|
476
|
-
logger.warning(
|
477
|
-
|
478
|
-
|
479
|
-
|
269
|
+
logger.warning(
|
270
|
+
'table recognition processing fails, not get html return'
|
271
|
+
)
|
272
|
+
logger.info(f'table time: {round(time.time() - table_start, 2)}')
|
480
273
|
|
481
274
|
return layout_res
|
482
|
-
|
483
|
-
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from doclayout_yolo import YOLOv10
|
2
|
+
|
3
|
+
|
4
|
+
class DocLayoutYOLOModel(object):
|
5
|
+
def __init__(self, weight, device):
|
6
|
+
self.model = YOLOv10(weight)
|
7
|
+
self.device = device
|
8
|
+
|
9
|
+
def predict(self, image):
|
10
|
+
layout_res = []
|
11
|
+
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
12
|
+
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
|
13
|
+
doclayout_yolo_res.boxes.cls.cpu()):
|
14
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
15
|
+
new_item = {
|
16
|
+
'category_id': int(cla.item()),
|
17
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
18
|
+
'score': round(float(conf.item()), 3),
|
19
|
+
}
|
20
|
+
layout_res.append(new_item)
|
21
|
+
return layout_res
|
File without changes
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from ultralytics import YOLO
|
2
|
+
|
3
|
+
|
4
|
+
class YOLOv8MFDModel(object):
|
5
|
+
def __init__(self, weight, device='cpu'):
|
6
|
+
self.mfd_model = YOLO(weight)
|
7
|
+
self.device = device
|
8
|
+
|
9
|
+
def predict(self, image):
|
10
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
11
|
+
return mfd_res
|
12
|
+
|
File without changes
|
File without changes
|