magic-pdf 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/__init__.py +0 -0
- magic_pdf/config/enums.py +7 -0
- magic_pdf/config/exceptions.py +32 -0
- magic_pdf/data/__init__.py +0 -0
- magic_pdf/data/data_reader_writer/__init__.py +12 -0
- magic_pdf/data/data_reader_writer/base.py +51 -0
- magic_pdf/data/data_reader_writer/filebase.py +59 -0
- magic_pdf/data/data_reader_writer/multi_bucket_s3.py +143 -0
- magic_pdf/data/data_reader_writer/s3.py +73 -0
- magic_pdf/data/dataset.py +194 -0
- magic_pdf/data/io/__init__.py +6 -0
- magic_pdf/data/io/base.py +42 -0
- magic_pdf/data/io/http.py +37 -0
- magic_pdf/data/io/s3.py +114 -0
- magic_pdf/data/read_api.py +95 -0
- magic_pdf/data/schemas.py +19 -0
- magic_pdf/data/utils.py +32 -0
- magic_pdf/dict2md/ocr_mkcontent.py +106 -244
- magic_pdf/libs/Constants.py +21 -8
- magic_pdf/libs/MakeContentConfig.py +1 -0
- magic_pdf/libs/boxbase.py +35 -0
- magic_pdf/libs/clean_memory.py +10 -0
- magic_pdf/libs/config_reader.py +53 -23
- magic_pdf/libs/draw_bbox.py +150 -65
- magic_pdf/libs/ocr_content_type.py +2 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +77 -32
- magic_pdf/model/magic_model.py +331 -15
- magic_pdf/model/pdf_extract_kit.py +170 -83
- magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +40 -16
- magic_pdf/model/ppTableModel.py +8 -6
- magic_pdf/model/pp_structure_v2.py +5 -2
- magic_pdf/model/v3/__init__.py +0 -0
- magic_pdf/model/v3/helpers.py +125 -0
- magic_pdf/para/para_split_v3.py +322 -0
- magic_pdf/pdf_parse_by_ocr.py +6 -3
- magic_pdf/pdf_parse_by_txt.py +6 -3
- magic_pdf/pdf_parse_union_core_v2.py +644 -0
- magic_pdf/pipe/AbsPipe.py +5 -1
- magic_pdf/pipe/OCRPipe.py +10 -4
- magic_pdf/pipe/TXTPipe.py +10 -4
- magic_pdf/pipe/UNIPipe.py +16 -7
- magic_pdf/pre_proc/ocr_detect_all_bboxes.py +83 -1
- magic_pdf/pre_proc/ocr_dict_merge.py +27 -2
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +7 -7
- magic_pdf/resources/model_config/model_configs.yaml +5 -13
- magic_pdf/tools/cli.py +14 -1
- magic_pdf/tools/common.py +18 -8
- magic_pdf/user_api.py +25 -6
- magic_pdf/utils/__init__.py +0 -0
- magic_pdf/utils/annotations.py +11 -0
- {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/LICENSE.md +1 -0
- {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/METADATA +124 -78
- {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/RECORD +57 -33
- {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/WHEEL +0 -0
- {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
|
|
1
1
|
from loguru import logger
|
2
2
|
import os
|
3
3
|
import time
|
4
|
-
|
4
|
+
from pathlib import Path
|
5
|
+
import shutil
|
5
6
|
from magic_pdf.libs.Constants import *
|
7
|
+
from magic_pdf.libs.clean_memory import clean_memory
|
6
8
|
from magic_pdf.model.model_list import AtomicModel
|
7
9
|
|
8
10
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
11
|
+
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
9
12
|
try:
|
10
13
|
import cv2
|
11
14
|
import yaml
|
@@ -23,6 +26,7 @@ try:
|
|
23
26
|
from unimernet.common.config import Config
|
24
27
|
import unimernet.tasks as tasks
|
25
28
|
from unimernet.processors import load_processor
|
29
|
+
from doclayout_yolo import YOLOv10
|
26
30
|
|
27
31
|
except ImportError as e:
|
28
32
|
logger.exception(e)
|
@@ -32,21 +36,24 @@ except ImportError as e:
|
|
32
36
|
exit(1)
|
33
37
|
|
34
38
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
35
|
-
from magic_pdf.model.pek_sub_modules.post_process import
|
39
|
+
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
36
40
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
37
41
|
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
38
42
|
from magic_pdf.model.ppTableModel import ppTableModel
|
39
43
|
|
40
44
|
|
41
45
|
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
42
|
-
if table_model_type == STRUCT_EQTABLE:
|
43
|
-
table_model = StructTableModel(model_path, max_time=max_time
|
44
|
-
|
46
|
+
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
47
|
+
table_model = StructTableModel(model_path, max_time=max_time)
|
48
|
+
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
45
49
|
config = {
|
46
50
|
"model_dir": model_path,
|
47
51
|
"device": _device_
|
48
52
|
}
|
49
53
|
table_model = ppTableModel(config)
|
54
|
+
else:
|
55
|
+
logger.error("table model type not allow")
|
56
|
+
exit(1)
|
50
57
|
return table_model
|
51
58
|
|
52
59
|
|
@@ -58,12 +65,13 @@ def mfd_model_init(weight):
|
|
58
65
|
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
59
66
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
60
67
|
cfg = Config(args)
|
61
|
-
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.
|
68
|
+
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
62
69
|
cfg.config.model.model_config.model_name = weight_dir
|
63
70
|
cfg.config.model.tokenizer_config.path = weight_dir
|
64
71
|
task = tasks.setup_task(cfg)
|
65
72
|
model = task.build_model(cfg)
|
66
|
-
model
|
73
|
+
model.to(_device_)
|
74
|
+
model.eval()
|
67
75
|
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
68
76
|
mfr_transform = transforms.Compose([vis_processor, ])
|
69
77
|
return [model, mfr_transform]
|
@@ -74,8 +82,16 @@ def layout_model_init(weight, config_file, device):
|
|
74
82
|
return model
|
75
83
|
|
76
84
|
|
77
|
-
def
|
78
|
-
model =
|
85
|
+
def doclayout_yolo_model_init(weight):
|
86
|
+
model = YOLOv10(weight)
|
87
|
+
return model
|
88
|
+
|
89
|
+
|
90
|
+
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
|
91
|
+
if lang is not None:
|
92
|
+
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
93
|
+
else:
|
94
|
+
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
79
95
|
return model
|
80
96
|
|
81
97
|
|
@@ -108,19 +124,27 @@ class AtomModelSingleton:
|
|
108
124
|
return cls._instance
|
109
125
|
|
110
126
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
111
|
-
|
112
|
-
|
113
|
-
|
127
|
+
lang = kwargs.get("lang", None)
|
128
|
+
layout_model_name = kwargs.get("layout_model_name", None)
|
129
|
+
key = (atom_model_name, layout_model_name, lang)
|
130
|
+
if key not in self._models:
|
131
|
+
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
132
|
+
return self._models[key]
|
114
133
|
|
115
134
|
|
116
135
|
def atom_model_init(model_name: str, **kwargs):
|
117
136
|
|
118
137
|
if model_name == AtomicModel.Layout:
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
138
|
+
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
139
|
+
atom_model = layout_model_init(
|
140
|
+
kwargs.get("layout_weights"),
|
141
|
+
kwargs.get("layout_config_file"),
|
142
|
+
kwargs.get("device")
|
143
|
+
)
|
144
|
+
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
145
|
+
atom_model = doclayout_yolo_model_init(
|
146
|
+
kwargs.get("doclayout_yolo_weights"),
|
147
|
+
)
|
124
148
|
elif model_name == AtomicModel.MFD:
|
125
149
|
atom_model = mfd_model_init(
|
126
150
|
kwargs.get("mfd_weights")
|
@@ -134,11 +158,12 @@ def atom_model_init(model_name: str, **kwargs):
|
|
134
158
|
elif model_name == AtomicModel.OCR:
|
135
159
|
atom_model = ocr_model_init(
|
136
160
|
kwargs.get("ocr_show_log"),
|
137
|
-
kwargs.get("det_db_box_thresh")
|
161
|
+
kwargs.get("det_db_box_thresh"),
|
162
|
+
kwargs.get("lang")
|
138
163
|
)
|
139
164
|
elif model_name == AtomicModel.Table:
|
140
165
|
atom_model = table_model_init(
|
141
|
-
kwargs.get("
|
166
|
+
kwargs.get("table_model_name"),
|
142
167
|
kwargs.get("table_model_path"),
|
143
168
|
kwargs.get("table_max_time"),
|
144
169
|
kwargs.get("device")
|
@@ -150,6 +175,23 @@ def atom_model_init(model_name: str, **kwargs):
|
|
150
175
|
return atom_model
|
151
176
|
|
152
177
|
|
178
|
+
# Unified crop img logic
|
179
|
+
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
180
|
+
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
181
|
+
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
182
|
+
# Create a white background with an additional width and height of 50
|
183
|
+
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
184
|
+
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
185
|
+
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
186
|
+
|
187
|
+
# Crop image
|
188
|
+
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
189
|
+
cropped_img = input_pil_img.crop(crop_box)
|
190
|
+
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
191
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
192
|
+
return return_image, return_list
|
193
|
+
|
194
|
+
|
153
195
|
class CustomPEKModel:
|
154
196
|
|
155
197
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
@@ -169,22 +211,35 @@ class CustomPEKModel:
|
|
169
211
|
with open(config_path, "r", encoding='utf-8') as f:
|
170
212
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
171
213
|
# 初始化解析配置
|
172
|
-
|
173
|
-
|
214
|
+
|
215
|
+
# layout config
|
216
|
+
self.layout_config = kwargs.get("layout_config")
|
217
|
+
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
|
218
|
+
|
219
|
+
# formula config
|
220
|
+
self.formula_config = kwargs.get("formula_config")
|
221
|
+
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
|
222
|
+
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
|
223
|
+
self.apply_formula = self.formula_config.get("enable", True)
|
224
|
+
|
174
225
|
# table config
|
175
|
-
self.table_config = kwargs.get("table_config"
|
176
|
-
self.apply_table = self.table_config.get("
|
226
|
+
self.table_config = kwargs.get("table_config")
|
227
|
+
self.apply_table = self.table_config.get("enable", False)
|
177
228
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
178
|
-
self.
|
229
|
+
self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
|
230
|
+
|
231
|
+
# ocr config
|
179
232
|
self.apply_ocr = ocr
|
233
|
+
self.lang = kwargs.get("lang", None)
|
234
|
+
|
180
235
|
logger.info(
|
181
|
-
"DocAnalysis init, this may take some times
|
182
|
-
|
236
|
+
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
|
237
|
+
"apply_table: {}, table_model: {}, lang: {}".format(
|
238
|
+
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
|
183
239
|
)
|
184
240
|
)
|
185
|
-
assert self.apply_layout, "DocAnalysis must contain layout model."
|
186
241
|
# 初始化解析方案
|
187
|
-
self.device = kwargs.get("device",
|
242
|
+
self.device = kwargs.get("device", "cpu")
|
188
243
|
logger.info("using device: {}".format(self.device))
|
189
244
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
190
245
|
logger.info("using models_dir: {}".format(models_dir))
|
@@ -193,17 +248,16 @@ class CustomPEKModel:
|
|
193
248
|
|
194
249
|
# 初始化公式识别
|
195
250
|
if self.apply_formula:
|
251
|
+
|
196
252
|
# 初始化公式检测模型
|
197
|
-
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
198
253
|
self.mfd_model = atom_model_manager.get_atom_model(
|
199
254
|
atom_model_name=AtomicModel.MFD,
|
200
|
-
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][
|
255
|
+
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
|
201
256
|
)
|
257
|
+
|
202
258
|
# 初始化公式解析模型
|
203
|
-
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][
|
259
|
+
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
|
204
260
|
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
205
|
-
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
206
|
-
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
207
261
|
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
|
208
262
|
atom_model_name=AtomicModel.MFR,
|
209
263
|
mfr_weight_dir=mfr_weight_dir,
|
@@ -212,17 +266,20 @@ class CustomPEKModel:
|
|
212
266
|
)
|
213
267
|
|
214
268
|
# 初始化layout模型
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
269
|
+
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
270
|
+
self.layout_model = atom_model_manager.get_atom_model(
|
271
|
+
atom_model_name=AtomicModel.Layout,
|
272
|
+
layout_model_name=MODEL_NAME.LAYOUTLMv3,
|
273
|
+
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
|
274
|
+
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
275
|
+
device=self.device
|
276
|
+
)
|
277
|
+
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
278
|
+
self.layout_model = atom_model_manager.get_atom_model(
|
279
|
+
atom_model_name=AtomicModel.Layout,
|
280
|
+
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
281
|
+
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
282
|
+
)
|
226
283
|
# 初始化ocr
|
227
284
|
if self.apply_ocr:
|
228
285
|
|
@@ -230,37 +287,67 @@ class CustomPEKModel:
|
|
230
287
|
self.ocr_model = atom_model_manager.get_atom_model(
|
231
288
|
atom_model_name=AtomicModel.OCR,
|
232
289
|
ocr_show_log=show_log,
|
233
|
-
det_db_box_thresh=0.3
|
290
|
+
det_db_box_thresh=0.3,
|
291
|
+
lang=self.lang
|
234
292
|
)
|
235
293
|
# init table model
|
236
294
|
if self.apply_table:
|
237
|
-
table_model_dir = self.configs["weights"][self.
|
238
|
-
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
239
|
-
# max_time=self.table_max_time, _device_=self.device)
|
295
|
+
table_model_dir = self.configs["weights"][self.table_model_name]
|
240
296
|
self.table_model = atom_model_manager.get_atom_model(
|
241
297
|
atom_model_name=AtomicModel.Table,
|
242
|
-
|
298
|
+
table_model_name=self.table_model_name,
|
243
299
|
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
244
300
|
table_max_time=self.table_max_time,
|
245
301
|
device=self.device
|
246
302
|
)
|
247
303
|
|
304
|
+
home_directory = Path.home()
|
305
|
+
det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
|
306
|
+
rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
|
307
|
+
det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
|
308
|
+
rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
|
309
|
+
|
310
|
+
if not os.path.exists(det_dest_dir):
|
311
|
+
shutil.copytree(det_source, det_dest_dir)
|
312
|
+
if not os.path.exists(rec_dest_dir):
|
313
|
+
shutil.copytree(rec_source, rec_dest_dir)
|
314
|
+
|
248
315
|
logger.info('DocAnalysis init done!')
|
249
316
|
|
250
317
|
def __call__(self, image):
|
251
318
|
|
319
|
+
page_start = time.time()
|
320
|
+
|
252
321
|
latex_filling_list = []
|
253
322
|
mf_image_list = []
|
254
323
|
|
255
324
|
# layout检测
|
256
325
|
layout_start = time.time()
|
257
|
-
|
326
|
+
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
327
|
+
# layoutlmv3
|
328
|
+
layout_res = self.layout_model(image, ignore_catids=[])
|
329
|
+
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
330
|
+
# doclayout_yolo
|
331
|
+
layout_res = []
|
332
|
+
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
333
|
+
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
334
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
335
|
+
new_item = {
|
336
|
+
'category_id': int(cla.item()),
|
337
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
338
|
+
'score': round(float(conf.item()), 3),
|
339
|
+
}
|
340
|
+
layout_res.append(new_item)
|
258
341
|
layout_cost = round(time.time() - layout_start, 2)
|
259
|
-
logger.info(f"layout detection
|
342
|
+
logger.info(f"layout detection time: {layout_cost}")
|
343
|
+
|
344
|
+
pil_img = Image.fromarray(image)
|
260
345
|
|
261
346
|
if self.apply_formula:
|
262
347
|
# 公式检测
|
263
|
-
|
348
|
+
mfd_start = time.time()
|
349
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
350
|
+
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
|
264
351
|
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
265
352
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
266
353
|
new_item = {
|
@@ -271,7 +358,7 @@ class CustomPEKModel:
|
|
271
358
|
}
|
272
359
|
layout_res.append(new_item)
|
273
360
|
latex_filling_list.append(new_item)
|
274
|
-
bbox_img =
|
361
|
+
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
275
362
|
mf_image_list.append(bbox_img)
|
276
363
|
|
277
364
|
# 公式识别
|
@@ -281,7 +368,8 @@ class CustomPEKModel:
|
|
281
368
|
mfr_res = []
|
282
369
|
for mf_img in dataloader:
|
283
370
|
mf_img = mf_img.to(self.device)
|
284
|
-
|
371
|
+
with torch.no_grad():
|
372
|
+
output = self.mfr_model.generate({'image': mf_img})
|
285
373
|
mfr_res.extend(output['pred_str'])
|
286
374
|
for res, latex in zip(latex_filling_list, mfr_res):
|
287
375
|
res['latex'] = latex_rm_whitespace(latex)
|
@@ -303,23 +391,14 @@ class CustomPEKModel:
|
|
303
391
|
elif int(res['category_id']) in [5]:
|
304
392
|
table_res_list.append(res)
|
305
393
|
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
# Crop image
|
316
|
-
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
317
|
-
cropped_img = input_pil_img.crop(crop_box)
|
318
|
-
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
319
|
-
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
320
|
-
return return_image, return_list
|
321
|
-
|
322
|
-
pil_img = Image.fromarray(image)
|
394
|
+
if torch.cuda.is_available() and self.device != 'cpu':
|
395
|
+
properties = torch.cuda.get_device_properties(self.device)
|
396
|
+
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
|
397
|
+
if total_memory <= 10:
|
398
|
+
gc_start = time.time()
|
399
|
+
clean_memory()
|
400
|
+
gc_time = round(time.time() - gc_start, 2)
|
401
|
+
logger.info(f"gc time: {gc_time}")
|
323
402
|
|
324
403
|
# ocr识别
|
325
404
|
if self.apply_ocr:
|
@@ -369,7 +448,7 @@ class CustomPEKModel:
|
|
369
448
|
})
|
370
449
|
|
371
450
|
ocr_cost = round(time.time() - ocr_start, 2)
|
372
|
-
logger.info(f"ocr
|
451
|
+
logger.info(f"ocr time: {ocr_cost}")
|
373
452
|
|
374
453
|
# 表格识别 table recognition
|
375
454
|
if self.apply_table:
|
@@ -377,33 +456,41 @@ class CustomPEKModel:
|
|
377
456
|
for res in table_res_list:
|
378
457
|
new_image, _ = crop_img(res, pil_img)
|
379
458
|
single_table_start_time = time.time()
|
380
|
-
logger.info("------------------table recognition processing begins-----------------")
|
459
|
+
# logger.info("------------------table recognition processing begins-----------------")
|
381
460
|
latex_code = None
|
382
461
|
html_code = None
|
383
|
-
if self.
|
462
|
+
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
384
463
|
with torch.no_grad():
|
385
|
-
|
464
|
+
table_result = self.table_model.predict(new_image, "html")
|
465
|
+
if len(table_result) > 0:
|
466
|
+
html_code = table_result[0]
|
386
467
|
else:
|
387
468
|
html_code = self.table_model.img2html(new_image)
|
388
469
|
|
389
470
|
run_time = time.time() - single_table_start_time
|
390
|
-
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
471
|
+
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
391
472
|
if run_time > self.table_max_time:
|
392
473
|
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
393
474
|
# 判断是否返回正常
|
394
475
|
|
395
476
|
if latex_code:
|
396
|
-
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
|
397
|
-
'end{table}')
|
477
|
+
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
|
398
478
|
if expected_ending:
|
399
479
|
res["latex"] = latex_code
|
400
480
|
else:
|
401
|
-
logger.warning(f"
|
481
|
+
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
|
402
482
|
elif html_code:
|
403
|
-
|
483
|
+
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
|
484
|
+
if expected_ending:
|
485
|
+
res["html"] = html_code
|
486
|
+
else:
|
487
|
+
logger.warning(f"table recognition processing fails, not found expected HTML table end")
|
404
488
|
else:
|
405
|
-
logger.warning(f"
|
406
|
-
|
407
|
-
|
489
|
+
logger.warning(f"table recognition processing fails, not get latex or html return")
|
490
|
+
logger.info(f"table time: {round(time.time() - table_start, 2)}")
|
491
|
+
|
492
|
+
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
|
408
493
|
|
409
494
|
return layout_res
|
495
|
+
|
496
|
+
|
@@ -1,21 +1,45 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import re
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from struct_eqtable import build_model
|
5
|
+
|
6
|
+
|
3
7
|
class StructTableModel:
|
4
|
-
def __init__(self, model_path, max_new_tokens=
|
8
|
+
def __init__(self, model_path, max_new_tokens=1024, max_time=60):
|
5
9
|
# init
|
6
|
-
|
7
|
-
self.
|
8
|
-
|
9
|
-
|
10
|
-
|
10
|
+
assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
|
11
|
+
self.model = build_model(
|
12
|
+
model_ckpt=model_path,
|
13
|
+
max_new_tokens=max_new_tokens,
|
14
|
+
max_time=max_time,
|
15
|
+
lmdeploy=False,
|
16
|
+
flash_attn=False,
|
17
|
+
batch_size=1,
|
18
|
+
).cuda()
|
19
|
+
self.default_format = "html"
|
20
|
+
|
21
|
+
def predict(self, images, output_format=None, **kwargs):
|
22
|
+
|
23
|
+
if output_format is None:
|
24
|
+
output_format = self.default_format
|
11
25
|
else:
|
12
|
-
|
26
|
+
if output_format not in ['latex', 'markdown', 'html']:
|
27
|
+
raise ValueError(f"Output format {output_format} is not supported.")
|
28
|
+
|
29
|
+
results = self.model(
|
30
|
+
images, output_format=output_format
|
31
|
+
)
|
32
|
+
|
33
|
+
if output_format == "html":
|
34
|
+
results = [self.minify_html(html) for html in results]
|
13
35
|
|
14
|
-
|
15
|
-
table_latex = self.model.forward(image)
|
16
|
-
return table_latex
|
36
|
+
return results
|
17
37
|
|
18
|
-
def
|
19
|
-
|
20
|
-
|
21
|
-
|
38
|
+
def minify_html(self, html):
|
39
|
+
# 移除多余的空白字符
|
40
|
+
html = re.sub(r'\s+', ' ', html)
|
41
|
+
# 移除行尾的空白字符
|
42
|
+
html = re.sub(r'\s*>\s*', '>', html)
|
43
|
+
# 移除标签前的空白字符
|
44
|
+
html = re.sub(r'\s*<\s*', '<', html)
|
45
|
+
return html.strip()
|
magic_pdf/model/ppTableModel.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import cv2
|
1
2
|
from paddleocr.ppstructure.table.predict_table import TableSystem
|
2
3
|
from paddleocr.ppstructure.utility import init_args
|
3
4
|
from magic_pdf.libs.Constants import *
|
@@ -36,12 +37,13 @@ class ppTableModel(object):
|
|
36
37
|
- HTML (str): A string representing the HTML structure with content of the table.
|
37
38
|
"""
|
38
39
|
if isinstance(image, Image.Image):
|
39
|
-
image = np.
|
40
|
+
image = np.asarray(image)
|
41
|
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
40
42
|
pred_res, _ = self.table_sys(image)
|
41
43
|
pred_html = pred_res["html"]
|
42
|
-
res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
|
43
|
-
|
44
|
-
return
|
44
|
+
# res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
|
45
|
+
# "</table></body></html>","") + "</table></td>\n"
|
46
|
+
return pred_html
|
45
47
|
|
46
48
|
def parse_args(self, **kwargs):
|
47
49
|
parser = init_args()
|
@@ -52,11 +54,11 @@ class ppTableModel(object):
|
|
52
54
|
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
|
53
55
|
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
|
54
56
|
device = kwargs.get("device", "cpu")
|
55
|
-
use_gpu = True if device
|
57
|
+
use_gpu = True if device.startswith("cuda") else False
|
56
58
|
config = {
|
57
59
|
"use_gpu": use_gpu,
|
58
60
|
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
|
59
|
-
"table_algorithm":
|
61
|
+
"table_algorithm": "TableMaster",
|
60
62
|
"table_model_dir": table_model_dir,
|
61
63
|
"table_char_dict_path": table_char_dict_path,
|
62
64
|
"det_model_dir": det_model_dir,
|
@@ -18,8 +18,11 @@ def region_to_bbox(region):
|
|
18
18
|
|
19
19
|
|
20
20
|
class CustomPaddleModel:
|
21
|
-
def __init__(self, ocr: bool = False, show_log: bool = False):
|
22
|
-
|
21
|
+
def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
|
22
|
+
if lang is not None:
|
23
|
+
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
|
24
|
+
else:
|
25
|
+
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
|
23
26
|
|
24
27
|
def __call__(self, img):
|
25
28
|
try:
|
File without changes
|