magic-pdf 0.7.0a1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +4 -0
- magic_pdf/libs/Constants.py +30 -1
- magic_pdf/libs/draw_bbox.py +66 -1
- magic_pdf/libs/ocr_content_type.py +14 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +2 -2
- magic_pdf/model/magic_model.py +3 -0
- magic_pdf/model/pdf_extract_kit.py +94 -70
- magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +0 -1
- magic_pdf/model/ppTableModel.py +67 -0
- magic_pdf/para/para_split_v2.py +50 -47
- magic_pdf/resources/model_config/model_configs.yaml +3 -1
- magic_pdf/tools/cli_dev.py +8 -9
- magic_pdf/tools/common.py +4 -1
- magic_pdf-0.7.1.dist-info/METADATA +417 -0
- {magic_pdf-0.7.0a1.dist-info → magic_pdf-0.7.1.dist-info}/RECORD +20 -19
- magic_pdf-0.7.0a1.dist-info/METADATA +0 -362
- {magic_pdf-0.7.0a1.dist-info → magic_pdf-0.7.1.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.7.0a1.dist-info → magic_pdf-0.7.1.dist-info}/WHEEL +0 -0
- {magic_pdf-0.7.0a1.dist-info → magic_pdf-0.7.1.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.7.0a1.dist-info → magic_pdf-0.7.1.dist-info}/top_level.txt +0 -0
@@ -132,6 +132,8 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
|
|
132
132
|
# if processed by table model
|
133
133
|
if span.get('latex', ''):
|
134
134
|
para_text += f"\n\n$\n {span['latex']}\n$\n\n"
|
135
|
+
elif span.get('html', ''):
|
136
|
+
para_text += f"\n\n{span['html']}\n\n"
|
135
137
|
else:
|
136
138
|
para_text += f"\n}) \n"
|
137
139
|
for block in para_block['blocks']: # 3rd.拼table_footnote
|
@@ -256,6 +258,8 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
|
|
256
258
|
if block['type'] == BlockType.TableBody:
|
257
259
|
if block["lines"][0]["spans"][0].get('latex', ''):
|
258
260
|
para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['latex']}\n$\n\n"
|
261
|
+
elif block["lines"][0]["spans"][0].get('html', ''):
|
262
|
+
para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
|
259
263
|
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
|
260
264
|
if block['type'] == BlockType.TableCaption:
|
261
265
|
para_content['table_caption'] = merge_para_with_text(block)
|
magic_pdf/libs/Constants.py
CHANGED
@@ -8,4 +8,33 @@ CROSS_PAGE = "cross_page"
|
|
8
8
|
block维度自定义字段
|
9
9
|
"""
|
10
10
|
# block中lines是否被删除
|
11
|
-
LINES_DELETED = "lines_deleted"
|
11
|
+
LINES_DELETED = "lines_deleted"
|
12
|
+
|
13
|
+
# struct eqtable
|
14
|
+
STRUCT_EQTABLE = "struct_eqtable"
|
15
|
+
|
16
|
+
# table recognition max time default value
|
17
|
+
TABLE_MAX_TIME_VALUE = 400
|
18
|
+
|
19
|
+
# pp_table_result_max_length
|
20
|
+
TABLE_MAX_LEN = 480
|
21
|
+
|
22
|
+
# pp table structure algorithm
|
23
|
+
TABLE_MASTER = "TableMaster"
|
24
|
+
|
25
|
+
# table master structure dict
|
26
|
+
TABLE_MASTER_DICT = "table_master_structure_dict.txt"
|
27
|
+
|
28
|
+
# table master dir
|
29
|
+
TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
|
30
|
+
|
31
|
+
# pp detect model dir
|
32
|
+
DETECT_MODEL_DIR = "ch_PP-OCRv3_det_infer"
|
33
|
+
|
34
|
+
# pp rec model dir
|
35
|
+
REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
|
36
|
+
|
37
|
+
# pp rec char dict path
|
38
|
+
REC_CHAR_DICT = "ppocr_keys_v1.txt"
|
39
|
+
|
40
|
+
|
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from magic_pdf.libs.Constants import CROSS_PAGE
|
2
2
|
from magic_pdf.libs.commons import fitz # PyMuPDF
|
3
|
-
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
|
3
|
+
from magic_pdf.libs.ocr_content_type import ContentType, BlockType, CategoryId
|
4
|
+
from magic_pdf.model.magic_model import MagicModel
|
4
5
|
|
5
6
|
|
6
7
|
def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
|
@@ -225,3 +226,67 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path):
|
|
225
226
|
|
226
227
|
# Save the PDF
|
227
228
|
pdf_docs.save(f"{out_path}/spans.pdf")
|
229
|
+
|
230
|
+
|
231
|
+
def drow_model_bbox(model_list: list, pdf_bytes, out_path):
|
232
|
+
dropped_bbox_list = []
|
233
|
+
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
|
234
|
+
imgs_body_list, imgs_caption_list = [], []
|
235
|
+
titles_list = []
|
236
|
+
texts_list = []
|
237
|
+
interequations_list = []
|
238
|
+
pdf_docs = fitz.open("pdf", pdf_bytes)
|
239
|
+
magic_model = MagicModel(model_list, pdf_docs)
|
240
|
+
for i in range(len(model_list)):
|
241
|
+
page_dropped_list = []
|
242
|
+
tables_body, tables_caption, tables_footnote = [], [], []
|
243
|
+
imgs_body, imgs_caption = [], []
|
244
|
+
titles = []
|
245
|
+
texts = []
|
246
|
+
interequations = []
|
247
|
+
page_info = magic_model.get_model_list(i)
|
248
|
+
layout_dets = page_info["layout_dets"]
|
249
|
+
for layout_det in layout_dets:
|
250
|
+
bbox = layout_det["bbox"]
|
251
|
+
if layout_det["category_id"] == CategoryId.Text:
|
252
|
+
texts.append(bbox)
|
253
|
+
elif layout_det["category_id"] == CategoryId.Title:
|
254
|
+
titles.append(bbox)
|
255
|
+
elif layout_det["category_id"] == CategoryId.TableBody:
|
256
|
+
tables_body.append(bbox)
|
257
|
+
elif layout_det["category_id"] == CategoryId.TableCaption:
|
258
|
+
tables_caption.append(bbox)
|
259
|
+
elif layout_det["category_id"] == CategoryId.TableFootnote:
|
260
|
+
tables_footnote.append(bbox)
|
261
|
+
elif layout_det["category_id"] == CategoryId.ImageBody:
|
262
|
+
imgs_body.append(bbox)
|
263
|
+
elif layout_det["category_id"] == CategoryId.ImageCaption:
|
264
|
+
imgs_caption.append(bbox)
|
265
|
+
elif layout_det["category_id"] == CategoryId.InterlineEquation_YOLO:
|
266
|
+
interequations.append(bbox)
|
267
|
+
elif layout_det["category_id"] == CategoryId.Abandon:
|
268
|
+
page_dropped_list.append(bbox)
|
269
|
+
|
270
|
+
tables_body_list.append(tables_body)
|
271
|
+
tables_caption_list.append(tables_caption)
|
272
|
+
tables_footnote_list.append(tables_footnote)
|
273
|
+
imgs_body_list.append(imgs_body)
|
274
|
+
imgs_caption_list.append(imgs_caption)
|
275
|
+
titles_list.append(titles)
|
276
|
+
texts_list.append(texts)
|
277
|
+
interequations_list.append(interequations)
|
278
|
+
dropped_bbox_list.append(page_dropped_list)
|
279
|
+
|
280
|
+
for i, page in enumerate(pdf_docs):
|
281
|
+
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158], True) # color !
|
282
|
+
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
|
283
|
+
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
|
284
|
+
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
|
285
|
+
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
|
286
|
+
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
|
287
|
+
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
|
288
|
+
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
|
289
|
+
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
|
290
|
+
|
291
|
+
# Save the PDF
|
292
|
+
pdf_docs.save(f"{out_path}/model.pdf")
|
@@ -19,3 +19,17 @@ class BlockType:
|
|
19
19
|
Footnote = "footnote"
|
20
20
|
Discarded = "discarded"
|
21
21
|
|
22
|
+
|
23
|
+
class CategoryId:
|
24
|
+
Title = 0
|
25
|
+
Text = 1
|
26
|
+
Abandon = 2
|
27
|
+
ImageBody = 3
|
28
|
+
ImageCaption = 4
|
29
|
+
TableBody = 5
|
30
|
+
TableCaption = 6
|
31
|
+
TableFootnote = 7
|
32
|
+
InterlineEquation_Layout = 8
|
33
|
+
InlineEquation = 13
|
34
|
+
InterlineEquation_YOLO = 14
|
35
|
+
OcrText = 15
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.7.
|
1
|
+
__version__ = "0.7.1"
|
@@ -37,8 +37,8 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
|
|
37
37
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
38
38
|
pm = page.get_pixmap(matrix=mat, alpha=False)
|
39
39
|
|
40
|
-
#
|
41
|
-
if pm.width >
|
40
|
+
# If the width or height exceeds 9000 after scaling, do not scale further.
|
41
|
+
if pm.width > 9000 or pm.height > 9000:
|
42
42
|
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
|
43
43
|
|
44
44
|
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
|
magic_pdf/model/magic_model.py
CHANGED
@@ -562,8 +562,11 @@ class MagicModel:
|
|
562
562
|
elif category_id == 5:
|
563
563
|
# 获取table模型结果
|
564
564
|
latex = layout_det.get("latex", None)
|
565
|
+
html = layout_det.get("html", None)
|
565
566
|
if latex:
|
566
567
|
span["latex"] = latex
|
568
|
+
elif html:
|
569
|
+
span["html"] = html
|
567
570
|
span["type"] = ContentType.Table
|
568
571
|
elif category_id == 13:
|
569
572
|
span["content"] = layout_det["latex"]
|
@@ -2,6 +2,7 @@ from loguru import logger
|
|
2
2
|
import os
|
3
3
|
import time
|
4
4
|
|
5
|
+
from magic_pdf.libs.Constants import *
|
5
6
|
|
6
7
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
7
8
|
try:
|
@@ -26,17 +27,25 @@ except ImportError as e:
|
|
26
27
|
logger.exception(e)
|
27
28
|
logger.error(
|
28
29
|
'Required dependency not installed, please install by \n'
|
29
|
-
'"pip install magic-pdf[full]
|
30
|
+
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
30
31
|
exit(1)
|
31
32
|
|
32
33
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
33
34
|
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
34
35
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
35
36
|
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
37
|
+
from magic_pdf.model.ppTableModel import ppTableModel
|
38
|
+
|
39
|
+
|
40
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
41
|
+
if table_model_type == STRUCT_EQTABLE:
|
42
|
+
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
43
|
+
else:
|
44
|
+
config = {
|
45
|
+
"model_dir": model_path,
|
46
|
+
"device": _device_
|
47
|
+
}
|
48
|
+
table_model = ppTableModel(config)
|
40
49
|
return table_model
|
41
50
|
|
42
51
|
|
@@ -103,8 +112,11 @@ class CustomPEKModel:
|
|
103
112
|
# 初始化解析配置
|
104
113
|
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
105
114
|
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
115
|
+
# table config
|
106
116
|
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
|
107
117
|
self.apply_table = self.table_config.get("is_table_recog_enable", False)
|
118
|
+
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
119
|
+
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
|
108
120
|
self.apply_ocr = ocr
|
109
121
|
logger.info(
|
110
122
|
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
|
@@ -139,11 +151,11 @@ class CustomPEKModel:
|
|
139
151
|
if self.apply_ocr:
|
140
152
|
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|
141
153
|
|
142
|
-
# init
|
154
|
+
# init table model
|
143
155
|
if self.apply_table:
|
144
|
-
|
145
|
-
self.table_model = table_model_init(str(os.path.join(models_dir,
|
146
|
-
max_time=
|
156
|
+
table_model_dir = self.configs["weights"][self.table_model_type]
|
157
|
+
self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
158
|
+
max_time=self.table_max_time, _device_=self.device)
|
147
159
|
logger.info('DocAnalysis init done!')
|
148
160
|
|
149
161
|
def __call__(self, image):
|
@@ -187,50 +199,56 @@ class CustomPEKModel:
|
|
187
199
|
mfr_cost = round(time.time() - mfr_start, 2)
|
188
200
|
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
189
201
|
|
202
|
+
# Select regions for OCR / formula regions / table regions
|
203
|
+
ocr_res_list = []
|
204
|
+
table_res_list = []
|
205
|
+
single_page_mfdetrec_res = []
|
206
|
+
for res in layout_res:
|
207
|
+
if int(res['category_id']) in [13, 14]:
|
208
|
+
single_page_mfdetrec_res.append({
|
209
|
+
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
210
|
+
int(res['poly'][4]), int(res['poly'][5])],
|
211
|
+
})
|
212
|
+
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
213
|
+
ocr_res_list.append(res)
|
214
|
+
elif int(res['category_id']) in [5]:
|
215
|
+
table_res_list.append(res)
|
216
|
+
|
217
|
+
# Unified crop img logic
|
218
|
+
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
219
|
+
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
220
|
+
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
221
|
+
# Create a white background with an additional width and height of 50
|
222
|
+
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
223
|
+
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
224
|
+
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
225
|
+
|
226
|
+
# Crop image
|
227
|
+
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
228
|
+
cropped_img = input_pil_img.crop(crop_box)
|
229
|
+
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
230
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
231
|
+
return return_image, return_list
|
232
|
+
|
233
|
+
pil_img = Image.fromarray(image)
|
234
|
+
|
190
235
|
# ocr识别
|
191
236
|
if self.apply_ocr:
|
192
237
|
ocr_start = time.time()
|
193
|
-
|
194
|
-
|
195
|
-
# 筛选出需要OCR的区域和公式区域
|
196
|
-
ocr_res_list = []
|
197
|
-
single_page_mfdetrec_res = []
|
198
|
-
for res in layout_res:
|
199
|
-
if int(res['category_id']) in [13, 14]:
|
200
|
-
single_page_mfdetrec_res.append({
|
201
|
-
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
202
|
-
int(res['poly'][4]), int(res['poly'][5])],
|
203
|
-
})
|
204
|
-
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
205
|
-
ocr_res_list.append(res)
|
206
|
-
|
207
|
-
# 对每一个需OCR处理的区域进行处理
|
238
|
+
# Process each area that requires OCR processing
|
208
239
|
for res in ocr_res_list:
|
209
|
-
|
210
|
-
xmax, ymax =
|
211
|
-
|
212
|
-
paste_x = 50
|
213
|
-
paste_y = 50
|
214
|
-
# 创建一个宽高各多50的白色背景
|
215
|
-
new_width = xmax - xmin + paste_x * 2
|
216
|
-
new_height = ymax - ymin + paste_y * 2
|
217
|
-
new_image = Image.new('RGB', (new_width, new_height), 'white')
|
218
|
-
|
219
|
-
# 裁剪图像
|
220
|
-
crop_box = (xmin, ymin, xmax, ymax)
|
221
|
-
cropped_img = pil_img.crop(crop_box)
|
222
|
-
new_image.paste(cropped_img, (paste_x, paste_y))
|
223
|
-
|
224
|
-
# 调整公式区域坐标
|
240
|
+
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
241
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
242
|
+
# Adjust the coordinates of the formula area
|
225
243
|
adjusted_mfdetrec_res = []
|
226
244
|
for mf_res in single_page_mfdetrec_res:
|
227
245
|
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
228
|
-
#
|
246
|
+
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
229
247
|
x0 = mf_xmin - xmin + paste_x
|
230
248
|
y0 = mf_ymin - ymin + paste_y
|
231
249
|
x1 = mf_xmax - xmin + paste_x
|
232
250
|
y1 = mf_ymax - ymin + paste_y
|
233
|
-
#
|
251
|
+
# Filter formula blocks outside the graph
|
234
252
|
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
235
253
|
continue
|
236
254
|
else:
|
@@ -238,17 +256,17 @@ class CustomPEKModel:
|
|
238
256
|
"bbox": [x0, y0, x1, y1],
|
239
257
|
})
|
240
258
|
|
241
|
-
# OCR
|
259
|
+
# OCR recognition
|
242
260
|
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
243
261
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
244
262
|
|
245
|
-
#
|
263
|
+
# Integration results
|
246
264
|
if ocr_res:
|
247
265
|
for box_ocr_res in ocr_res:
|
248
266
|
p1, p2, p3, p4 = box_ocr_res[0]
|
249
267
|
text, score = box_ocr_res[1]
|
250
268
|
|
251
|
-
#
|
269
|
+
# Convert the coordinates back to the original coordinate system
|
252
270
|
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
253
271
|
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
254
272
|
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
@@ -266,30 +284,36 @@ class CustomPEKModel:
|
|
266
284
|
|
267
285
|
# 表格识别 table recognition
|
268
286
|
if self.apply_table:
|
269
|
-
|
270
|
-
for
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
287
|
+
table_start = time.time()
|
288
|
+
for res in table_res_list:
|
289
|
+
new_image, _ = crop_img(res, pil_img)
|
290
|
+
single_table_start_time = time.time()
|
291
|
+
logger.info("------------------table recognition processing begins-----------------")
|
292
|
+
latex_code = None
|
293
|
+
html_code = None
|
294
|
+
with torch.no_grad():
|
295
|
+
if self.table_model_type == STRUCT_EQTABLE:
|
296
|
+
latex_code = self.table_model.image2latex(new_image)[0]
|
297
|
+
else:
|
298
|
+
html_code = self.table_model.img2html(new_image)
|
299
|
+
run_time = time.time() - single_table_start_time
|
300
|
+
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
301
|
+
if run_time > self.table_max_time:
|
302
|
+
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
303
|
+
# 判断是否返回正常
|
304
|
+
|
305
|
+
if latex_code:
|
306
|
+
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
|
307
|
+
'end{table}')
|
308
|
+
if expected_ending:
|
309
|
+
res["latex"] = latex_code
|
310
|
+
else:
|
311
|
+
logger.warning(f"------------table recognition processing fails----------")
|
312
|
+
elif html_code:
|
313
|
+
res["html"] = html_code
|
314
|
+
else:
|
315
|
+
logger.warning(f"------------table recognition processing fails----------")
|
316
|
+
table_cost = round(time.time() - table_start, 2)
|
317
|
+
logger.info(f"table cost: {table_cost}")
|
294
318
|
|
295
319
|
return layout_res
|
@@ -0,0 +1,67 @@
|
|
1
|
+
from paddleocr.ppstructure.table.predict_table import TableSystem
|
2
|
+
from paddleocr.ppstructure.utility import init_args
|
3
|
+
from magic_pdf.libs.Constants import *
|
4
|
+
import os
|
5
|
+
from PIL import Image
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
|
9
|
+
class ppTableModel(object):
|
10
|
+
"""
|
11
|
+
This class is responsible for converting image of table into HTML format using a pre-trained model.
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
- table_sys: An instance of TableSystem initialized with parsed arguments.
|
15
|
+
|
16
|
+
Methods:
|
17
|
+
- __init__(config): Initializes the model with configuration parameters.
|
18
|
+
- img2html(image): Converts a PIL Image or NumPy array to HTML string.
|
19
|
+
- parse_args(**kwargs): Parses configuration arguments.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, config):
|
23
|
+
"""
|
24
|
+
Parameters:
|
25
|
+
- config (dict): Configuration dictionary containing model_dir and device.
|
26
|
+
"""
|
27
|
+
args = self.parse_args(**config)
|
28
|
+
self.table_sys = TableSystem(args)
|
29
|
+
|
30
|
+
def img2html(self, image):
|
31
|
+
"""
|
32
|
+
Parameters:
|
33
|
+
- image (PIL.Image or np.ndarray): The image of the table to be converted.
|
34
|
+
|
35
|
+
Return:
|
36
|
+
- HTML (str): A string representing the HTML structure with content of the table.
|
37
|
+
"""
|
38
|
+
if isinstance(image, Image.Image):
|
39
|
+
image = np.array(image)
|
40
|
+
pred_res, _ = self.table_sys(image)
|
41
|
+
pred_html = pred_res["html"]
|
42
|
+
res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
|
43
|
+
"") + "</table></td>\n"
|
44
|
+
return res
|
45
|
+
|
46
|
+
def parse_args(self, **kwargs):
|
47
|
+
parser = init_args()
|
48
|
+
model_dir = kwargs.get("model_dir")
|
49
|
+
table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)
|
50
|
+
table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)
|
51
|
+
det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)
|
52
|
+
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
|
53
|
+
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
|
54
|
+
device = kwargs.get("device", "cpu")
|
55
|
+
use_gpu = True if device == "cuda" else False
|
56
|
+
config = {
|
57
|
+
"use_gpu": use_gpu,
|
58
|
+
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
|
59
|
+
"table_algorithm": TABLE_MASTER,
|
60
|
+
"table_model_dir": table_model_dir,
|
61
|
+
"table_char_dict_path": table_char_dict_path,
|
62
|
+
"det_model_dir": det_model_dir,
|
63
|
+
"rec_model_dir": rec_model_dir,
|
64
|
+
"rec_char_dict_path": rec_char_dict_path,
|
65
|
+
}
|
66
|
+
parser.set_defaults(**config)
|
67
|
+
return parser.parse_args([])
|
magic_pdf/para/para_split_v2.py
CHANGED
@@ -100,59 +100,62 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
|
|
100
100
|
|
101
101
|
if lang != 'en':
|
102
102
|
return lines, None
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
103
|
+
|
104
|
+
total_lines = len(lines)
|
105
|
+
line_fea_encode = []
|
106
|
+
"""
|
107
|
+
对每一行进行特征编码,编码规则如下:
|
108
|
+
1. 如果行顶格,且大写字母开头或者数字开头,编码为1
|
109
|
+
2. 如果顶格,其他非大写开头编码为4
|
110
|
+
3. 如果非顶格,首字符大写,编码为2
|
111
|
+
4. 如果非顶格,首字符非大写编码为3
|
112
|
+
"""
|
113
|
+
if len(lines) > 0:
|
114
|
+
x_map_tag_dict, min_x_tag = cluster_line_x(lines)
|
115
|
+
for l in lines:
|
116
|
+
span_text = __get_span_text(l['spans'][0])
|
117
|
+
if not span_text:
|
118
|
+
line_fea_encode.append(0)
|
119
|
+
continue
|
120
|
+
first_char = span_text[0]
|
121
|
+
layout = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)
|
122
|
+
if not layout:
|
123
|
+
line_fea_encode.append(0)
|
124
|
+
else:
|
125
|
+
#
|
126
|
+
if x_map_tag_dict[round(l['bbox'][0])] == min_x_tag:
|
127
|
+
# if first_char.isupper() or first_char.isdigit() or not first_char.isalnum():
|
128
|
+
if not first_char.isalnum() or if_match_reference_list(span_text):
|
129
|
+
line_fea_encode.append(1)
|
130
|
+
else:
|
131
|
+
line_fea_encode.append(4)
|
121
132
|
else:
|
122
|
-
|
123
|
-
|
124
|
-
# if first_char.isupper() or first_char.isdigit() or not first_char.isalnum():
|
125
|
-
if not first_char.isalnum() or if_match_reference_list(span_text):
|
126
|
-
line_fea_encode.append(1)
|
127
|
-
else:
|
128
|
-
line_fea_encode.append(4)
|
133
|
+
if first_char.isupper():
|
134
|
+
line_fea_encode.append(2)
|
129
135
|
else:
|
130
|
-
|
131
|
-
line_fea_encode.append(2)
|
132
|
-
else:
|
133
|
-
line_fea_encode.append(3)
|
136
|
+
line_fea_encode.append(3)
|
134
137
|
|
135
|
-
|
138
|
+
# 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
|
136
139
|
|
137
|
-
|
138
|
-
|
140
|
+
list_indice, list_start_idx = find_repeating_patterns2(line_fea_encode)
|
141
|
+
if len(list_indice) > 0:
|
142
|
+
if debug_able:
|
143
|
+
logger.info(f"发现了列表,列表行数:{list_indice}, {list_start_idx}")
|
144
|
+
|
145
|
+
# TODO check一下这个特列表里缩进的行左侧是不是对齐的。
|
146
|
+
segments = []
|
147
|
+
for start, end in list_indice:
|
148
|
+
for i in range(start, end + 1):
|
149
|
+
if i > 0:
|
150
|
+
if line_fea_encode[i] == 4:
|
151
|
+
if debug_able:
|
152
|
+
logger.info(f"列表行的第{i}行不是顶格的")
|
153
|
+
break
|
154
|
+
else:
|
139
155
|
if debug_able:
|
140
|
-
logger.info(f"
|
141
|
-
|
142
|
-
# TODO check一下这个特列表里缩进的行左侧是不是对齐的。
|
143
|
-
segments = []
|
144
|
-
for start, end in list_indice:
|
145
|
-
for i in range(start, end + 1):
|
146
|
-
if i > 0:
|
147
|
-
if line_fea_encode[i] == 4:
|
148
|
-
if debug_able:
|
149
|
-
logger.info(f"列表行的第{i}行不是顶格的")
|
150
|
-
break
|
151
|
-
else:
|
152
|
-
if debug_able:
|
153
|
-
logger.info(f"列表行的第{start}到第{end}行是列表")
|
156
|
+
logger.info(f"列表行的第{start}到第{end}行是列表")
|
154
157
|
|
155
|
-
|
158
|
+
return split_indices(total_lines, list_indice), list_start_idx
|
156
159
|
|
157
160
|
|
158
161
|
def cluster_line_x(lines: list) -> dict:
|
@@ -3,6 +3,7 @@ config:
|
|
3
3
|
layout: True
|
4
4
|
formula: True
|
5
5
|
table_config:
|
6
|
+
model: TableMaster
|
6
7
|
is_table_recog_enable: False
|
7
8
|
max_time: 400
|
8
9
|
|
@@ -10,4 +11,5 @@ weights:
|
|
10
11
|
layout: Layout/model_final.pth
|
11
12
|
mfd: MFD/weights.pt
|
12
13
|
mfr: MFR/UniMERNet
|
13
|
-
|
14
|
+
struct_eqtable: TabRec/StructEqTable
|
15
|
+
TableMaster: TabRec/TableMaster
|