magic-pdf 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +1 -1
- magic_pdf/libs/Constants.py +3 -1
- magic_pdf/libs/config_reader.py +1 -1
- magic_pdf/libs/draw_bbox.py +10 -4
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/pdf_extract_kit.py +42 -310
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
- magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
- magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
- magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
- magic_pdf/model/sub_modules/model_init.py +144 -0
- magic_pdf/model/sub_modules/model_utils.py +51 -0
- magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
- magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
- magic_pdf/model/sub_modules/table/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
- magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
- magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
- magic_pdf/model/sub_modules/table/table_utils.py +11 -0
- magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
- magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
- magic_pdf/para/para_split_v3.py +13 -15
- magic_pdf/pdf_parse_union_core_v2.py +56 -19
- magic_pdf/resources/model_config/model_configs.yaml +2 -1
- magic_pdf/tools/common.py +47 -3
- {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +35 -25
- {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
- {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
- magic_pdf/model/pek_sub_modules/post_process.py +0 -36
- magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
- /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
- {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
|
|
168
168
|
# 如果是前一行带有-连字符,那么末尾不应该加空格
|
169
169
|
if __is_hyphen_at_line_end(content):
|
170
170
|
para_text += content[:-1]
|
171
|
-
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']:
|
171
|
+
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
|
172
172
|
para_text += content
|
173
173
|
else: # 西方文本语境下 content间需要空格分隔
|
174
174
|
para_text += f"{content} "
|
magic_pdf/libs/Constants.py
CHANGED
magic_pdf/libs/config_reader.py
CHANGED
@@ -92,7 +92,7 @@ def get_table_recog_config():
|
|
92
92
|
table_config = config.get('table-config')
|
93
93
|
if table_config is None:
|
94
94
|
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
|
95
|
-
return json.loads(f'{{"model": "{MODEL_NAME.
|
95
|
+
return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
|
96
96
|
else:
|
97
97
|
return table_config
|
98
98
|
|
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
369
369
|
if block['type'] in [BlockType.Image, BlockType.Table]:
|
370
370
|
for sub_block in block['blocks']:
|
371
371
|
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
372
|
+
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
|
373
|
+
for line in sub_block['virtual_lines']:
|
374
|
+
bbox = line['bbox']
|
375
|
+
index = line['index']
|
376
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
377
|
+
else:
|
378
|
+
for line in sub_block['lines']:
|
379
|
+
bbox = line['bbox']
|
380
|
+
index = line['index']
|
381
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
376
382
|
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
|
377
383
|
for line in sub_block['lines']:
|
378
384
|
bbox = line['bbox']
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.9.
|
1
|
+
__version__ = "0.9.3"
|
@@ -1,195 +1,28 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
1
3
|
from loguru import logger
|
2
4
|
import os
|
3
5
|
import time
|
4
|
-
|
5
|
-
import
|
6
|
-
from
|
7
|
-
from magic_pdf.libs.clean_memory import clean_memory
|
8
|
-
from magic_pdf.model.model_list import AtomicModel
|
6
|
+
import cv2
|
7
|
+
import yaml
|
8
|
+
from PIL import Image
|
9
9
|
|
10
10
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
11
11
|
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
12
|
+
|
12
13
|
try:
|
13
|
-
import cv2
|
14
|
-
import yaml
|
15
|
-
import argparse
|
16
|
-
import numpy as np
|
17
|
-
import torch
|
18
14
|
import torchtext
|
19
15
|
|
20
16
|
if torchtext.__version__ >= "0.18.0":
|
21
17
|
torchtext.disable_torchtext_deprecation_warning()
|
22
|
-
|
23
|
-
|
24
|
-
from torch.utils.data import Dataset, DataLoader
|
25
|
-
from ultralytics import YOLO
|
26
|
-
from unimernet.common.config import Config
|
27
|
-
import unimernet.tasks as tasks
|
28
|
-
from unimernet.processors import load_processor
|
29
|
-
from doclayout_yolo import YOLOv10
|
30
|
-
|
31
|
-
except ImportError as e:
|
32
|
-
logger.exception(e)
|
33
|
-
logger.error(
|
34
|
-
'Required dependency not installed, please install by \n'
|
35
|
-
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
36
|
-
exit(1)
|
37
|
-
|
38
|
-
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
39
|
-
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
|
40
|
-
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
41
|
-
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
42
|
-
from magic_pdf.model.ppTableModel import ppTableModel
|
43
|
-
|
44
|
-
|
45
|
-
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
46
|
-
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
47
|
-
table_model = StructTableModel(model_path, max_time=max_time)
|
48
|
-
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
49
|
-
config = {
|
50
|
-
"model_dir": model_path,
|
51
|
-
"device": _device_
|
52
|
-
}
|
53
|
-
table_model = ppTableModel(config)
|
54
|
-
else:
|
55
|
-
logger.error("table model type not allow")
|
56
|
-
exit(1)
|
57
|
-
return table_model
|
58
|
-
|
59
|
-
|
60
|
-
def mfd_model_init(weight):
|
61
|
-
mfd_model = YOLO(weight)
|
62
|
-
return mfd_model
|
63
|
-
|
64
|
-
|
65
|
-
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
66
|
-
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
67
|
-
cfg = Config(args)
|
68
|
-
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
69
|
-
cfg.config.model.model_config.model_name = weight_dir
|
70
|
-
cfg.config.model.tokenizer_config.path = weight_dir
|
71
|
-
task = tasks.setup_task(cfg)
|
72
|
-
model = task.build_model(cfg)
|
73
|
-
model.to(_device_)
|
74
|
-
model.eval()
|
75
|
-
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
76
|
-
mfr_transform = transforms.Compose([vis_processor, ])
|
77
|
-
return [model, mfr_transform]
|
78
|
-
|
79
|
-
|
80
|
-
def layout_model_init(weight, config_file, device):
|
81
|
-
model = Layoutlmv3_Predictor(weight, config_file, device)
|
82
|
-
return model
|
83
|
-
|
84
|
-
|
85
|
-
def doclayout_yolo_model_init(weight):
|
86
|
-
model = YOLOv10(weight)
|
87
|
-
return model
|
88
|
-
|
89
|
-
|
90
|
-
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
|
91
|
-
if lang is not None:
|
92
|
-
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
93
|
-
else:
|
94
|
-
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
95
|
-
return model
|
96
|
-
|
97
|
-
|
98
|
-
class MathDataset(Dataset):
|
99
|
-
def __init__(self, image_paths, transform=None):
|
100
|
-
self.image_paths = image_paths
|
101
|
-
self.transform = transform
|
102
|
-
|
103
|
-
def __len__(self):
|
104
|
-
return len(self.image_paths)
|
105
|
-
|
106
|
-
def __getitem__(self, idx):
|
107
|
-
# if not pil image, then convert to pil image
|
108
|
-
if isinstance(self.image_paths[idx], str):
|
109
|
-
raw_image = Image.open(self.image_paths[idx])
|
110
|
-
else:
|
111
|
-
raw_image = self.image_paths[idx]
|
112
|
-
if self.transform:
|
113
|
-
image = self.transform(raw_image)
|
114
|
-
return image
|
115
|
-
|
116
|
-
|
117
|
-
class AtomModelSingleton:
|
118
|
-
_instance = None
|
119
|
-
_models = {}
|
120
|
-
|
121
|
-
def __new__(cls, *args, **kwargs):
|
122
|
-
if cls._instance is None:
|
123
|
-
cls._instance = super().__new__(cls)
|
124
|
-
return cls._instance
|
125
|
-
|
126
|
-
def get_atom_model(self, atom_model_name: str, **kwargs):
|
127
|
-
lang = kwargs.get("lang", None)
|
128
|
-
layout_model_name = kwargs.get("layout_model_name", None)
|
129
|
-
key = (atom_model_name, layout_model_name, lang)
|
130
|
-
if key not in self._models:
|
131
|
-
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
132
|
-
return self._models[key]
|
133
|
-
|
134
|
-
|
135
|
-
def atom_model_init(model_name: str, **kwargs):
|
136
|
-
|
137
|
-
if model_name == AtomicModel.Layout:
|
138
|
-
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
139
|
-
atom_model = layout_model_init(
|
140
|
-
kwargs.get("layout_weights"),
|
141
|
-
kwargs.get("layout_config_file"),
|
142
|
-
kwargs.get("device")
|
143
|
-
)
|
144
|
-
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
145
|
-
atom_model = doclayout_yolo_model_init(
|
146
|
-
kwargs.get("doclayout_yolo_weights"),
|
147
|
-
)
|
148
|
-
elif model_name == AtomicModel.MFD:
|
149
|
-
atom_model = mfd_model_init(
|
150
|
-
kwargs.get("mfd_weights")
|
151
|
-
)
|
152
|
-
elif model_name == AtomicModel.MFR:
|
153
|
-
atom_model = mfr_model_init(
|
154
|
-
kwargs.get("mfr_weight_dir"),
|
155
|
-
kwargs.get("mfr_cfg_path"),
|
156
|
-
kwargs.get("device")
|
157
|
-
)
|
158
|
-
elif model_name == AtomicModel.OCR:
|
159
|
-
atom_model = ocr_model_init(
|
160
|
-
kwargs.get("ocr_show_log"),
|
161
|
-
kwargs.get("det_db_box_thresh"),
|
162
|
-
kwargs.get("lang")
|
163
|
-
)
|
164
|
-
elif model_name == AtomicModel.Table:
|
165
|
-
atom_model = table_model_init(
|
166
|
-
kwargs.get("table_model_name"),
|
167
|
-
kwargs.get("table_model_path"),
|
168
|
-
kwargs.get("table_max_time"),
|
169
|
-
kwargs.get("device")
|
170
|
-
)
|
171
|
-
else:
|
172
|
-
logger.error("model name not allow")
|
173
|
-
exit(1)
|
174
|
-
|
175
|
-
return atom_model
|
176
|
-
|
177
|
-
|
178
|
-
# Unified crop img logic
|
179
|
-
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
180
|
-
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
181
|
-
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
182
|
-
# Create a white background with an additional width and height of 50
|
183
|
-
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
184
|
-
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
185
|
-
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
18
|
+
except ImportError:
|
19
|
+
pass
|
186
20
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
return return_image, return_list
|
21
|
+
from magic_pdf.libs.Constants import *
|
22
|
+
from magic_pdf.model.model_list import AtomicModel
|
23
|
+
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
24
|
+
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
|
25
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
|
193
26
|
|
194
27
|
|
195
28
|
class CustomPEKModel:
|
@@ -226,7 +59,7 @@ class CustomPEKModel:
|
|
226
59
|
self.table_config = kwargs.get("table_config")
|
227
60
|
self.apply_table = self.table_config.get("enable", False)
|
228
61
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
229
|
-
self.table_model_name = self.table_config.get("model", MODEL_NAME.
|
62
|
+
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
|
230
63
|
|
231
64
|
# ocr config
|
232
65
|
self.apply_ocr = ocr
|
@@ -235,7 +68,8 @@ class CustomPEKModel:
|
|
235
68
|
logger.info(
|
236
69
|
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
|
237
70
|
"apply_table: {}, table_model: {}, lang: {}".format(
|
238
|
-
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
|
71
|
+
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
|
72
|
+
self.lang
|
239
73
|
)
|
240
74
|
)
|
241
75
|
# 初始化解析方案
|
@@ -248,17 +82,17 @@ class CustomPEKModel:
|
|
248
82
|
|
249
83
|
# 初始化公式识别
|
250
84
|
if self.apply_formula:
|
251
|
-
|
252
85
|
# 初始化公式检测模型
|
253
86
|
self.mfd_model = atom_model_manager.get_atom_model(
|
254
87
|
atom_model_name=AtomicModel.MFD,
|
255
|
-
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
|
88
|
+
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
|
89
|
+
device=self.device
|
256
90
|
)
|
257
91
|
|
258
92
|
# 初始化公式解析模型
|
259
93
|
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
|
260
94
|
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
261
|
-
self.mfr_model
|
95
|
+
self.mfr_model = atom_model_manager.get_atom_model(
|
262
96
|
atom_model_name=AtomicModel.MFR,
|
263
97
|
mfr_weight_dir=mfr_weight_dir,
|
264
98
|
mfr_cfg_path=mfr_cfg_path,
|
@@ -278,12 +112,11 @@ class CustomPEKModel:
|
|
278
112
|
self.layout_model = atom_model_manager.get_atom_model(
|
279
113
|
atom_model_name=AtomicModel.Layout,
|
280
114
|
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
281
|
-
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
115
|
+
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
|
116
|
+
device=self.device
|
282
117
|
)
|
283
118
|
# 初始化ocr
|
284
119
|
if self.apply_ocr:
|
285
|
-
|
286
|
-
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
|
287
120
|
self.ocr_model = atom_model_manager.get_atom_model(
|
288
121
|
atom_model_name=AtomicModel.OCR,
|
289
122
|
ocr_show_log=show_log,
|
@@ -301,43 +134,21 @@ class CustomPEKModel:
|
|
301
134
|
device=self.device
|
302
135
|
)
|
303
136
|
|
304
|
-
home_directory = Path.home()
|
305
|
-
det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
|
306
|
-
rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
|
307
|
-
det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
|
308
|
-
rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
|
309
|
-
|
310
|
-
if not os.path.exists(det_dest_dir):
|
311
|
-
shutil.copytree(det_source, det_dest_dir)
|
312
|
-
if not os.path.exists(rec_dest_dir):
|
313
|
-
shutil.copytree(rec_source, rec_dest_dir)
|
314
|
-
|
315
137
|
logger.info('DocAnalysis init done!')
|
316
138
|
|
317
139
|
def __call__(self, image):
|
318
140
|
|
319
141
|
page_start = time.time()
|
320
142
|
|
321
|
-
latex_filling_list = []
|
322
|
-
mf_image_list = []
|
323
|
-
|
324
143
|
# layout检测
|
325
144
|
layout_start = time.time()
|
145
|
+
layout_res = []
|
326
146
|
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
327
147
|
# layoutlmv3
|
328
148
|
layout_res = self.layout_model(image, ignore_catids=[])
|
329
149
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
330
150
|
# doclayout_yolo
|
331
|
-
layout_res =
|
332
|
-
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
333
|
-
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
334
|
-
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
335
|
-
new_item = {
|
336
|
-
'category_id': int(cla.item()),
|
337
|
-
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
338
|
-
'score': round(float(conf.item()), 3),
|
339
|
-
}
|
340
|
-
layout_res.append(new_item)
|
151
|
+
layout_res = self.layout_model.predict(image)
|
341
152
|
layout_cost = round(time.time() - layout_start, 2)
|
342
153
|
logger.info(f"layout detection time: {layout_cost}")
|
343
154
|
|
@@ -346,59 +157,21 @@ class CustomPEKModel:
|
|
346
157
|
if self.apply_formula:
|
347
158
|
# 公式检测
|
348
159
|
mfd_start = time.time()
|
349
|
-
mfd_res = self.mfd_model.predict(image
|
160
|
+
mfd_res = self.mfd_model.predict(image)
|
350
161
|
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
|
351
|
-
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
352
|
-
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
353
|
-
new_item = {
|
354
|
-
'category_id': 13 + int(cla.item()),
|
355
|
-
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
356
|
-
'score': round(float(conf.item()), 2),
|
357
|
-
'latex': '',
|
358
|
-
}
|
359
|
-
layout_res.append(new_item)
|
360
|
-
latex_filling_list.append(new_item)
|
361
|
-
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
362
|
-
mf_image_list.append(bbox_img)
|
363
162
|
|
364
163
|
# 公式识别
|
365
164
|
mfr_start = time.time()
|
366
|
-
|
367
|
-
|
368
|
-
mfr_res = []
|
369
|
-
for mf_img in dataloader:
|
370
|
-
mf_img = mf_img.to(self.device)
|
371
|
-
with torch.no_grad():
|
372
|
-
output = self.mfr_model.generate({'image': mf_img})
|
373
|
-
mfr_res.extend(output['pred_str'])
|
374
|
-
for res, latex in zip(latex_filling_list, mfr_res):
|
375
|
-
res['latex'] = latex_rm_whitespace(latex)
|
165
|
+
formula_list = self.mfr_model.predict(mfd_res, image)
|
166
|
+
layout_res.extend(formula_list)
|
376
167
|
mfr_cost = round(time.time() - mfr_start, 2)
|
377
|
-
logger.info(f"formula nums: {len(
|
378
|
-
|
379
|
-
#
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
if int(res['category_id']) in [13, 14]:
|
385
|
-
single_page_mfdetrec_res.append({
|
386
|
-
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
387
|
-
int(res['poly'][4]), int(res['poly'][5])],
|
388
|
-
})
|
389
|
-
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
390
|
-
ocr_res_list.append(res)
|
391
|
-
elif int(res['category_id']) in [5]:
|
392
|
-
table_res_list.append(res)
|
393
|
-
|
394
|
-
if torch.cuda.is_available() and self.device != 'cpu':
|
395
|
-
properties = torch.cuda.get_device_properties(self.device)
|
396
|
-
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
|
397
|
-
if total_memory <= 10:
|
398
|
-
gc_start = time.time()
|
399
|
-
clean_memory()
|
400
|
-
gc_time = round(time.time() - gc_start, 2)
|
401
|
-
logger.info(f"gc time: {gc_time}")
|
168
|
+
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
|
169
|
+
|
170
|
+
# 清理显存
|
171
|
+
clean_vram(self.device, vram_threshold=8)
|
172
|
+
|
173
|
+
# 从layout_res中获取ocr区域、表格区域、公式区域
|
174
|
+
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
|
402
175
|
|
403
176
|
# ocr识别
|
404
177
|
if self.apply_ocr:
|
@@ -406,23 +179,7 @@ class CustomPEKModel:
|
|
406
179
|
# Process each area that requires OCR processing
|
407
180
|
for res in ocr_res_list:
|
408
181
|
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
409
|
-
|
410
|
-
# Adjust the coordinates of the formula area
|
411
|
-
adjusted_mfdetrec_res = []
|
412
|
-
for mf_res in single_page_mfdetrec_res:
|
413
|
-
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
414
|
-
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
415
|
-
x0 = mf_xmin - xmin + paste_x
|
416
|
-
y0 = mf_ymin - ymin + paste_y
|
417
|
-
x1 = mf_xmax - xmin + paste_x
|
418
|
-
y1 = mf_ymax - ymin + paste_y
|
419
|
-
# Filter formula blocks outside the graph
|
420
|
-
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
421
|
-
continue
|
422
|
-
else:
|
423
|
-
adjusted_mfdetrec_res.append({
|
424
|
-
"bbox": [x0, y0, x1, y1],
|
425
|
-
})
|
182
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
426
183
|
|
427
184
|
# OCR recognition
|
428
185
|
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
@@ -430,22 +187,8 @@ class CustomPEKModel:
|
|
430
187
|
|
431
188
|
# Integration results
|
432
189
|
if ocr_res:
|
433
|
-
|
434
|
-
|
435
|
-
text, score = box_ocr_res[1]
|
436
|
-
|
437
|
-
# Convert the coordinates back to the original coordinate system
|
438
|
-
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
439
|
-
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
440
|
-
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
441
|
-
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
442
|
-
|
443
|
-
layout_res.append({
|
444
|
-
'category_id': 15,
|
445
|
-
'poly': p1 + p2 + p3 + p4,
|
446
|
-
'score': round(score, 2),
|
447
|
-
'text': text,
|
448
|
-
})
|
190
|
+
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
191
|
+
layout_res.extend(ocr_result_list)
|
449
192
|
|
450
193
|
ocr_cost = round(time.time() - ocr_start, 2)
|
451
194
|
logger.info(f"ocr time: {ocr_cost}")
|
@@ -456,41 +199,30 @@ class CustomPEKModel:
|
|
456
199
|
for res in table_res_list:
|
457
200
|
new_image, _ = crop_img(res, pil_img)
|
458
201
|
single_table_start_time = time.time()
|
459
|
-
# logger.info("------------------table recognition processing begins-----------------")
|
460
|
-
latex_code = None
|
461
202
|
html_code = None
|
462
203
|
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
463
204
|
with torch.no_grad():
|
464
205
|
table_result = self.table_model.predict(new_image, "html")
|
465
206
|
if len(table_result) > 0:
|
466
207
|
html_code = table_result[0]
|
467
|
-
|
208
|
+
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
468
209
|
html_code = self.table_model.img2html(new_image)
|
469
|
-
|
210
|
+
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
211
|
+
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
|
470
212
|
run_time = time.time() - single_table_start_time
|
471
|
-
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
472
213
|
if run_time > self.table_max_time:
|
473
|
-
logger.warning(f"
|
214
|
+
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
|
474
215
|
# 判断是否返回正常
|
475
|
-
|
476
|
-
if latex_code:
|
477
|
-
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
|
478
|
-
if expected_ending:
|
479
|
-
res["latex"] = latex_code
|
480
|
-
else:
|
481
|
-
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
|
482
|
-
elif html_code:
|
216
|
+
if html_code:
|
483
217
|
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
|
484
218
|
if expected_ending:
|
485
219
|
res["html"] = html_code
|
486
220
|
else:
|
487
221
|
logger.warning(f"table recognition processing fails, not found expected HTML table end")
|
488
222
|
else:
|
489
|
-
logger.warning(f"table recognition processing fails, not get
|
223
|
+
logger.warning(f"table recognition processing fails, not get html return")
|
490
224
|
logger.info(f"table time: {round(time.time() - table_start, 2)}")
|
491
225
|
|
492
226
|
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
|
493
227
|
|
494
228
|
return layout_res
|
495
|
-
|
496
|
-
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from doclayout_yolo import YOLOv10
|
2
|
+
|
3
|
+
|
4
|
+
class DocLayoutYOLOModel(object):
|
5
|
+
def __init__(self, weight, device):
|
6
|
+
self.model = YOLOv10(weight)
|
7
|
+
self.device = device
|
8
|
+
|
9
|
+
def predict(self, image):
|
10
|
+
layout_res = []
|
11
|
+
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
12
|
+
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
|
13
|
+
doclayout_yolo_res.boxes.cls.cpu()):
|
14
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
15
|
+
new_item = {
|
16
|
+
'category_id': int(cla.item()),
|
17
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
18
|
+
'score': round(float(conf.item()), 3),
|
19
|
+
}
|
20
|
+
layout_res.append(new_item)
|
21
|
+
return layout_res
|
File without changes
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from ultralytics import YOLO
|
2
|
+
|
3
|
+
|
4
|
+
class YOLOv8MFDModel(object):
|
5
|
+
def __init__(self, weight, device='cpu'):
|
6
|
+
self.mfd_model = YOLO(weight)
|
7
|
+
self.device = device
|
8
|
+
|
9
|
+
def predict(self, image):
|
10
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
11
|
+
return mfd_res
|
12
|
+
|
File without changes
|
File without changes
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import os
|
2
|
+
import argparse
|
3
|
+
import re
|
4
|
+
|
5
|
+
from PIL import Image
|
6
|
+
import torch
|
7
|
+
from torch.utils.data import Dataset, DataLoader
|
8
|
+
from torchvision import transforms
|
9
|
+
from unimernet.common.config import Config
|
10
|
+
import unimernet.tasks as tasks
|
11
|
+
from unimernet.processors import load_processor
|
12
|
+
|
13
|
+
|
14
|
+
class MathDataset(Dataset):
|
15
|
+
def __init__(self, image_paths, transform=None):
|
16
|
+
self.image_paths = image_paths
|
17
|
+
self.transform = transform
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return len(self.image_paths)
|
21
|
+
|
22
|
+
def __getitem__(self, idx):
|
23
|
+
# if not pil image, then convert to pil image
|
24
|
+
if isinstance(self.image_paths[idx], str):
|
25
|
+
raw_image = Image.open(self.image_paths[idx])
|
26
|
+
else:
|
27
|
+
raw_image = self.image_paths[idx]
|
28
|
+
if self.transform:
|
29
|
+
image = self.transform(raw_image)
|
30
|
+
return image
|
31
|
+
|
32
|
+
|
33
|
+
def latex_rm_whitespace(s: str):
|
34
|
+
"""Remove unnecessary whitespace from LaTeX code.
|
35
|
+
"""
|
36
|
+
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
37
|
+
letter = '[a-zA-Z]'
|
38
|
+
noletter = '[\W_^\d]'
|
39
|
+
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
40
|
+
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
41
|
+
news = s
|
42
|
+
while True:
|
43
|
+
s = news
|
44
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
45
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
46
|
+
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
47
|
+
if news == s:
|
48
|
+
break
|
49
|
+
return s
|
50
|
+
|
51
|
+
|
52
|
+
class UnimernetModel(object):
|
53
|
+
def __init__(self, weight_dir, cfg_path, _device_='cpu'):
|
54
|
+
|
55
|
+
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
56
|
+
cfg = Config(args)
|
57
|
+
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
58
|
+
cfg.config.model.model_config.model_name = weight_dir
|
59
|
+
cfg.config.model.tokenizer_config.path = weight_dir
|
60
|
+
task = tasks.setup_task(cfg)
|
61
|
+
self.model = task.build_model(cfg)
|
62
|
+
self.device = _device_
|
63
|
+
self.model.to(_device_)
|
64
|
+
self.model.eval()
|
65
|
+
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
66
|
+
self.mfr_transform = transforms.Compose([vis_processor, ])
|
67
|
+
|
68
|
+
def predict(self, mfd_res, image):
|
69
|
+
|
70
|
+
formula_list = []
|
71
|
+
mf_image_list = []
|
72
|
+
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
73
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
74
|
+
new_item = {
|
75
|
+
'category_id': 13 + int(cla.item()),
|
76
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
77
|
+
'score': round(float(conf.item()), 2),
|
78
|
+
'latex': '',
|
79
|
+
}
|
80
|
+
formula_list.append(new_item)
|
81
|
+
pil_img = Image.fromarray(image)
|
82
|
+
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
83
|
+
mf_image_list.append(bbox_img)
|
84
|
+
|
85
|
+
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
86
|
+
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
87
|
+
mfr_res = []
|
88
|
+
for mf_img in dataloader:
|
89
|
+
mf_img = mf_img.to(self.device)
|
90
|
+
with torch.no_grad():
|
91
|
+
output = self.model.generate({'image': mf_img})
|
92
|
+
mfr_res.extend(output['pred_str'])
|
93
|
+
for res, latex in zip(formula_list, mfr_res):
|
94
|
+
res['latex'] = latex_rm_whitespace(latex)
|
95
|
+
return formula_list
|
96
|
+
|
97
|
+
|
98
|
+
|