magic-pdf 0.6.1__py3-none-any.whl → 0.7.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +20 -7
- magic_pdf/libs/config_reader.py +28 -10
- magic_pdf/libs/language.py +12 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +35 -3
- magic_pdf/model/magic_model.py +49 -41
- magic_pdf/model/pdf_extract_kit.py +155 -60
- magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +7 -6
- magic_pdf/model/pek_sub_modules/self_modify.py +87 -43
- magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +22 -0
- magic_pdf/model/pp_structure_v2.py +1 -1
- magic_pdf/pdf_parse_union_core.py +4 -2
- magic_pdf/pre_proc/citationmarker_remove.py +5 -1
- magic_pdf/pre_proc/ocr_detect_all_bboxes.py +40 -2
- magic_pdf/pre_proc/ocr_span_list_modify.py +12 -7
- magic_pdf/resources/fasttext-langdetect/lid.176.ftz +0 -0
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +2 -2
- magic_pdf/resources/model_config/model_configs.yaml +4 -0
- magic_pdf/rw/AbsReaderWriter.py +1 -18
- magic_pdf/rw/DiskReaderWriter.py +32 -24
- magic_pdf/rw/S3ReaderWriter.py +83 -48
- magic_pdf/tools/cli.py +79 -0
- magic_pdf/tools/cli_dev.py +156 -0
- magic_pdf/tools/common.py +119 -0
- {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/METADATA +120 -72
- {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/RECORD +34 -35
- {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/WHEEL +1 -1
- magic_pdf-0.7.0a1.dist-info/entry_points.txt +3 -0
- magic_pdf/cli/magicpdf.py +0 -337
- magic_pdf/pdf_parse_for_train.py +0 -685
- magic_pdf/train_utils/convert_to_train_format.py +0 -65
- magic_pdf/train_utils/extract_caption.py +0 -59
- magic_pdf/train_utils/remove_footer_header.py +0 -159
- magic_pdf/train_utils/vis_utils.py +0 -327
- magic_pdf-0.6.1.dist-info/entry_points.txt +0 -2
- /magic_pdf/libs/{math.py → local_math.py} +0 -0
- /magic_pdf/{cli → model/pek_sub_modules/structeqtable}/__init__.py +0 -0
- /magic_pdf/{train_utils → tools}/__init__.py +0 -0
- {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,19 @@
|
|
1
1
|
from loguru import logger
|
2
2
|
import os
|
3
|
+
import time
|
4
|
+
|
5
|
+
|
6
|
+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
3
7
|
try:
|
4
8
|
import cv2
|
5
9
|
import yaml
|
6
|
-
import time
|
7
10
|
import argparse
|
8
11
|
import numpy as np
|
9
12
|
import torch
|
13
|
+
import torchtext
|
10
14
|
|
11
|
-
|
15
|
+
if torchtext.__version__ >= "0.18.0":
|
16
|
+
torchtext.disable_torchtext_deprecation_warning()
|
12
17
|
from PIL import Image
|
13
18
|
from torchvision import transforms
|
14
19
|
from torch.utils.data import Dataset, DataLoader
|
@@ -17,13 +22,23 @@ try:
|
|
17
22
|
import unimernet.tasks as tasks
|
18
23
|
from unimernet.processors import load_processor
|
19
24
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
+
except ImportError as e:
|
26
|
+
logger.exception(e)
|
27
|
+
logger.error(
|
28
|
+
'Required dependency not installed, please install by \n'
|
29
|
+
'"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
|
25
30
|
exit(1)
|
26
31
|
|
32
|
+
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
33
|
+
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
34
|
+
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
35
|
+
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
36
|
+
|
37
|
+
|
38
|
+
def table_model_init(model_path, max_time=400, _device_='cpu'):
|
39
|
+
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
40
|
+
return table_model
|
41
|
+
|
27
42
|
|
28
43
|
def mfd_model_init(weight):
|
29
44
|
mfd_model = YOLO(weight)
|
@@ -83,15 +98,17 @@ class CustomPEKModel:
|
|
83
98
|
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
|
84
99
|
# 构建 model_configs.yaml 文件的完整路径
|
85
100
|
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
|
86
|
-
with open(config_path, "r") as f:
|
101
|
+
with open(config_path, "r", encoding='utf-8') as f:
|
87
102
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
88
103
|
# 初始化解析配置
|
89
104
|
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
90
105
|
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
106
|
+
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
|
107
|
+
self.apply_table = self.table_config.get("is_table_recog_enable", False)
|
91
108
|
self.apply_ocr = ocr
|
92
109
|
logger.info(
|
93
|
-
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
|
94
|
-
self.apply_layout, self.apply_formula, self.apply_ocr
|
110
|
+
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
|
111
|
+
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
|
95
112
|
)
|
96
113
|
)
|
97
114
|
assert self.apply_layout, "DocAnalysis must contain layout model."
|
@@ -99,6 +116,7 @@ class CustomPEKModel:
|
|
99
116
|
self.device = kwargs.get("device", self.configs["config"]["device"])
|
100
117
|
logger.info("using device: {}".format(self.device))
|
101
118
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
119
|
+
logger.info("using models_dir: {}".format(models_dir))
|
102
120
|
|
103
121
|
# 初始化公式识别
|
104
122
|
if self.apply_formula:
|
@@ -121,6 +139,11 @@ class CustomPEKModel:
|
|
121
139
|
if self.apply_ocr:
|
122
140
|
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|
123
141
|
|
142
|
+
# init structeqtable
|
143
|
+
if self.apply_table:
|
144
|
+
max_time = self.table_config.get("max_time", 400)
|
145
|
+
self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
|
146
|
+
max_time=max_time, _device_=self.device)
|
124
147
|
logger.info('DocAnalysis init done!')
|
125
148
|
|
126
149
|
def __call__(self, image):
|
@@ -134,67 +157,139 @@ class CustomPEKModel:
|
|
134
157
|
layout_cost = round(time.time() - layout_start, 2)
|
135
158
|
logger.info(f"layout detection cost: {layout_cost}")
|
136
159
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
mf_img
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
res
|
163
|
-
|
164
|
-
|
160
|
+
if self.apply_formula:
|
161
|
+
# 公式检测
|
162
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
163
|
+
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
164
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
165
|
+
new_item = {
|
166
|
+
'category_id': 13 + int(cla.item()),
|
167
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
168
|
+
'score': round(float(conf.item()), 2),
|
169
|
+
'latex': '',
|
170
|
+
}
|
171
|
+
layout_res.append(new_item)
|
172
|
+
latex_filling_list.append(new_item)
|
173
|
+
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
174
|
+
mf_image_list.append(bbox_img)
|
175
|
+
|
176
|
+
# 公式识别
|
177
|
+
mfr_start = time.time()
|
178
|
+
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
179
|
+
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
180
|
+
mfr_res = []
|
181
|
+
for mf_img in dataloader:
|
182
|
+
mf_img = mf_img.to(self.device)
|
183
|
+
output = self.mfr_model.generate({'image': mf_img})
|
184
|
+
mfr_res.extend(output['pred_str'])
|
185
|
+
for res, latex in zip(latex_filling_list, mfr_res):
|
186
|
+
res['latex'] = latex_rm_whitespace(latex)
|
187
|
+
mfr_cost = round(time.time() - mfr_start, 2)
|
188
|
+
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
165
189
|
|
166
190
|
# ocr识别
|
167
191
|
if self.apply_ocr:
|
168
192
|
ocr_start = time.time()
|
169
193
|
pil_img = Image.fromarray(image)
|
194
|
+
|
195
|
+
# 筛选出需要OCR的区域和公式区域
|
196
|
+
ocr_res_list = []
|
170
197
|
single_page_mfdetrec_res = []
|
171
198
|
for res in layout_res:
|
172
199
|
if int(res['category_id']) in [13, 14]:
|
173
|
-
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
174
|
-
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
175
200
|
single_page_mfdetrec_res.append({
|
176
|
-
"bbox": [
|
201
|
+
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
202
|
+
int(res['poly'][4]), int(res['poly'][5])],
|
177
203
|
})
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
204
|
+
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
205
|
+
ocr_res_list.append(res)
|
206
|
+
|
207
|
+
# 对每一个需OCR处理的区域进行处理
|
208
|
+
for res in ocr_res_list:
|
209
|
+
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
210
|
+
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
211
|
+
|
212
|
+
paste_x = 50
|
213
|
+
paste_y = 50
|
214
|
+
# 创建一个宽高各多50的白色背景
|
215
|
+
new_width = xmax - xmin + paste_x * 2
|
216
|
+
new_height = ymax - ymin + paste_y * 2
|
217
|
+
new_image = Image.new('RGB', (new_width, new_height), 'white')
|
218
|
+
|
219
|
+
# 裁剪图像
|
220
|
+
crop_box = (xmin, ymin, xmax, ymax)
|
221
|
+
cropped_img = pil_img.crop(crop_box)
|
222
|
+
new_image.paste(cropped_img, (paste_x, paste_y))
|
223
|
+
|
224
|
+
# 调整公式区域坐标
|
225
|
+
adjusted_mfdetrec_res = []
|
226
|
+
for mf_res in single_page_mfdetrec_res:
|
227
|
+
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
228
|
+
# 将公式区域坐标调整为相对于裁剪区域的坐标
|
229
|
+
x0 = mf_xmin - xmin + paste_x
|
230
|
+
y0 = mf_ymin - ymin + paste_y
|
231
|
+
x1 = mf_xmax - xmin + paste_x
|
232
|
+
y1 = mf_ymax - ymin + paste_y
|
233
|
+
# 过滤在图外的公式块
|
234
|
+
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
235
|
+
continue
|
236
|
+
else:
|
237
|
+
adjusted_mfdetrec_res.append({
|
238
|
+
"bbox": [x0, y0, x1, y1],
|
239
|
+
})
|
240
|
+
|
241
|
+
# OCR识别
|
242
|
+
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
243
|
+
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
244
|
+
|
245
|
+
# 整合结果
|
246
|
+
if ocr_res:
|
247
|
+
for box_ocr_res in ocr_res:
|
248
|
+
p1, p2, p3, p4 = box_ocr_res[0]
|
249
|
+
text, score = box_ocr_res[1]
|
250
|
+
|
251
|
+
# 将坐标转换回原图坐标系
|
252
|
+
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
253
|
+
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
254
|
+
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
255
|
+
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
256
|
+
|
257
|
+
layout_res.append({
|
258
|
+
'category_id': 15,
|
259
|
+
'poly': p1 + p2 + p3 + p4,
|
260
|
+
'score': round(score, 2),
|
261
|
+
'text': text,
|
262
|
+
})
|
263
|
+
|
197
264
|
ocr_cost = round(time.time() - ocr_start, 2)
|
198
265
|
logger.info(f"ocr cost: {ocr_cost}")
|
199
266
|
|
267
|
+
# 表格识别 table recognition
|
268
|
+
if self.apply_table:
|
269
|
+
pil_img = Image.fromarray(image)
|
270
|
+
for layout in layout_res:
|
271
|
+
if layout.get("category_id", -1) == 5:
|
272
|
+
poly = layout["poly"]
|
273
|
+
xmin, ymin = int(poly[0]), int(poly[1])
|
274
|
+
xmax, ymax = int(poly[4]), int(poly[5])
|
275
|
+
|
276
|
+
paste_x = 50
|
277
|
+
paste_y = 50
|
278
|
+
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
|
279
|
+
new_width = xmax - xmin + paste_x * 2
|
280
|
+
new_height = ymax - ymin + paste_y * 2
|
281
|
+
new_image = Image.new('RGB', (new_width, new_height), 'white')
|
282
|
+
|
283
|
+
# 裁剪图像 crop image
|
284
|
+
crop_box = (xmin, ymin, xmax, ymax)
|
285
|
+
cropped_img = pil_img.crop(crop_box)
|
286
|
+
new_image.paste(cropped_img, (paste_x, paste_y))
|
287
|
+
start_time = time.time()
|
288
|
+
logger.info("------------------table recognition processing begins-----------------")
|
289
|
+
latex_code = self.table_model.image2latex(new_image)[0]
|
290
|
+
end_time = time.time()
|
291
|
+
run_time = end_time - start_time
|
292
|
+
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
293
|
+
layout["latex"] = latex_code
|
294
|
+
|
200
295
|
return layout_res
|
@@ -79,12 +79,13 @@ def setup(args, device):
|
|
79
79
|
cfg.freeze()
|
80
80
|
default_setup(cfg, args)
|
81
81
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
82
|
+
#@todo 可以删掉这块?
|
83
|
+
# register_coco_instances(
|
84
|
+
# "scihub_train",
|
85
|
+
# {},
|
86
|
+
# cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
|
87
|
+
# cfg.SCIHUB_DATA_DIR_TRAIN
|
88
|
+
# )
|
88
89
|
|
89
90
|
return cfg
|
90
91
|
|
@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
|
|
10
10
|
from paddleocr.ppocr.utils.logging import get_logger
|
11
11
|
from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
|
12
12
|
from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
|
13
|
+
|
14
|
+
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
15
|
+
|
13
16
|
logger = get_logger()
|
14
17
|
|
18
|
+
|
15
19
|
def img_decode(content: bytes):
|
16
20
|
np_arr = np.frombuffer(content, dtype=np.uint8)
|
17
21
|
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
18
22
|
|
23
|
+
|
19
24
|
def check_img(img):
|
20
25
|
if isinstance(img, bytes):
|
21
26
|
img = img_decode(img)
|
@@ -51,6 +56,7 @@ def check_img(img):
|
|
51
56
|
|
52
57
|
return img
|
53
58
|
|
59
|
+
|
54
60
|
def sorted_boxes(dt_boxes):
|
55
61
|
"""
|
56
62
|
Sort text boxes in order from top to bottom, left to right
|
@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
|
|
75
81
|
return _boxes
|
76
82
|
|
77
83
|
|
78
|
-
def
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
if
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
84
|
+
def bbox_to_points(bbox):
|
85
|
+
""" 将bbox格式转换为四个顶点的数组 """
|
86
|
+
x0, y0, x1, y1 = bbox
|
87
|
+
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
|
88
|
+
|
89
|
+
|
90
|
+
def points_to_bbox(points):
|
91
|
+
""" 将四个顶点的数组转换为bbox格式 """
|
92
|
+
x0, y0 = points[0]
|
93
|
+
x1, _ = points[1]
|
94
|
+
_, y1 = points[2]
|
95
|
+
return [x0, y0, x1, y1]
|
96
|
+
|
97
|
+
|
98
|
+
def merge_intervals(intervals):
|
99
|
+
# Sort the intervals based on the start value
|
100
|
+
intervals.sort(key=lambda x: x[0])
|
101
|
+
|
102
|
+
merged = []
|
103
|
+
for interval in intervals:
|
104
|
+
# If the list of merged intervals is empty or if the current
|
105
|
+
# interval does not overlap with the previous, simply append it.
|
106
|
+
if not merged or merged[-1][1] < interval[0]:
|
107
|
+
merged.append(interval)
|
108
|
+
else:
|
109
|
+
# Otherwise, there is overlap, so we merge the current and previous intervals.
|
110
|
+
merged[-1][1] = max(merged[-1][1], interval[1])
|
111
|
+
|
112
|
+
return merged
|
113
|
+
|
114
|
+
|
115
|
+
def remove_intervals(original, masks):
|
116
|
+
# Merge all mask intervals
|
117
|
+
merged_masks = merge_intervals(masks)
|
118
|
+
|
119
|
+
result = []
|
120
|
+
original_start, original_end = original
|
121
|
+
|
122
|
+
for mask in merged_masks:
|
123
|
+
mask_start, mask_end = mask
|
124
|
+
|
125
|
+
# If the mask starts after the original range, ignore it
|
126
|
+
if mask_start > original_end:
|
127
|
+
continue
|
128
|
+
|
129
|
+
# If the mask ends before the original range starts, ignore it
|
130
|
+
if mask_end < original_start:
|
131
|
+
continue
|
132
|
+
|
133
|
+
# Remove the masked part from the original range
|
134
|
+
if original_start < mask_start:
|
135
|
+
result.append([original_start, mask_start - 1])
|
136
|
+
|
137
|
+
original_start = max(mask_end + 1, original_start)
|
138
|
+
|
139
|
+
# Add the remaining part of the original range, if any
|
140
|
+
if original_start <= original_end:
|
141
|
+
result.append([original_start, original_end])
|
142
|
+
|
143
|
+
return result
|
144
|
+
|
145
|
+
|
146
|
+
def update_det_boxes(dt_boxes, mfd_res):
|
147
|
+
new_dt_boxes = []
|
148
|
+
for text_box in dt_boxes:
|
149
|
+
text_bbox = points_to_bbox(text_box)
|
150
|
+
masks_list = []
|
151
|
+
for mf_box in mfd_res:
|
152
|
+
mf_bbox = mf_box['bbox']
|
153
|
+
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
|
154
|
+
masks_list.append([mf_bbox[0], mf_bbox[2]])
|
155
|
+
text_x_range = [text_bbox[0], text_bbox[2]]
|
156
|
+
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
|
157
|
+
temp_dt_box = []
|
158
|
+
for text_remove_mask in text_remove_mask_range:
|
159
|
+
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
160
|
+
if len(temp_dt_box) > 0:
|
161
|
+
new_dt_boxes.extend(temp_dt_box)
|
119
162
|
return new_dt_boxes
|
120
163
|
|
164
|
+
|
121
165
|
class ModifiedPaddleOCR(PaddleOCR):
|
122
166
|
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
|
123
167
|
"""
|
@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
197
241
|
if not rec:
|
198
242
|
return cls_res
|
199
243
|
return ocr_res
|
200
|
-
|
244
|
+
|
201
245
|
def __call__(self, img, cls=True, mfd_res=None):
|
202
246
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
203
247
|
|
@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
226
270
|
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
227
271
|
aft = time.time()
|
228
272
|
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
229
|
-
len(dt_boxes), aft-bef))
|
273
|
+
len(dt_boxes), aft - bef))
|
230
274
|
|
231
275
|
for bno in range(len(dt_boxes)):
|
232
276
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from struct_eqtable.model import StructTable
|
2
|
+
from pypandoc import convert_text
|
3
|
+
class StructTableModel:
|
4
|
+
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
|
5
|
+
# init
|
6
|
+
self.model_path = model_path
|
7
|
+
self.max_new_tokens = max_new_tokens # maximum output tokens length
|
8
|
+
self.max_time = max_time # timeout for processing in seconds
|
9
|
+
if device == 'cuda':
|
10
|
+
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
|
11
|
+
else:
|
12
|
+
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
|
13
|
+
|
14
|
+
def image2latex(self, image) -> str:
|
15
|
+
#
|
16
|
+
table_latex = self.model.forward(image)
|
17
|
+
return table_latex
|
18
|
+
|
19
|
+
def image2html(self, image) -> str:
|
20
|
+
table_latex = self.image2latex(image)
|
21
|
+
table_html = convert_text(table_latex, 'html', format='latex')
|
22
|
+
return table_html
|
@@ -5,7 +5,7 @@ from loguru import logger
|
|
5
5
|
try:
|
6
6
|
from paddleocr import PPStructure
|
7
7
|
except ImportError:
|
8
|
-
logger.error('paddleocr not installed, please install by "pip install magic-pdf[
|
8
|
+
logger.error('paddleocr not installed, please install by "pip install magic-pdf[lite]"')
|
9
9
|
exit(1)
|
10
10
|
|
11
11
|
|
@@ -7,7 +7,7 @@ from magic_pdf.layout.layout_sort import get_bboxes_layout, LAYOUT_UNPROC, get_c
|
|
7
7
|
from magic_pdf.libs.convert_utils import dict_to_list
|
8
8
|
from magic_pdf.libs.drop_reason import DropReason
|
9
9
|
from magic_pdf.libs.hash_utils import compute_md5
|
10
|
-
from magic_pdf.libs.
|
10
|
+
from magic_pdf.libs.local_math import float_equal
|
11
11
|
from magic_pdf.libs.ocr_content_type import ContentType
|
12
12
|
from magic_pdf.model.magic_model import MagicModel
|
13
13
|
from magic_pdf.para.para_split_v2 import para_split
|
@@ -111,7 +111,8 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
|
|
111
111
|
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
|
112
112
|
|
113
113
|
'''将所有区块的bbox整理到一起'''
|
114
|
-
#
|
114
|
+
# interline_equation_blocks参数不够准,后面切换到interline_equations上
|
115
|
+
interline_equation_blocks = []
|
115
116
|
if len(interline_equation_blocks) > 0:
|
116
117
|
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
|
117
118
|
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
|
@@ -120,6 +121,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
|
|
120
121
|
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
|
121
122
|
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
|
122
123
|
interline_equations, page_w, page_h)
|
124
|
+
|
123
125
|
if len(drop_reasons) > 0:
|
124
126
|
need_drop = True
|
125
127
|
drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
|
@@ -135,7 +135,11 @@ def remove_citation_marker(with_char_text_blcoks):
|
|
135
135
|
|
136
136
|
if max_font_sz-span_font_sz<1: # 先以字体过滤正文,如果是正文就不再继续判断了
|
137
137
|
continue
|
138
|
-
|
138
|
+
|
139
|
+
# 对被除数为0的情况进行过滤
|
140
|
+
if span_hi==0 or min_font_sz==0:
|
141
|
+
continue
|
142
|
+
|
139
143
|
if (base_span_mid_y-span_mid_y)/span_hi>0.2 or (base_span_mid_y-span_mid_y>0 and abs(span_font_sz-min_font_sz)/min_font_sz<0.1):
|
140
144
|
"""
|
141
145
|
1. 它的前一个char如果是句号或者逗号的话,那么肯定是角标而不是公式
|
@@ -36,9 +36,12 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
|
|
36
36
|
all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
|
37
37
|
'''任何框体与舍弃框重叠,优先信任舍弃框'''
|
38
38
|
all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
|
39
|
-
|
39
|
+
|
40
|
+
# interline_equation 与title或text框冲突的情况,分两种情况处理
|
40
41
|
'''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
|
42
|
+
all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
|
41
43
|
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
|
44
|
+
# 通过后续大框套小框逻辑删除
|
42
45
|
|
43
46
|
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
|
44
47
|
for discarded in discarded_blocks:
|
@@ -57,6 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
|
|
57
60
|
return all_bboxes, all_discarded_blocks, drop_reasons
|
58
61
|
|
59
62
|
|
63
|
+
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
|
64
|
+
# 先提取所有text和interline block
|
65
|
+
text_blocks = []
|
66
|
+
for block in all_bboxes:
|
67
|
+
if block[7] == BlockType.Text:
|
68
|
+
text_blocks.append(block)
|
69
|
+
interline_equation_blocks = []
|
70
|
+
for block in all_bboxes:
|
71
|
+
if block[7] == BlockType.InterlineEquation:
|
72
|
+
interline_equation_blocks.append(block)
|
73
|
+
|
74
|
+
need_remove = []
|
75
|
+
|
76
|
+
for interline_equation_block in interline_equation_blocks:
|
77
|
+
for text_block in text_blocks:
|
78
|
+
interline_equation_block_bbox = interline_equation_block[:4]
|
79
|
+
text_block_bbox = text_block[:4]
|
80
|
+
if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
|
81
|
+
if text_block not in need_remove:
|
82
|
+
need_remove.append(text_block)
|
83
|
+
|
84
|
+
if len(need_remove) > 0:
|
85
|
+
for block in need_remove:
|
86
|
+
all_bboxes.remove(block)
|
87
|
+
|
88
|
+
return all_bboxes
|
89
|
+
|
90
|
+
|
60
91
|
def fix_text_overlap_title_blocks(all_bboxes):
|
61
92
|
# 先提取所有text和title block
|
62
93
|
text_blocks = []
|
@@ -68,12 +99,19 @@ def fix_text_overlap_title_blocks(all_bboxes):
|
|
68
99
|
if block[7] == BlockType.Title:
|
69
100
|
title_blocks.append(block)
|
70
101
|
|
102
|
+
need_remove = []
|
103
|
+
|
71
104
|
for text_block in text_blocks:
|
72
105
|
for title_block in title_blocks:
|
73
106
|
text_block_bbox = text_block[:4]
|
74
107
|
title_block_bbox = title_block[:4]
|
75
108
|
if calculate_iou(text_block_bbox, title_block_bbox) > 0.8:
|
76
|
-
|
109
|
+
if title_block not in need_remove:
|
110
|
+
need_remove.append(title_block)
|
111
|
+
|
112
|
+
if len(need_remove) > 0:
|
113
|
+
for block in need_remove:
|
114
|
+
all_bboxes.remove(block)
|
77
115
|
|
78
116
|
return all_bboxes
|
79
117
|
|
@@ -5,19 +5,24 @@ from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, g
|
|
5
5
|
from magic_pdf.libs.drop_tag import DropTag
|
6
6
|
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
|
7
7
|
|
8
|
+
|
8
9
|
def remove_overlaps_low_confidence_spans(spans):
|
9
10
|
dropped_spans = []
|
10
11
|
# 删除重叠spans中置信度低的的那些
|
11
12
|
for span1 in spans:
|
12
13
|
for span2 in spans:
|
13
14
|
if span1 != span2:
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
15
|
+
# span1 或 span2 任何一个都不应该在 dropped_spans 中
|
16
|
+
if span1 in dropped_spans or span2 in dropped_spans:
|
17
|
+
continue
|
18
|
+
else:
|
19
|
+
if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
|
20
|
+
if span1['score'] < span2['score']:
|
21
|
+
span_need_remove = span1
|
22
|
+
else:
|
23
|
+
span_need_remove = span2
|
24
|
+
if span_need_remove is not None and span_need_remove not in dropped_spans:
|
25
|
+
dropped_spans.append(span_need_remove)
|
21
26
|
|
22
27
|
if len(dropped_spans) > 0:
|
23
28
|
for span_need_remove in dropped_spans:
|
Binary file
|