magic-pdf 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +1 -1
- magic_pdf/libs/Constants.py +3 -1
- magic_pdf/libs/config_reader.py +1 -1
- magic_pdf/libs/draw_bbox.py +10 -4
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/pdf_extract_kit.py +42 -297
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
- magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
- magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
- magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
- magic_pdf/model/sub_modules/model_init.py +144 -0
- magic_pdf/model/sub_modules/model_utils.py +51 -0
- magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
- magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
- magic_pdf/model/sub_modules/table/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
- magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
- magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
- magic_pdf/model/sub_modules/table/table_utils.py +11 -0
- magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
- magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
- magic_pdf/para/para_split_v3.py +13 -15
- magic_pdf/pdf_parse_union_core_v2.py +56 -19
- magic_pdf/resources/model_config/model_configs.yaml +2 -1
- magic_pdf/tools/common.py +47 -3
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +9 -3
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
- magic_pdf/model/pek_sub_modules/post_process.py +0 -36
- magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
- /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
|
|
168
168
|
# 如果是前一行带有-连字符,那么末尾不应该加空格
|
169
169
|
if __is_hyphen_at_line_end(content):
|
170
170
|
para_text += content[:-1]
|
171
|
-
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']:
|
171
|
+
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
|
172
172
|
para_text += content
|
173
173
|
else: # 西方文本语境下 content间需要空格分隔
|
174
174
|
para_text += f"{content} "
|
magic_pdf/libs/Constants.py
CHANGED
magic_pdf/libs/config_reader.py
CHANGED
@@ -92,7 +92,7 @@ def get_table_recog_config():
|
|
92
92
|
table_config = config.get('table-config')
|
93
93
|
if table_config is None:
|
94
94
|
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
|
95
|
-
return json.loads(f'{{"model": "{MODEL_NAME.
|
95
|
+
return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
|
96
96
|
else:
|
97
97
|
return table_config
|
98
98
|
|
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
369
369
|
if block['type'] in [BlockType.Image, BlockType.Table]:
|
370
370
|
for sub_block in block['blocks']:
|
371
371
|
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
372
|
+
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
|
373
|
+
for line in sub_block['virtual_lines']:
|
374
|
+
bbox = line['bbox']
|
375
|
+
index = line['index']
|
376
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
377
|
+
else:
|
378
|
+
for line in sub_block['lines']:
|
379
|
+
bbox = line['bbox']
|
380
|
+
index = line['index']
|
381
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
376
382
|
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
|
377
383
|
for line in sub_block['lines']:
|
378
384
|
bbox = line['bbox']
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.9.
|
1
|
+
__version__ = "0.9.3"
|
@@ -1,195 +1,28 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
1
3
|
from loguru import logger
|
2
4
|
import os
|
3
5
|
import time
|
4
|
-
|
5
|
-
import
|
6
|
-
from
|
7
|
-
from magic_pdf.libs.clean_memory import clean_memory
|
8
|
-
from magic_pdf.model.model_list import AtomicModel
|
6
|
+
import cv2
|
7
|
+
import yaml
|
8
|
+
from PIL import Image
|
9
9
|
|
10
10
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
11
11
|
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
12
|
+
|
12
13
|
try:
|
13
|
-
import cv2
|
14
|
-
import yaml
|
15
|
-
import argparse
|
16
|
-
import numpy as np
|
17
|
-
import torch
|
18
14
|
import torchtext
|
19
15
|
|
20
16
|
if torchtext.__version__ >= "0.18.0":
|
21
17
|
torchtext.disable_torchtext_deprecation_warning()
|
22
|
-
|
23
|
-
|
24
|
-
from torch.utils.data import Dataset, DataLoader
|
25
|
-
from ultralytics import YOLO
|
26
|
-
from unimernet.common.config import Config
|
27
|
-
import unimernet.tasks as tasks
|
28
|
-
from unimernet.processors import load_processor
|
29
|
-
from doclayout_yolo import YOLOv10
|
30
|
-
|
31
|
-
except ImportError as e:
|
32
|
-
logger.exception(e)
|
33
|
-
logger.error(
|
34
|
-
'Required dependency not installed, please install by \n'
|
35
|
-
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
36
|
-
exit(1)
|
37
|
-
|
38
|
-
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
39
|
-
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
40
|
-
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
41
|
-
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
42
|
-
from magic_pdf.model.ppTableModel import ppTableModel
|
43
|
-
|
44
|
-
|
45
|
-
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
46
|
-
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
47
|
-
table_model = StructTableModel(model_path, max_time=max_time)
|
48
|
-
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
49
|
-
config = {
|
50
|
-
"model_dir": model_path,
|
51
|
-
"device": _device_
|
52
|
-
}
|
53
|
-
table_model = ppTableModel(config)
|
54
|
-
else:
|
55
|
-
logger.error("table model type not allow")
|
56
|
-
exit(1)
|
57
|
-
return table_model
|
58
|
-
|
59
|
-
|
60
|
-
def mfd_model_init(weight):
|
61
|
-
mfd_model = YOLO(weight)
|
62
|
-
return mfd_model
|
63
|
-
|
64
|
-
|
65
|
-
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
66
|
-
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
67
|
-
cfg = Config(args)
|
68
|
-
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
69
|
-
cfg.config.model.model_config.model_name = weight_dir
|
70
|
-
cfg.config.model.tokenizer_config.path = weight_dir
|
71
|
-
task = tasks.setup_task(cfg)
|
72
|
-
model = task.build_model(cfg)
|
73
|
-
model.to(_device_)
|
74
|
-
model.eval()
|
75
|
-
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
76
|
-
mfr_transform = transforms.Compose([vis_processor, ])
|
77
|
-
return [model, mfr_transform]
|
78
|
-
|
79
|
-
|
80
|
-
def layout_model_init(weight, config_file, device):
|
81
|
-
model = Layoutlmv3_Predictor(weight, config_file, device)
|
82
|
-
return model
|
83
|
-
|
84
|
-
|
85
|
-
def doclayout_yolo_model_init(weight):
|
86
|
-
model = YOLOv10(weight)
|
87
|
-
return model
|
88
|
-
|
89
|
-
|
90
|
-
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
|
91
|
-
if lang is not None:
|
92
|
-
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
93
|
-
else:
|
94
|
-
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
95
|
-
return model
|
96
|
-
|
97
|
-
|
98
|
-
class MathDataset(Dataset):
|
99
|
-
def __init__(self, image_paths, transform=None):
|
100
|
-
self.image_paths = image_paths
|
101
|
-
self.transform = transform
|
102
|
-
|
103
|
-
def __len__(self):
|
104
|
-
return len(self.image_paths)
|
105
|
-
|
106
|
-
def __getitem__(self, idx):
|
107
|
-
# if not pil image, then convert to pil image
|
108
|
-
if isinstance(self.image_paths[idx], str):
|
109
|
-
raw_image = Image.open(self.image_paths[idx])
|
110
|
-
else:
|
111
|
-
raw_image = self.image_paths[idx]
|
112
|
-
if self.transform:
|
113
|
-
image = self.transform(raw_image)
|
114
|
-
return image
|
115
|
-
|
116
|
-
|
117
|
-
class AtomModelSingleton:
|
118
|
-
_instance = None
|
119
|
-
_models = {}
|
120
|
-
|
121
|
-
def __new__(cls, *args, **kwargs):
|
122
|
-
if cls._instance is None:
|
123
|
-
cls._instance = super().__new__(cls)
|
124
|
-
return cls._instance
|
125
|
-
|
126
|
-
def get_atom_model(self, atom_model_name: str, **kwargs):
|
127
|
-
lang = kwargs.get("lang", None)
|
128
|
-
layout_model_name = kwargs.get("layout_model_name", None)
|
129
|
-
key = (atom_model_name, layout_model_name, lang)
|
130
|
-
if key not in self._models:
|
131
|
-
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
132
|
-
return self._models[key]
|
133
|
-
|
134
|
-
|
135
|
-
def atom_model_init(model_name: str, **kwargs):
|
136
|
-
|
137
|
-
if model_name == AtomicModel.Layout:
|
138
|
-
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
139
|
-
atom_model = layout_model_init(
|
140
|
-
kwargs.get("layout_weights"),
|
141
|
-
kwargs.get("layout_config_file"),
|
142
|
-
kwargs.get("device")
|
143
|
-
)
|
144
|
-
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
145
|
-
atom_model = doclayout_yolo_model_init(
|
146
|
-
kwargs.get("doclayout_yolo_weights"),
|
147
|
-
)
|
148
|
-
elif model_name == AtomicModel.MFD:
|
149
|
-
atom_model = mfd_model_init(
|
150
|
-
kwargs.get("mfd_weights")
|
151
|
-
)
|
152
|
-
elif model_name == AtomicModel.MFR:
|
153
|
-
atom_model = mfr_model_init(
|
154
|
-
kwargs.get("mfr_weight_dir"),
|
155
|
-
kwargs.get("mfr_cfg_path"),
|
156
|
-
kwargs.get("device")
|
157
|
-
)
|
158
|
-
elif model_name == AtomicModel.OCR:
|
159
|
-
atom_model = ocr_model_init(
|
160
|
-
kwargs.get("ocr_show_log"),
|
161
|
-
kwargs.get("det_db_box_thresh"),
|
162
|
-
kwargs.get("lang")
|
163
|
-
)
|
164
|
-
elif model_name == AtomicModel.Table:
|
165
|
-
atom_model = table_model_init(
|
166
|
-
kwargs.get("table_model_name"),
|
167
|
-
kwargs.get("table_model_path"),
|
168
|
-
kwargs.get("table_max_time"),
|
169
|
-
kwargs.get("device")
|
170
|
-
)
|
171
|
-
else:
|
172
|
-
logger.error("model name not allow")
|
173
|
-
exit(1)
|
174
|
-
|
175
|
-
return atom_model
|
176
|
-
|
18
|
+
except ImportError:
|
19
|
+
pass
|
177
20
|
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
184
|
-
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
185
|
-
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
186
|
-
|
187
|
-
# Crop image
|
188
|
-
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
189
|
-
cropped_img = input_pil_img.crop(crop_box)
|
190
|
-
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
191
|
-
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
192
|
-
return return_image, return_list
|
21
|
+
from magic_pdf.libs.Constants import *
|
22
|
+
from magic_pdf.model.model_list import AtomicModel
|
23
|
+
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
24
|
+
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
|
25
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
|
193
26
|
|
194
27
|
|
195
28
|
class CustomPEKModel:
|
@@ -226,7 +59,7 @@ class CustomPEKModel:
|
|
226
59
|
self.table_config = kwargs.get("table_config")
|
227
60
|
self.apply_table = self.table_config.get("enable", False)
|
228
61
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
229
|
-
self.table_model_name = self.table_config.get("model", MODEL_NAME.
|
62
|
+
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
|
230
63
|
|
231
64
|
# ocr config
|
232
65
|
self.apply_ocr = ocr
|
@@ -235,7 +68,8 @@ class CustomPEKModel:
|
|
235
68
|
logger.info(
|
236
69
|
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
|
237
70
|
"apply_table: {}, table_model: {}, lang: {}".format(
|
238
|
-
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
|
71
|
+
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
|
72
|
+
self.lang
|
239
73
|
)
|
240
74
|
)
|
241
75
|
# 初始化解析方案
|
@@ -248,17 +82,17 @@ class CustomPEKModel:
|
|
248
82
|
|
249
83
|
# 初始化公式识别
|
250
84
|
if self.apply_formula:
|
251
|
-
|
252
85
|
# 初始化公式检测模型
|
253
86
|
self.mfd_model = atom_model_manager.get_atom_model(
|
254
87
|
atom_model_name=AtomicModel.MFD,
|
255
|
-
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
|
88
|
+
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
|
89
|
+
device=self.device
|
256
90
|
)
|
257
91
|
|
258
92
|
# 初始化公式解析模型
|
259
93
|
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
|
260
94
|
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
261
|
-
self.mfr_model
|
95
|
+
self.mfr_model = atom_model_manager.get_atom_model(
|
262
96
|
atom_model_name=AtomicModel.MFR,
|
263
97
|
mfr_weight_dir=mfr_weight_dir,
|
264
98
|
mfr_cfg_path=mfr_cfg_path,
|
@@ -278,7 +112,8 @@ class CustomPEKModel:
|
|
278
112
|
self.layout_model = atom_model_manager.get_atom_model(
|
279
113
|
atom_model_name=AtomicModel.Layout,
|
280
114
|
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
281
|
-
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
115
|
+
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
|
116
|
+
device=self.device
|
282
117
|
)
|
283
118
|
# 初始化ocr
|
284
119
|
if self.apply_ocr:
|
@@ -305,26 +140,15 @@ class CustomPEKModel:
|
|
305
140
|
|
306
141
|
page_start = time.time()
|
307
142
|
|
308
|
-
latex_filling_list = []
|
309
|
-
mf_image_list = []
|
310
|
-
|
311
143
|
# layout检测
|
312
144
|
layout_start = time.time()
|
145
|
+
layout_res = []
|
313
146
|
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
314
147
|
# layoutlmv3
|
315
148
|
layout_res = self.layout_model(image, ignore_catids=[])
|
316
149
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
317
150
|
# doclayout_yolo
|
318
|
-
layout_res =
|
319
|
-
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
320
|
-
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
321
|
-
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
322
|
-
new_item = {
|
323
|
-
'category_id': int(cla.item()),
|
324
|
-
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
325
|
-
'score': round(float(conf.item()), 3),
|
326
|
-
}
|
327
|
-
layout_res.append(new_item)
|
151
|
+
layout_res = self.layout_model.predict(image)
|
328
152
|
layout_cost = round(time.time() - layout_start, 2)
|
329
153
|
logger.info(f"layout detection time: {layout_cost}")
|
330
154
|
|
@@ -333,59 +157,21 @@ class CustomPEKModel:
|
|
333
157
|
if self.apply_formula:
|
334
158
|
# 公式检测
|
335
159
|
mfd_start = time.time()
|
336
|
-
mfd_res = self.mfd_model.predict(image
|
160
|
+
mfd_res = self.mfd_model.predict(image)
|
337
161
|
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
|
338
|
-
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
339
|
-
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
340
|
-
new_item = {
|
341
|
-
'category_id': 13 + int(cla.item()),
|
342
|
-
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
343
|
-
'score': round(float(conf.item()), 2),
|
344
|
-
'latex': '',
|
345
|
-
}
|
346
|
-
layout_res.append(new_item)
|
347
|
-
latex_filling_list.append(new_item)
|
348
|
-
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
349
|
-
mf_image_list.append(bbox_img)
|
350
162
|
|
351
163
|
# 公式识别
|
352
164
|
mfr_start = time.time()
|
353
|
-
|
354
|
-
|
355
|
-
mfr_res = []
|
356
|
-
for mf_img in dataloader:
|
357
|
-
mf_img = mf_img.to(self.device)
|
358
|
-
with torch.no_grad():
|
359
|
-
output = self.mfr_model.generate({'image': mf_img})
|
360
|
-
mfr_res.extend(output['pred_str'])
|
361
|
-
for res, latex in zip(latex_filling_list, mfr_res):
|
362
|
-
res['latex'] = latex_rm_whitespace(latex)
|
165
|
+
formula_list = self.mfr_model.predict(mfd_res, image)
|
166
|
+
layout_res.extend(formula_list)
|
363
167
|
mfr_cost = round(time.time() - mfr_start, 2)
|
364
|
-
logger.info(f"formula nums: {len(
|
365
|
-
|
366
|
-
#
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
if int(res['category_id']) in [13, 14]:
|
372
|
-
single_page_mfdetrec_res.append({
|
373
|
-
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
374
|
-
int(res['poly'][4]), int(res['poly'][5])],
|
375
|
-
})
|
376
|
-
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
377
|
-
ocr_res_list.append(res)
|
378
|
-
elif int(res['category_id']) in [5]:
|
379
|
-
table_res_list.append(res)
|
380
|
-
|
381
|
-
if torch.cuda.is_available() and self.device != 'cpu':
|
382
|
-
properties = torch.cuda.get_device_properties(self.device)
|
383
|
-
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
|
384
|
-
if total_memory <= 10:
|
385
|
-
gc_start = time.time()
|
386
|
-
clean_memory()
|
387
|
-
gc_time = round(time.time() - gc_start, 2)
|
388
|
-
logger.info(f"gc time: {gc_time}")
|
168
|
+
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
|
169
|
+
|
170
|
+
# 清理显存
|
171
|
+
clean_vram(self.device, vram_threshold=8)
|
172
|
+
|
173
|
+
# 从layout_res中获取ocr区域、表格区域、公式区域
|
174
|
+
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
|
389
175
|
|
390
176
|
# ocr识别
|
391
177
|
if self.apply_ocr:
|
@@ -393,23 +179,7 @@ class CustomPEKModel:
|
|
393
179
|
# Process each area that requires OCR processing
|
394
180
|
for res in ocr_res_list:
|
395
181
|
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
396
|
-
|
397
|
-
# Adjust the coordinates of the formula area
|
398
|
-
adjusted_mfdetrec_res = []
|
399
|
-
for mf_res in single_page_mfdetrec_res:
|
400
|
-
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
401
|
-
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
402
|
-
x0 = mf_xmin - xmin + paste_x
|
403
|
-
y0 = mf_ymin - ymin + paste_y
|
404
|
-
x1 = mf_xmax - xmin + paste_x
|
405
|
-
y1 = mf_ymax - ymin + paste_y
|
406
|
-
# Filter formula blocks outside the graph
|
407
|
-
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
408
|
-
continue
|
409
|
-
else:
|
410
|
-
adjusted_mfdetrec_res.append({
|
411
|
-
"bbox": [x0, y0, x1, y1],
|
412
|
-
})
|
182
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
413
183
|
|
414
184
|
# OCR recognition
|
415
185
|
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
@@ -417,22 +187,8 @@ class CustomPEKModel:
|
|
417
187
|
|
418
188
|
# Integration results
|
419
189
|
if ocr_res:
|
420
|
-
|
421
|
-
|
422
|
-
text, score = box_ocr_res[1]
|
423
|
-
|
424
|
-
# Convert the coordinates back to the original coordinate system
|
425
|
-
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
426
|
-
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
427
|
-
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
428
|
-
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
429
|
-
|
430
|
-
layout_res.append({
|
431
|
-
'category_id': 15,
|
432
|
-
'poly': p1 + p2 + p3 + p4,
|
433
|
-
'score': round(score, 2),
|
434
|
-
'text': text,
|
435
|
-
})
|
190
|
+
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
191
|
+
layout_res.extend(ocr_result_list)
|
436
192
|
|
437
193
|
ocr_cost = round(time.time() - ocr_start, 2)
|
438
194
|
logger.info(f"ocr time: {ocr_cost}")
|
@@ -443,41 +199,30 @@ class CustomPEKModel:
|
|
443
199
|
for res in table_res_list:
|
444
200
|
new_image, _ = crop_img(res, pil_img)
|
445
201
|
single_table_start_time = time.time()
|
446
|
-
# logger.info("------------------table recognition processing begins-----------------")
|
447
|
-
latex_code = None
|
448
202
|
html_code = None
|
449
203
|
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
450
204
|
with torch.no_grad():
|
451
205
|
table_result = self.table_model.predict(new_image, "html")
|
452
206
|
if len(table_result) > 0:
|
453
207
|
html_code = table_result[0]
|
454
|
-
|
208
|
+
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
455
209
|
html_code = self.table_model.img2html(new_image)
|
456
|
-
|
210
|
+
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
211
|
+
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
|
457
212
|
run_time = time.time() - single_table_start_time
|
458
|
-
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
459
213
|
if run_time > self.table_max_time:
|
460
|
-
logger.warning(f"
|
214
|
+
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
|
461
215
|
# 判断是否返回正常
|
462
|
-
|
463
|
-
if latex_code:
|
464
|
-
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
|
465
|
-
if expected_ending:
|
466
|
-
res["latex"] = latex_code
|
467
|
-
else:
|
468
|
-
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
|
469
|
-
elif html_code:
|
216
|
+
if html_code:
|
470
217
|
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
|
471
218
|
if expected_ending:
|
472
219
|
res["html"] = html_code
|
473
220
|
else:
|
474
221
|
logger.warning(f"table recognition processing fails, not found expected HTML table end")
|
475
222
|
else:
|
476
|
-
logger.warning(f"table recognition processing fails, not get
|
223
|
+
logger.warning(f"table recognition processing fails, not get html return")
|
477
224
|
logger.info(f"table time: {round(time.time() - table_start, 2)}")
|
478
225
|
|
479
226
|
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
|
480
227
|
|
481
228
|
return layout_res
|
482
|
-
|
483
|
-
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from doclayout_yolo import YOLOv10
|
2
|
+
|
3
|
+
|
4
|
+
class DocLayoutYOLOModel(object):
|
5
|
+
def __init__(self, weight, device):
|
6
|
+
self.model = YOLOv10(weight)
|
7
|
+
self.device = device
|
8
|
+
|
9
|
+
def predict(self, image):
|
10
|
+
layout_res = []
|
11
|
+
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
12
|
+
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
|
13
|
+
doclayout_yolo_res.boxes.cls.cpu()):
|
14
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
15
|
+
new_item = {
|
16
|
+
'category_id': int(cla.item()),
|
17
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
18
|
+
'score': round(float(conf.item()), 3),
|
19
|
+
}
|
20
|
+
layout_res.append(new_item)
|
21
|
+
return layout_res
|
File without changes
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from ultralytics import YOLO
|
2
|
+
|
3
|
+
|
4
|
+
class YOLOv8MFDModel(object):
|
5
|
+
def __init__(self, weight, device='cpu'):
|
6
|
+
self.mfd_model = YOLO(weight)
|
7
|
+
self.device = device
|
8
|
+
|
9
|
+
def predict(self, image):
|
10
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
11
|
+
return mfd_res
|
12
|
+
|
File without changes
|
File without changes
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import os
|
2
|
+
import argparse
|
3
|
+
import re
|
4
|
+
|
5
|
+
from PIL import Image
|
6
|
+
import torch
|
7
|
+
from torch.utils.data import Dataset, DataLoader
|
8
|
+
from torchvision import transforms
|
9
|
+
from unimernet.common.config import Config
|
10
|
+
import unimernet.tasks as tasks
|
11
|
+
from unimernet.processors import load_processor
|
12
|
+
|
13
|
+
|
14
|
+
class MathDataset(Dataset):
|
15
|
+
def __init__(self, image_paths, transform=None):
|
16
|
+
self.image_paths = image_paths
|
17
|
+
self.transform = transform
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return len(self.image_paths)
|
21
|
+
|
22
|
+
def __getitem__(self, idx):
|
23
|
+
# if not pil image, then convert to pil image
|
24
|
+
if isinstance(self.image_paths[idx], str):
|
25
|
+
raw_image = Image.open(self.image_paths[idx])
|
26
|
+
else:
|
27
|
+
raw_image = self.image_paths[idx]
|
28
|
+
if self.transform:
|
29
|
+
image = self.transform(raw_image)
|
30
|
+
return image
|
31
|
+
|
32
|
+
|
33
|
+
def latex_rm_whitespace(s: str):
|
34
|
+
"""Remove unnecessary whitespace from LaTeX code.
|
35
|
+
"""
|
36
|
+
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
37
|
+
letter = '[a-zA-Z]'
|
38
|
+
noletter = '[\W_^\d]'
|
39
|
+
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
40
|
+
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
41
|
+
news = s
|
42
|
+
while True:
|
43
|
+
s = news
|
44
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
45
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
46
|
+
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
47
|
+
if news == s:
|
48
|
+
break
|
49
|
+
return s
|
50
|
+
|
51
|
+
|
52
|
+
class UnimernetModel(object):
|
53
|
+
def __init__(self, weight_dir, cfg_path, _device_='cpu'):
|
54
|
+
|
55
|
+
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
56
|
+
cfg = Config(args)
|
57
|
+
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
58
|
+
cfg.config.model.model_config.model_name = weight_dir
|
59
|
+
cfg.config.model.tokenizer_config.path = weight_dir
|
60
|
+
task = tasks.setup_task(cfg)
|
61
|
+
self.model = task.build_model(cfg)
|
62
|
+
self.device = _device_
|
63
|
+
self.model.to(_device_)
|
64
|
+
self.model.eval()
|
65
|
+
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
66
|
+
self.mfr_transform = transforms.Compose([vis_processor, ])
|
67
|
+
|
68
|
+
def predict(self, mfd_res, image):
|
69
|
+
|
70
|
+
formula_list = []
|
71
|
+
mf_image_list = []
|
72
|
+
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
73
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
74
|
+
new_item = {
|
75
|
+
'category_id': 13 + int(cla.item()),
|
76
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
77
|
+
'score': round(float(conf.item()), 2),
|
78
|
+
'latex': '',
|
79
|
+
}
|
80
|
+
formula_list.append(new_item)
|
81
|
+
pil_img = Image.fromarray(image)
|
82
|
+
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
83
|
+
mf_image_list.append(bbox_img)
|
84
|
+
|
85
|
+
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
86
|
+
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
87
|
+
mfr_res = []
|
88
|
+
for mf_img in dataloader:
|
89
|
+
mf_img = mf_img.to(self.device)
|
90
|
+
with torch.no_grad():
|
91
|
+
output = self.model.generate({'image': mf_img})
|
92
|
+
mfr_res.extend(output['pred_str'])
|
93
|
+
for res, latex in zip(formula_list, mfr_res):
|
94
|
+
res['latex'] = latex_rm_whitespace(latex)
|
95
|
+
return formula_list
|
96
|
+
|
97
|
+
|
98
|
+
|
File without changes
|