magic-pdf 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/__init__.py +0 -0
- magic_pdf/config/enums.py +7 -0
- magic_pdf/config/exceptions.py +32 -0
- magic_pdf/data/__init__.py +0 -0
- magic_pdf/data/data_reader_writer/__init__.py +12 -0
- magic_pdf/data/data_reader_writer/base.py +51 -0
- magic_pdf/data/data_reader_writer/filebase.py +59 -0
- magic_pdf/data/data_reader_writer/multi_bucket_s3.py +137 -0
- magic_pdf/data/data_reader_writer/s3.py +69 -0
- magic_pdf/data/dataset.py +194 -0
- magic_pdf/data/io/__init__.py +0 -0
- magic_pdf/data/io/base.py +42 -0
- magic_pdf/data/io/http.py +37 -0
- magic_pdf/data/io/s3.py +114 -0
- magic_pdf/data/read_api.py +95 -0
- magic_pdf/data/schemas.py +15 -0
- magic_pdf/data/utils.py +32 -0
- magic_pdf/dict2md/ocr_mkcontent.py +74 -234
- magic_pdf/libs/Constants.py +21 -8
- magic_pdf/libs/MakeContentConfig.py +1 -0
- magic_pdf/libs/boxbase.py +54 -0
- magic_pdf/libs/clean_memory.py +10 -0
- magic_pdf/libs/config_reader.py +53 -23
- magic_pdf/libs/draw_bbox.py +150 -65
- magic_pdf/libs/ocr_content_type.py +2 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +77 -32
- magic_pdf/model/magic_model.py +418 -51
- magic_pdf/model/pdf_extract_kit.py +164 -80
- magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +8 -1
- magic_pdf/model/ppTableModel.py +2 -2
- magic_pdf/model/pp_structure_v2.py +5 -2
- magic_pdf/model/v3/__init__.py +0 -0
- magic_pdf/model/v3/helpers.py +125 -0
- magic_pdf/para/para_split_v3.py +296 -0
- magic_pdf/pdf_parse_by_ocr.py +6 -3
- magic_pdf/pdf_parse_by_txt.py +6 -3
- magic_pdf/pdf_parse_union_core_v2.py +644 -0
- magic_pdf/pipe/AbsPipe.py +5 -1
- magic_pdf/pipe/OCRPipe.py +10 -4
- magic_pdf/pipe/TXTPipe.py +10 -4
- magic_pdf/pipe/UNIPipe.py +16 -7
- magic_pdf/pre_proc/ocr_detect_all_bboxes.py +83 -1
- magic_pdf/pre_proc/ocr_dict_merge.py +27 -2
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +7 -7
- magic_pdf/resources/model_config/model_configs.yaml +5 -13
- magic_pdf/tools/cli.py +14 -1
- magic_pdf/tools/common.py +19 -9
- magic_pdf/user_api.py +25 -6
- magic_pdf/utils/__init__.py +0 -0
- magic_pdf/utils/annotations.py +11 -0
- {magic_pdf-0.8.0.dist-info → magic_pdf-0.9.0.dist-info}/LICENSE.md +1 -0
- magic_pdf-0.9.0.dist-info/METADATA +507 -0
- {magic_pdf-0.8.0.dist-info → magic_pdf-0.9.0.dist-info}/RECORD +57 -33
- magic_pdf-0.8.0.dist-info/METADATA +0 -459
- {magic_pdf-0.8.0.dist-info → magic_pdf-0.9.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.8.0.dist-info → magic_pdf-0.9.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.8.0.dist-info → magic_pdf-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
|
|
1
1
|
from loguru import logger
|
2
2
|
import os
|
3
3
|
import time
|
4
|
-
|
4
|
+
from pathlib import Path
|
5
|
+
import shutil
|
5
6
|
from magic_pdf.libs.Constants import *
|
7
|
+
from magic_pdf.libs.clean_memory import clean_memory
|
6
8
|
from magic_pdf.model.model_list import AtomicModel
|
7
9
|
|
8
10
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
11
|
+
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
9
12
|
try:
|
10
13
|
import cv2
|
11
14
|
import yaml
|
@@ -23,6 +26,7 @@ try:
|
|
23
26
|
from unimernet.common.config import Config
|
24
27
|
import unimernet.tasks as tasks
|
25
28
|
from unimernet.processors import load_processor
|
29
|
+
from doclayout_yolo import YOLOv10
|
26
30
|
|
27
31
|
except ImportError as e:
|
28
32
|
logger.exception(e)
|
@@ -32,21 +36,26 @@ except ImportError as e:
|
|
32
36
|
exit(1)
|
33
37
|
|
34
38
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
35
|
-
from magic_pdf.model.pek_sub_modules.post_process import
|
39
|
+
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
36
40
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
37
|
-
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
41
|
+
# from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
38
42
|
from magic_pdf.model.ppTableModel import ppTableModel
|
39
43
|
|
40
44
|
|
41
45
|
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
42
|
-
if table_model_type == STRUCT_EQTABLE:
|
43
|
-
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
44
|
-
|
46
|
+
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
47
|
+
# table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
48
|
+
logger.error("StructEqTable is under upgrade, the current version does not support it.")
|
49
|
+
exit(1)
|
50
|
+
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
45
51
|
config = {
|
46
52
|
"model_dir": model_path,
|
47
53
|
"device": _device_
|
48
54
|
}
|
49
55
|
table_model = ppTableModel(config)
|
56
|
+
else:
|
57
|
+
logger.error("table model type not allow")
|
58
|
+
exit(1)
|
50
59
|
return table_model
|
51
60
|
|
52
61
|
|
@@ -58,12 +67,13 @@ def mfd_model_init(weight):
|
|
58
67
|
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
59
68
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
60
69
|
cfg = Config(args)
|
61
|
-
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.
|
70
|
+
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
62
71
|
cfg.config.model.model_config.model_name = weight_dir
|
63
72
|
cfg.config.model.tokenizer_config.path = weight_dir
|
64
73
|
task = tasks.setup_task(cfg)
|
65
74
|
model = task.build_model(cfg)
|
66
|
-
model
|
75
|
+
model.to(_device_)
|
76
|
+
model.eval()
|
67
77
|
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
68
78
|
mfr_transform = transforms.Compose([vis_processor, ])
|
69
79
|
return [model, mfr_transform]
|
@@ -74,8 +84,16 @@ def layout_model_init(weight, config_file, device):
|
|
74
84
|
return model
|
75
85
|
|
76
86
|
|
77
|
-
def
|
78
|
-
model =
|
87
|
+
def doclayout_yolo_model_init(weight):
|
88
|
+
model = YOLOv10(weight)
|
89
|
+
return model
|
90
|
+
|
91
|
+
|
92
|
+
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
|
93
|
+
if lang is not None:
|
94
|
+
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
95
|
+
else:
|
96
|
+
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
79
97
|
return model
|
80
98
|
|
81
99
|
|
@@ -108,19 +126,27 @@ class AtomModelSingleton:
|
|
108
126
|
return cls._instance
|
109
127
|
|
110
128
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
111
|
-
|
112
|
-
|
113
|
-
|
129
|
+
lang = kwargs.get("lang", None)
|
130
|
+
layout_model_name = kwargs.get("layout_model_name", None)
|
131
|
+
key = (atom_model_name, layout_model_name, lang)
|
132
|
+
if key not in self._models:
|
133
|
+
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
134
|
+
return self._models[key]
|
114
135
|
|
115
136
|
|
116
137
|
def atom_model_init(model_name: str, **kwargs):
|
117
138
|
|
118
139
|
if model_name == AtomicModel.Layout:
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
140
|
+
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
141
|
+
atom_model = layout_model_init(
|
142
|
+
kwargs.get("layout_weights"),
|
143
|
+
kwargs.get("layout_config_file"),
|
144
|
+
kwargs.get("device")
|
145
|
+
)
|
146
|
+
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
147
|
+
atom_model = doclayout_yolo_model_init(
|
148
|
+
kwargs.get("doclayout_yolo_weights"),
|
149
|
+
)
|
124
150
|
elif model_name == AtomicModel.MFD:
|
125
151
|
atom_model = mfd_model_init(
|
126
152
|
kwargs.get("mfd_weights")
|
@@ -134,11 +160,12 @@ def atom_model_init(model_name: str, **kwargs):
|
|
134
160
|
elif model_name == AtomicModel.OCR:
|
135
161
|
atom_model = ocr_model_init(
|
136
162
|
kwargs.get("ocr_show_log"),
|
137
|
-
kwargs.get("det_db_box_thresh")
|
163
|
+
kwargs.get("det_db_box_thresh"),
|
164
|
+
kwargs.get("lang")
|
138
165
|
)
|
139
166
|
elif model_name == AtomicModel.Table:
|
140
167
|
atom_model = table_model_init(
|
141
|
-
kwargs.get("
|
168
|
+
kwargs.get("table_model_name"),
|
142
169
|
kwargs.get("table_model_path"),
|
143
170
|
kwargs.get("table_max_time"),
|
144
171
|
kwargs.get("device")
|
@@ -150,6 +177,23 @@ def atom_model_init(model_name: str, **kwargs):
|
|
150
177
|
return atom_model
|
151
178
|
|
152
179
|
|
180
|
+
# Unified crop img logic
|
181
|
+
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
182
|
+
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
183
|
+
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
184
|
+
# Create a white background with an additional width and height of 50
|
185
|
+
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
186
|
+
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
187
|
+
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
188
|
+
|
189
|
+
# Crop image
|
190
|
+
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
191
|
+
cropped_img = input_pil_img.crop(crop_box)
|
192
|
+
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
193
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
194
|
+
return return_image, return_list
|
195
|
+
|
196
|
+
|
153
197
|
class CustomPEKModel:
|
154
198
|
|
155
199
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
@@ -169,22 +213,35 @@ class CustomPEKModel:
|
|
169
213
|
with open(config_path, "r", encoding='utf-8') as f:
|
170
214
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
171
215
|
# 初始化解析配置
|
172
|
-
|
173
|
-
|
216
|
+
|
217
|
+
# layout config
|
218
|
+
self.layout_config = kwargs.get("layout_config")
|
219
|
+
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
|
220
|
+
|
221
|
+
# formula config
|
222
|
+
self.formula_config = kwargs.get("formula_config")
|
223
|
+
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
|
224
|
+
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
|
225
|
+
self.apply_formula = self.formula_config.get("enable", True)
|
226
|
+
|
174
227
|
# table config
|
175
|
-
self.table_config = kwargs.get("table_config"
|
176
|
-
self.apply_table = self.table_config.get("
|
228
|
+
self.table_config = kwargs.get("table_config")
|
229
|
+
self.apply_table = self.table_config.get("enable", False)
|
177
230
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
178
|
-
self.
|
231
|
+
self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
|
232
|
+
|
233
|
+
# ocr config
|
179
234
|
self.apply_ocr = ocr
|
235
|
+
self.lang = kwargs.get("lang", None)
|
236
|
+
|
180
237
|
logger.info(
|
181
|
-
"DocAnalysis init, this may take some times
|
182
|
-
|
238
|
+
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
|
239
|
+
"apply_table: {}, table_model: {}, lang: {}".format(
|
240
|
+
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
|
183
241
|
)
|
184
242
|
)
|
185
|
-
assert self.apply_layout, "DocAnalysis must contain layout model."
|
186
243
|
# 初始化解析方案
|
187
|
-
self.device = kwargs.get("device",
|
244
|
+
self.device = kwargs.get("device", "cpu")
|
188
245
|
logger.info("using device: {}".format(self.device))
|
189
246
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
190
247
|
logger.info("using models_dir: {}".format(models_dir))
|
@@ -193,17 +250,16 @@ class CustomPEKModel:
|
|
193
250
|
|
194
251
|
# 初始化公式识别
|
195
252
|
if self.apply_formula:
|
253
|
+
|
196
254
|
# 初始化公式检测模型
|
197
|
-
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
198
255
|
self.mfd_model = atom_model_manager.get_atom_model(
|
199
256
|
atom_model_name=AtomicModel.MFD,
|
200
|
-
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][
|
257
|
+
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
|
201
258
|
)
|
259
|
+
|
202
260
|
# 初始化公式解析模型
|
203
|
-
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][
|
261
|
+
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
|
204
262
|
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
205
|
-
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
206
|
-
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
207
263
|
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
|
208
264
|
atom_model_name=AtomicModel.MFR,
|
209
265
|
mfr_weight_dir=mfr_weight_dir,
|
@@ -212,17 +268,20 @@ class CustomPEKModel:
|
|
212
268
|
)
|
213
269
|
|
214
270
|
# 初始化layout模型
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
271
|
+
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
272
|
+
self.layout_model = atom_model_manager.get_atom_model(
|
273
|
+
atom_model_name=AtomicModel.Layout,
|
274
|
+
layout_model_name=MODEL_NAME.LAYOUTLMv3,
|
275
|
+
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
|
276
|
+
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
277
|
+
device=self.device
|
278
|
+
)
|
279
|
+
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
280
|
+
self.layout_model = atom_model_manager.get_atom_model(
|
281
|
+
atom_model_name=AtomicModel.Layout,
|
282
|
+
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
283
|
+
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
284
|
+
)
|
226
285
|
# 初始化ocr
|
227
286
|
if self.apply_ocr:
|
228
287
|
|
@@ -230,37 +289,67 @@ class CustomPEKModel:
|
|
230
289
|
self.ocr_model = atom_model_manager.get_atom_model(
|
231
290
|
atom_model_name=AtomicModel.OCR,
|
232
291
|
ocr_show_log=show_log,
|
233
|
-
det_db_box_thresh=0.3
|
292
|
+
det_db_box_thresh=0.3,
|
293
|
+
lang=self.lang
|
234
294
|
)
|
235
295
|
# init table model
|
236
296
|
if self.apply_table:
|
237
|
-
table_model_dir = self.configs["weights"][self.
|
238
|
-
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
239
|
-
# max_time=self.table_max_time, _device_=self.device)
|
297
|
+
table_model_dir = self.configs["weights"][self.table_model_name]
|
240
298
|
self.table_model = atom_model_manager.get_atom_model(
|
241
299
|
atom_model_name=AtomicModel.Table,
|
242
|
-
|
300
|
+
table_model_name=self.table_model_name,
|
243
301
|
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
244
302
|
table_max_time=self.table_max_time,
|
245
303
|
device=self.device
|
246
304
|
)
|
247
305
|
|
306
|
+
home_directory = Path.home()
|
307
|
+
det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
|
308
|
+
rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
|
309
|
+
det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
|
310
|
+
rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
|
311
|
+
|
312
|
+
if not os.path.exists(det_dest_dir):
|
313
|
+
shutil.copytree(det_source, det_dest_dir)
|
314
|
+
if not os.path.exists(rec_dest_dir):
|
315
|
+
shutil.copytree(rec_source, rec_dest_dir)
|
316
|
+
|
248
317
|
logger.info('DocAnalysis init done!')
|
249
318
|
|
250
319
|
def __call__(self, image):
|
251
320
|
|
321
|
+
page_start = time.time()
|
322
|
+
|
252
323
|
latex_filling_list = []
|
253
324
|
mf_image_list = []
|
254
325
|
|
255
326
|
# layout检测
|
256
327
|
layout_start = time.time()
|
257
|
-
|
328
|
+
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
329
|
+
# layoutlmv3
|
330
|
+
layout_res = self.layout_model(image, ignore_catids=[])
|
331
|
+
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
332
|
+
# doclayout_yolo
|
333
|
+
layout_res = []
|
334
|
+
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
335
|
+
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
336
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
337
|
+
new_item = {
|
338
|
+
'category_id': int(cla.item()),
|
339
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
340
|
+
'score': round(float(conf.item()), 3),
|
341
|
+
}
|
342
|
+
layout_res.append(new_item)
|
258
343
|
layout_cost = round(time.time() - layout_start, 2)
|
259
|
-
logger.info(f"layout detection
|
344
|
+
logger.info(f"layout detection time: {layout_cost}")
|
345
|
+
|
346
|
+
pil_img = Image.fromarray(image)
|
260
347
|
|
261
348
|
if self.apply_formula:
|
262
349
|
# 公式检测
|
263
|
-
|
350
|
+
mfd_start = time.time()
|
351
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
352
|
+
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
|
264
353
|
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
265
354
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
266
355
|
new_item = {
|
@@ -271,7 +360,7 @@ class CustomPEKModel:
|
|
271
360
|
}
|
272
361
|
layout_res.append(new_item)
|
273
362
|
latex_filling_list.append(new_item)
|
274
|
-
bbox_img =
|
363
|
+
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
275
364
|
mf_image_list.append(bbox_img)
|
276
365
|
|
277
366
|
# 公式识别
|
@@ -281,7 +370,8 @@ class CustomPEKModel:
|
|
281
370
|
mfr_res = []
|
282
371
|
for mf_img in dataloader:
|
283
372
|
mf_img = mf_img.to(self.device)
|
284
|
-
|
373
|
+
with torch.no_grad():
|
374
|
+
output = self.mfr_model.generate({'image': mf_img})
|
285
375
|
mfr_res.extend(output['pred_str'])
|
286
376
|
for res, latex in zip(latex_filling_list, mfr_res):
|
287
377
|
res['latex'] = latex_rm_whitespace(latex)
|
@@ -303,23 +393,14 @@ class CustomPEKModel:
|
|
303
393
|
elif int(res['category_id']) in [5]:
|
304
394
|
table_res_list.append(res)
|
305
395
|
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
# Crop image
|
316
|
-
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
317
|
-
cropped_img = input_pil_img.crop(crop_box)
|
318
|
-
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
319
|
-
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
320
|
-
return return_image, return_list
|
321
|
-
|
322
|
-
pil_img = Image.fromarray(image)
|
396
|
+
if torch.cuda.is_available():
|
397
|
+
properties = torch.cuda.get_device_properties(self.device)
|
398
|
+
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
|
399
|
+
if total_memory <= 10:
|
400
|
+
gc_start = time.time()
|
401
|
+
clean_memory()
|
402
|
+
gc_time = round(time.time() - gc_start, 2)
|
403
|
+
logger.info(f"gc time: {gc_time}")
|
323
404
|
|
324
405
|
# ocr识别
|
325
406
|
if self.apply_ocr:
|
@@ -369,7 +450,7 @@ class CustomPEKModel:
|
|
369
450
|
})
|
370
451
|
|
371
452
|
ocr_cost = round(time.time() - ocr_start, 2)
|
372
|
-
logger.info(f"ocr
|
453
|
+
logger.info(f"ocr time: {ocr_cost}")
|
373
454
|
|
374
455
|
# 表格识别 table recognition
|
375
456
|
if self.apply_table:
|
@@ -377,17 +458,17 @@ class CustomPEKModel:
|
|
377
458
|
for res in table_res_list:
|
378
459
|
new_image, _ = crop_img(res, pil_img)
|
379
460
|
single_table_start_time = time.time()
|
380
|
-
logger.info("------------------table recognition processing begins-----------------")
|
461
|
+
# logger.info("------------------table recognition processing begins-----------------")
|
381
462
|
latex_code = None
|
382
463
|
html_code = None
|
383
|
-
if self.
|
464
|
+
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
384
465
|
with torch.no_grad():
|
385
466
|
latex_code = self.table_model.image2latex(new_image)[0]
|
386
467
|
else:
|
387
468
|
html_code = self.table_model.img2html(new_image)
|
388
469
|
|
389
470
|
run_time = time.time() - single_table_start_time
|
390
|
-
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
471
|
+
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
391
472
|
if run_time > self.table_max_time:
|
392
473
|
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
393
474
|
# 判断是否返回正常
|
@@ -398,12 +479,15 @@ class CustomPEKModel:
|
|
398
479
|
if expected_ending:
|
399
480
|
res["latex"] = latex_code
|
400
481
|
else:
|
401
|
-
logger.warning(f"
|
482
|
+
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
|
402
483
|
elif html_code:
|
403
484
|
res["html"] = html_code
|
404
485
|
else:
|
405
|
-
logger.warning(f"
|
406
|
-
|
407
|
-
|
486
|
+
logger.warning(f"table recognition processing fails, not get latex or html return")
|
487
|
+
logger.info(f"table time: {round(time.time() - table_start, 2)}")
|
488
|
+
|
489
|
+
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
|
408
490
|
|
409
491
|
return layout_res
|
492
|
+
|
493
|
+
|
@@ -1,5 +1,12 @@
|
|
1
|
-
from
|
1
|
+
from loguru import logger
|
2
|
+
|
3
|
+
try:
|
4
|
+
from struct_eqtable.model import StructTable
|
5
|
+
except ImportError:
|
6
|
+
logger.error("StructEqTable is under upgrade, the current version does not support it.")
|
2
7
|
from pypandoc import convert_text
|
8
|
+
|
9
|
+
|
3
10
|
class StructTableModel:
|
4
11
|
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
|
5
12
|
# init
|
magic_pdf/model/ppTableModel.py
CHANGED
@@ -52,11 +52,11 @@ class ppTableModel(object):
|
|
52
52
|
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
|
53
53
|
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
|
54
54
|
device = kwargs.get("device", "cpu")
|
55
|
-
use_gpu = True if device
|
55
|
+
use_gpu = True if device.startswith("cuda") else False
|
56
56
|
config = {
|
57
57
|
"use_gpu": use_gpu,
|
58
58
|
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
|
59
|
-
"table_algorithm":
|
59
|
+
"table_algorithm": "TableMaster",
|
60
60
|
"table_model_dir": table_model_dir,
|
61
61
|
"table_char_dict_path": table_char_dict_path,
|
62
62
|
"det_model_dir": det_model_dir,
|
@@ -18,8 +18,11 @@ def region_to_bbox(region):
|
|
18
18
|
|
19
19
|
|
20
20
|
class CustomPaddleModel:
|
21
|
-
def __init__(self, ocr: bool = False, show_log: bool = False):
|
22
|
-
|
21
|
+
def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
|
22
|
+
if lang is not None:
|
23
|
+
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
|
24
|
+
else:
|
25
|
+
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
|
23
26
|
|
24
27
|
def __call__(self, img):
|
25
28
|
try:
|
File without changes
|
@@ -0,0 +1,125 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from typing import List, Dict
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from transformers import LayoutLMv3ForTokenClassification
|
6
|
+
|
7
|
+
MAX_LEN = 510
|
8
|
+
CLS_TOKEN_ID = 0
|
9
|
+
UNK_TOKEN_ID = 3
|
10
|
+
EOS_TOKEN_ID = 2
|
11
|
+
|
12
|
+
|
13
|
+
class DataCollator:
|
14
|
+
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
|
15
|
+
bbox = []
|
16
|
+
labels = []
|
17
|
+
input_ids = []
|
18
|
+
attention_mask = []
|
19
|
+
|
20
|
+
# clip bbox and labels to max length, build input_ids and attention_mask
|
21
|
+
for feature in features:
|
22
|
+
_bbox = feature["source_boxes"]
|
23
|
+
if len(_bbox) > MAX_LEN:
|
24
|
+
_bbox = _bbox[:MAX_LEN]
|
25
|
+
_labels = feature["target_index"]
|
26
|
+
if len(_labels) > MAX_LEN:
|
27
|
+
_labels = _labels[:MAX_LEN]
|
28
|
+
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
|
29
|
+
_attention_mask = [1] * len(_bbox)
|
30
|
+
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
|
31
|
+
bbox.append(_bbox)
|
32
|
+
labels.append(_labels)
|
33
|
+
input_ids.append(_input_ids)
|
34
|
+
attention_mask.append(_attention_mask)
|
35
|
+
|
36
|
+
# add CLS and EOS tokens
|
37
|
+
for i in range(len(bbox)):
|
38
|
+
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
|
39
|
+
labels[i] = [-100] + labels[i] + [-100]
|
40
|
+
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
|
41
|
+
attention_mask[i] = [1] + attention_mask[i] + [1]
|
42
|
+
|
43
|
+
# padding to max length
|
44
|
+
max_len = max(len(x) for x in bbox)
|
45
|
+
for i in range(len(bbox)):
|
46
|
+
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
|
47
|
+
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
|
48
|
+
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
|
49
|
+
attention_mask[i] = attention_mask[i] + [0] * (
|
50
|
+
max_len - len(attention_mask[i])
|
51
|
+
)
|
52
|
+
|
53
|
+
ret = {
|
54
|
+
"bbox": torch.tensor(bbox),
|
55
|
+
"attention_mask": torch.tensor(attention_mask),
|
56
|
+
"labels": torch.tensor(labels),
|
57
|
+
"input_ids": torch.tensor(input_ids),
|
58
|
+
}
|
59
|
+
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
|
60
|
+
ret["labels"][ret["labels"] > MAX_LEN] = -100
|
61
|
+
# set label > 0 to label-1, because original labels are 1-indexed
|
62
|
+
ret["labels"][ret["labels"] > 0] -= 1
|
63
|
+
return ret
|
64
|
+
|
65
|
+
|
66
|
+
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
|
67
|
+
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
|
68
|
+
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
|
69
|
+
attention_mask = [1] + [1] * len(boxes) + [1]
|
70
|
+
return {
|
71
|
+
"bbox": torch.tensor([bbox]),
|
72
|
+
"attention_mask": torch.tensor([attention_mask]),
|
73
|
+
"input_ids": torch.tensor([input_ids]),
|
74
|
+
}
|
75
|
+
|
76
|
+
|
77
|
+
def prepare_inputs(
|
78
|
+
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
|
79
|
+
) -> Dict[str, torch.Tensor]:
|
80
|
+
ret = {}
|
81
|
+
for k, v in inputs.items():
|
82
|
+
v = v.to(model.device)
|
83
|
+
if torch.is_floating_point(v):
|
84
|
+
v = v.to(model.dtype)
|
85
|
+
ret[k] = v
|
86
|
+
return ret
|
87
|
+
|
88
|
+
|
89
|
+
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
|
90
|
+
"""
|
91
|
+
parse logits to orders
|
92
|
+
|
93
|
+
:param logits: logits from model
|
94
|
+
:param length: input length
|
95
|
+
:return: orders
|
96
|
+
"""
|
97
|
+
logits = logits[1 : length + 1, :length]
|
98
|
+
orders = logits.argsort(descending=False).tolist()
|
99
|
+
ret = [o.pop() for o in orders]
|
100
|
+
while True:
|
101
|
+
order_to_idxes = defaultdict(list)
|
102
|
+
for idx, order in enumerate(ret):
|
103
|
+
order_to_idxes[order].append(idx)
|
104
|
+
# filter idxes len > 1
|
105
|
+
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
|
106
|
+
if not order_to_idxes:
|
107
|
+
break
|
108
|
+
# filter
|
109
|
+
for order, idxes in order_to_idxes.items():
|
110
|
+
# find original logits of idxes
|
111
|
+
idxes_to_logit = {}
|
112
|
+
for idx in idxes:
|
113
|
+
idxes_to_logit[idx] = logits[idx, order]
|
114
|
+
idxes_to_logit = sorted(
|
115
|
+
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
|
116
|
+
)
|
117
|
+
# keep the highest logit as order, set others to next candidate
|
118
|
+
for idx, _ in idxes_to_logit[1:]:
|
119
|
+
ret[idx] = orders[idx].pop()
|
120
|
+
|
121
|
+
return ret
|
122
|
+
|
123
|
+
|
124
|
+
def check_duplicate(a: List[int]) -> bool:
|
125
|
+
return len(a) != len(set(a))
|