magic-pdf 0.6.2b1__py3-none-any.whl → 0.7.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +10 -3
- magic_pdf/libs/Constants.py +4 -1
- magic_pdf/libs/config_reader.py +10 -10
- magic_pdf/libs/draw_bbox.py +66 -1
- magic_pdf/libs/ocr_content_type.py +14 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +10 -4
- magic_pdf/model/magic_model.py +4 -0
- magic_pdf/model/pdf_extract_kit.py +83 -39
- magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +22 -0
- magic_pdf/resources/model_config/model_configs.yaml +4 -0
- magic_pdf/rw/AbsReaderWriter.py +1 -18
- magic_pdf/rw/DiskReaderWriter.py +32 -24
- magic_pdf/rw/S3ReaderWriter.py +83 -48
- magic_pdf/tools/cli.py +79 -0
- magic_pdf/tools/cli_dev.py +155 -0
- magic_pdf/tools/common.py +122 -0
- magic_pdf-0.7.0b1.dist-info/METADATA +421 -0
- {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/RECORD +25 -27
- {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/WHEEL +1 -1
- magic_pdf-0.7.0b1.dist-info/entry_points.txt +3 -0
- magic_pdf/cli/magicpdf.py +0 -359
- magic_pdf/pdf_parse_for_train.py +0 -685
- magic_pdf/train_utils/convert_to_train_format.py +0 -65
- magic_pdf/train_utils/extract_caption.py +0 -59
- magic_pdf/train_utils/remove_footer_header.py +0 -159
- magic_pdf/train_utils/vis_utils.py +0 -327
- magic_pdf-0.6.2b1.dist-info/METADATA +0 -344
- magic_pdf-0.6.2b1.dist-info/entry_points.txt +0 -2
- /magic_pdf/{cli → model/pek_sub_modules/structeqtable}/__init__.py +0 -0
- /magic_pdf/{train_utils → tools}/__init__.py +0 -0
- {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.6.2b1.dist-info → magic_pdf-0.7.0b1.dist-info}/top_level.txt +0 -0
@@ -120,15 +120,20 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
|
|
120
120
|
if mode == 'nlp':
|
121
121
|
continue
|
122
122
|
elif mode == 'mm':
|
123
|
+
table_caption = ''
|
123
124
|
for block in para_block['blocks']: # 1st.拼table_caption
|
124
125
|
if block['type'] == BlockType.TableCaption:
|
125
|
-
|
126
|
+
table_caption = merge_para_with_text(block)
|
126
127
|
for block in para_block['blocks']: # 2nd.拼table_body
|
127
128
|
if block['type'] == BlockType.TableBody:
|
128
129
|
for line in block['lines']:
|
129
130
|
for span in line['spans']:
|
130
131
|
if span['type'] == ContentType.Table:
|
131
|
-
|
132
|
+
# if processed by table model
|
133
|
+
if span.get('latex', ''):
|
134
|
+
para_text += f"\n\n$\n {span['latex']}\n$\n\n"
|
135
|
+
else:
|
136
|
+
para_text += f"\n}) \n"
|
132
137
|
for block in para_block['blocks']: # 3rd.拼table_footnote
|
133
138
|
if block['type'] == BlockType.TableFootnote:
|
134
139
|
para_text += merge_para_with_text(block)
|
@@ -163,7 +168,7 @@ def merge_para_with_text(para_block):
|
|
163
168
|
else:
|
164
169
|
content = ocr_escape_special_markdown_char(content)
|
165
170
|
elif span_type == ContentType.InlineEquation:
|
166
|
-
content = f"${span['content']}$"
|
171
|
+
content = f" ${span['content']}$ "
|
167
172
|
elif span_type == ContentType.InterlineEquation:
|
168
173
|
content = f"\n$$\n{span['content']}\n$$\n"
|
169
174
|
|
@@ -249,6 +254,8 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
|
|
249
254
|
}
|
250
255
|
for block in para_block['blocks']:
|
251
256
|
if block['type'] == BlockType.TableBody:
|
257
|
+
if block["lines"][0]["spans"][0].get('latex', ''):
|
258
|
+
para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['latex']}\n$\n\n"
|
252
259
|
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
|
253
260
|
if block['type'] == BlockType.TableCaption:
|
254
261
|
para_content['table_caption'] = merge_para_with_text(block)
|
magic_pdf/libs/Constants.py
CHANGED
magic_pdf/libs/config_reader.py
CHANGED
@@ -57,16 +57,6 @@ def get_bucket_name(path):
|
|
57
57
|
return bucket
|
58
58
|
|
59
59
|
|
60
|
-
def get_local_dir():
|
61
|
-
config = read_config()
|
62
|
-
local_dir = config.get("temp-output-dir")
|
63
|
-
if local_dir is None:
|
64
|
-
logger.warning(f"'temp-output-dir' not found in {CONFIG_FILE_NAME}, use '/tmp' as default")
|
65
|
-
return "/tmp"
|
66
|
-
else:
|
67
|
-
return local_dir
|
68
|
-
|
69
|
-
|
70
60
|
def get_local_models_dir():
|
71
61
|
config = read_config()
|
72
62
|
models_dir = config.get("models-dir")
|
@@ -87,5 +77,15 @@ def get_device():
|
|
87
77
|
return device
|
88
78
|
|
89
79
|
|
80
|
+
def get_table_recog_config():
|
81
|
+
config = read_config()
|
82
|
+
table_config = config.get("table-config")
|
83
|
+
if table_config is None:
|
84
|
+
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
|
85
|
+
return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
|
86
|
+
else:
|
87
|
+
return table_config
|
88
|
+
|
89
|
+
|
90
90
|
if __name__ == "__main__":
|
91
91
|
ak, sk, endpoint = get_s3_config("llm-raw")
|
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from magic_pdf.libs.Constants import CROSS_PAGE
|
2
2
|
from magic_pdf.libs.commons import fitz # PyMuPDF
|
3
|
-
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
|
3
|
+
from magic_pdf.libs.ocr_content_type import ContentType, BlockType, CategoryId
|
4
|
+
from magic_pdf.model.magic_model import MagicModel
|
4
5
|
|
5
6
|
|
6
7
|
def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
|
@@ -225,3 +226,67 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path):
|
|
225
226
|
|
226
227
|
# Save the PDF
|
227
228
|
pdf_docs.save(f"{out_path}/spans.pdf")
|
229
|
+
|
230
|
+
|
231
|
+
def drow_model_bbox(model_list: list, pdf_bytes, out_path):
|
232
|
+
dropped_bbox_list = []
|
233
|
+
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
|
234
|
+
imgs_body_list, imgs_caption_list = [], []
|
235
|
+
titles_list = []
|
236
|
+
texts_list = []
|
237
|
+
interequations_list = []
|
238
|
+
pdf_docs = fitz.open("pdf", pdf_bytes)
|
239
|
+
magic_model = MagicModel(model_list, pdf_docs)
|
240
|
+
for i in range(len(model_list)):
|
241
|
+
page_dropped_list = []
|
242
|
+
tables_body, tables_caption, tables_footnote = [], [], []
|
243
|
+
imgs_body, imgs_caption = [], []
|
244
|
+
titles = []
|
245
|
+
texts = []
|
246
|
+
interequations = []
|
247
|
+
page_info = magic_model.get_model_list(i)
|
248
|
+
layout_dets = page_info["layout_dets"]
|
249
|
+
for layout_det in layout_dets:
|
250
|
+
bbox = layout_det["bbox"]
|
251
|
+
if layout_det["category_id"] == CategoryId.Text:
|
252
|
+
texts.append(bbox)
|
253
|
+
elif layout_det["category_id"] == CategoryId.Title:
|
254
|
+
titles.append(bbox)
|
255
|
+
elif layout_det["category_id"] == CategoryId.TableBody:
|
256
|
+
tables_body.append(bbox)
|
257
|
+
elif layout_det["category_id"] == CategoryId.TableCaption:
|
258
|
+
tables_caption.append(bbox)
|
259
|
+
elif layout_det["category_id"] == CategoryId.TableFootnote:
|
260
|
+
tables_footnote.append(bbox)
|
261
|
+
elif layout_det["category_id"] == CategoryId.ImageBody:
|
262
|
+
imgs_body.append(bbox)
|
263
|
+
elif layout_det["category_id"] == CategoryId.ImageCaption:
|
264
|
+
imgs_caption.append(bbox)
|
265
|
+
elif layout_det["category_id"] == CategoryId.InterlineEquation_YOLO:
|
266
|
+
interequations.append(bbox)
|
267
|
+
elif layout_det["category_id"] == CategoryId.Abandon:
|
268
|
+
page_dropped_list.append(bbox)
|
269
|
+
|
270
|
+
tables_body_list.append(tables_body)
|
271
|
+
tables_caption_list.append(tables_caption)
|
272
|
+
tables_footnote_list.append(tables_footnote)
|
273
|
+
imgs_body_list.append(imgs_body)
|
274
|
+
imgs_caption_list.append(imgs_caption)
|
275
|
+
titles_list.append(titles)
|
276
|
+
texts_list.append(texts)
|
277
|
+
interequations_list.append(interequations)
|
278
|
+
dropped_bbox_list.append(page_dropped_list)
|
279
|
+
|
280
|
+
for i, page in enumerate(pdf_docs):
|
281
|
+
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158], True) # color !
|
282
|
+
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
|
283
|
+
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
|
284
|
+
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
|
285
|
+
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
|
286
|
+
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
|
287
|
+
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
|
288
|
+
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
|
289
|
+
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
|
290
|
+
|
291
|
+
# Save the PDF
|
292
|
+
pdf_docs.save(f"{out_path}/model.pdf")
|
@@ -19,3 +19,17 @@ class BlockType:
|
|
19
19
|
Footnote = "footnote"
|
20
20
|
Discarded = "discarded"
|
21
21
|
|
22
|
+
|
23
|
+
class CategoryId:
|
24
|
+
Title = 0
|
25
|
+
Text = 1
|
26
|
+
Abandon = 2
|
27
|
+
ImageBody = 3
|
28
|
+
ImageCaption = 4
|
29
|
+
TableBody = 5
|
30
|
+
TableCaption = 6
|
31
|
+
TableFootnote = 7
|
32
|
+
InterlineEquation_Layout = 8
|
33
|
+
InlineEquation = 13
|
34
|
+
InterlineEquation_YOLO = 14
|
35
|
+
OcrText = 15
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.7.0b1"
|
@@ -4,7 +4,7 @@ import fitz
|
|
4
4
|
import numpy as np
|
5
5
|
from loguru import logger
|
6
6
|
|
7
|
-
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
|
7
|
+
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
|
8
8
|
from magic_pdf.model.model_list import MODEL
|
9
9
|
import magic_pdf.model as model_config
|
10
10
|
|
@@ -37,8 +37,8 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
|
|
37
37
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
38
38
|
pm = page.get_pixmap(matrix=mat, alpha=False)
|
39
39
|
|
40
|
-
#
|
41
|
-
if pm.width >
|
40
|
+
# If the width or height exceeds 9000 after scaling, do not scale further.
|
41
|
+
if pm.width > 9000 or pm.height > 9000:
|
42
42
|
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
|
43
43
|
|
44
44
|
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
|
@@ -84,7 +84,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
|
|
84
84
|
# 从配置文件读取model-dir和device
|
85
85
|
local_models_dir = get_local_models_dir()
|
86
86
|
device = get_device()
|
87
|
-
|
87
|
+
table_config = get_table_recog_config()
|
88
|
+
model_input = {"ocr": ocr,
|
89
|
+
"show_log": show_log,
|
90
|
+
"models_dir": local_models_dir,
|
91
|
+
"device": device,
|
92
|
+
"table_config": table_config}
|
93
|
+
custom_model = CustomPEKModel(**model_input)
|
88
94
|
else:
|
89
95
|
logger.error("Not allow model_name!")
|
90
96
|
exit(1)
|
magic_pdf/model/magic_model.py
CHANGED
@@ -560,6 +560,10 @@ class MagicModel:
|
|
560
560
|
if category_id == 3:
|
561
561
|
span["type"] = ContentType.Image
|
562
562
|
elif category_id == 5:
|
563
|
+
# 获取table模型结果
|
564
|
+
latex = layout_det.get("latex", None)
|
565
|
+
if latex:
|
566
|
+
span["latex"] = latex
|
563
567
|
span["type"] = ContentType.Table
|
564
568
|
elif category_id == 13:
|
565
569
|
span["content"] = layout_det["latex"]
|
@@ -2,6 +2,8 @@ from loguru import logger
|
|
2
2
|
import os
|
3
3
|
import time
|
4
4
|
|
5
|
+
from magic_pdf.libs.Constants import TABLE_MAX_TIME_VALUE
|
6
|
+
|
5
7
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
6
8
|
try:
|
7
9
|
import cv2
|
@@ -10,6 +12,7 @@ try:
|
|
10
12
|
import numpy as np
|
11
13
|
import torch
|
12
14
|
import torchtext
|
15
|
+
|
13
16
|
if torchtext.__version__ >= "0.18.0":
|
14
17
|
torchtext.disable_torchtext_deprecation_warning()
|
15
18
|
from PIL import Image
|
@@ -24,12 +27,18 @@ except ImportError as e:
|
|
24
27
|
logger.exception(e)
|
25
28
|
logger.error(
|
26
29
|
'Required dependency not installed, please install by \n'
|
27
|
-
'"pip install magic-pdf[full]
|
30
|
+
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
|
28
31
|
exit(1)
|
29
32
|
|
30
33
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
31
34
|
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
32
35
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
36
|
+
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
|
37
|
+
|
38
|
+
|
39
|
+
def table_model_init(model_path, max_time, _device_='cpu'):
|
40
|
+
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
41
|
+
return table_model
|
33
42
|
|
34
43
|
|
35
44
|
def mfd_model_init(weight):
|
@@ -95,10 +104,13 @@ class CustomPEKModel:
|
|
95
104
|
# 初始化解析配置
|
96
105
|
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
97
106
|
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
107
|
+
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
|
108
|
+
self.apply_table = self.table_config.get("is_table_recog_enable", False)
|
109
|
+
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
98
110
|
self.apply_ocr = ocr
|
99
111
|
logger.info(
|
100
|
-
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
|
101
|
-
self.apply_layout, self.apply_formula, self.apply_ocr
|
112
|
+
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
|
113
|
+
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
|
102
114
|
)
|
103
115
|
)
|
104
116
|
assert self.apply_layout, "DocAnalysis must contain layout model."
|
@@ -129,6 +141,10 @@ class CustomPEKModel:
|
|
129
141
|
if self.apply_ocr:
|
130
142
|
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|
131
143
|
|
144
|
+
# init structeqtable
|
145
|
+
if self.apply_table:
|
146
|
+
self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
|
147
|
+
max_time = self.table_max_time, _device_=self.device)
|
132
148
|
logger.info('DocAnalysis init done!')
|
133
149
|
|
134
150
|
def __call__(self, image):
|
@@ -172,50 +188,56 @@ class CustomPEKModel:
|
|
172
188
|
mfr_cost = round(time.time() - mfr_start, 2)
|
173
189
|
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
174
190
|
|
191
|
+
# Select regions for OCR / formula regions / table regions
|
192
|
+
ocr_res_list = []
|
193
|
+
table_res_list = []
|
194
|
+
single_page_mfdetrec_res = []
|
195
|
+
for res in layout_res:
|
196
|
+
if int(res['category_id']) in [13, 14]:
|
197
|
+
single_page_mfdetrec_res.append({
|
198
|
+
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
199
|
+
int(res['poly'][4]), int(res['poly'][5])],
|
200
|
+
})
|
201
|
+
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
202
|
+
ocr_res_list.append(res)
|
203
|
+
elif int(res['category_id']) in [5]:
|
204
|
+
table_res_list.append(res)
|
205
|
+
|
206
|
+
# Unified crop img logic
|
207
|
+
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
208
|
+
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
209
|
+
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
210
|
+
# Create a white background with an additional width and height of 50
|
211
|
+
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
212
|
+
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
213
|
+
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
214
|
+
|
215
|
+
# Crop image
|
216
|
+
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
217
|
+
cropped_img = input_pil_img.crop(crop_box)
|
218
|
+
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
219
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
220
|
+
return return_image, return_list
|
221
|
+
|
222
|
+
pil_img = Image.fromarray(image)
|
223
|
+
|
175
224
|
# ocr识别
|
176
225
|
if self.apply_ocr:
|
177
226
|
ocr_start = time.time()
|
178
|
-
|
179
|
-
|
180
|
-
# 筛选出需要OCR的区域和公式区域
|
181
|
-
ocr_res_list = []
|
182
|
-
single_page_mfdetrec_res = []
|
183
|
-
for res in layout_res:
|
184
|
-
if int(res['category_id']) in [13, 14]:
|
185
|
-
single_page_mfdetrec_res.append({
|
186
|
-
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
187
|
-
int(res['poly'][4]), int(res['poly'][5])],
|
188
|
-
})
|
189
|
-
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
190
|
-
ocr_res_list.append(res)
|
191
|
-
|
192
|
-
# 对每一个需OCR处理的区域进行处理
|
227
|
+
# Process each area that requires OCR processing
|
193
228
|
for res in ocr_res_list:
|
194
|
-
|
195
|
-
xmax, ymax =
|
196
|
-
|
197
|
-
paste_x = 50
|
198
|
-
paste_y = 50
|
199
|
-
# 创建一个宽高各多50的白色背景
|
200
|
-
new_width = xmax - xmin + paste_x * 2
|
201
|
-
new_height = ymax - ymin + paste_y * 2
|
202
|
-
new_image = Image.new('RGB', (new_width, new_height), 'white')
|
203
|
-
|
204
|
-
# 裁剪图像
|
205
|
-
crop_box = (xmin, ymin, xmax, ymax)
|
206
|
-
cropped_img = pil_img.crop(crop_box)
|
207
|
-
new_image.paste(cropped_img, (paste_x, paste_y))
|
208
|
-
|
209
|
-
# 调整公式区域坐标
|
229
|
+
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
230
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
231
|
+
# Adjust the coordinates of the formula area
|
210
232
|
adjusted_mfdetrec_res = []
|
211
233
|
for mf_res in single_page_mfdetrec_res:
|
212
234
|
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
213
|
-
#
|
235
|
+
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
214
236
|
x0 = mf_xmin - xmin + paste_x
|
215
237
|
y0 = mf_ymin - ymin + paste_y
|
216
238
|
x1 = mf_xmax - xmin + paste_x
|
217
239
|
y1 = mf_ymax - ymin + paste_y
|
218
|
-
#
|
240
|
+
# Filter formula blocks outside the graph
|
219
241
|
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
220
242
|
continue
|
221
243
|
else:
|
@@ -223,17 +245,17 @@ class CustomPEKModel:
|
|
223
245
|
"bbox": [x0, y0, x1, y1],
|
224
246
|
})
|
225
247
|
|
226
|
-
# OCR
|
248
|
+
# OCR recognition
|
227
249
|
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
228
250
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
229
251
|
|
230
|
-
#
|
252
|
+
# Integration results
|
231
253
|
if ocr_res:
|
232
254
|
for box_ocr_res in ocr_res:
|
233
255
|
p1, p2, p3, p4 = box_ocr_res[0]
|
234
256
|
text, score = box_ocr_res[1]
|
235
257
|
|
236
|
-
#
|
258
|
+
# Convert the coordinates back to the original coordinate system
|
237
259
|
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
238
260
|
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
239
261
|
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
@@ -249,4 +271,26 @@ class CustomPEKModel:
|
|
249
271
|
ocr_cost = round(time.time() - ocr_start, 2)
|
250
272
|
logger.info(f"ocr cost: {ocr_cost}")
|
251
273
|
|
274
|
+
# 表格识别 table recognition
|
275
|
+
if self.apply_table:
|
276
|
+
table_start = time.time()
|
277
|
+
for res in table_res_list:
|
278
|
+
new_image, _ = crop_img(res, pil_img)
|
279
|
+
single_table_start_time = time.time()
|
280
|
+
logger.info("------------------table recognition processing begins-----------------")
|
281
|
+
with torch.no_grad():
|
282
|
+
latex_code = self.table_model.image2latex(new_image)[0]
|
283
|
+
run_time = time.time() - single_table_start_time
|
284
|
+
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
285
|
+
if run_time > self.table_max_time:
|
286
|
+
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
|
287
|
+
# 判断是否返回正常
|
288
|
+
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
|
289
|
+
if latex_code and expected_ending:
|
290
|
+
res["latex"] = latex_code
|
291
|
+
else:
|
292
|
+
logger.warning(f"------------table recognition processing fails----------")
|
293
|
+
table_cost = round(time.time() - table_start, 2)
|
294
|
+
logger.info(f"table cost: {table_cost}")
|
295
|
+
|
252
296
|
return layout_res
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from struct_eqtable.model import StructTable
|
2
|
+
from pypandoc import convert_text
|
3
|
+
class StructTableModel:
|
4
|
+
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
|
5
|
+
# init
|
6
|
+
self.model_path = model_path
|
7
|
+
self.max_new_tokens = max_new_tokens # maximum output tokens length
|
8
|
+
self.max_time = max_time # timeout for processing in seconds
|
9
|
+
if device == 'cuda':
|
10
|
+
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
|
11
|
+
else:
|
12
|
+
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
|
13
|
+
|
14
|
+
def image2latex(self, image) -> str:
|
15
|
+
#
|
16
|
+
table_latex = self.model.forward(image)
|
17
|
+
return table_latex
|
18
|
+
|
19
|
+
def image2html(self, image) -> str:
|
20
|
+
table_latex = self.image2latex(image)
|
21
|
+
table_html = convert_text(table_latex, 'html', format='latex')
|
22
|
+
return table_html
|
magic_pdf/rw/AbsReaderWriter.py
CHANGED
@@ -2,33 +2,16 @@ from abc import ABC, abstractmethod
|
|
2
2
|
|
3
3
|
|
4
4
|
class AbsReaderWriter(ABC):
|
5
|
-
"""
|
6
|
-
同时支持二进制和文本读写的抽象类
|
7
|
-
"""
|
8
5
|
MODE_TXT = "text"
|
9
6
|
MODE_BIN = "binary"
|
10
|
-
|
11
|
-
def __init__(self, parent_path):
|
12
|
-
# 初始化代码可以在这里添加,如果需要的话
|
13
|
-
self.parent_path = parent_path # 对于本地目录是父目录,对于s3是会写到这个path下。
|
14
|
-
|
15
7
|
@abstractmethod
|
16
8
|
def read(self, path: str, mode=MODE_TXT):
|
17
|
-
"""
|
18
|
-
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
|
19
|
-
"""
|
20
9
|
raise NotImplementedError
|
21
10
|
|
22
11
|
@abstractmethod
|
23
12
|
def write(self, content: str, path: str, mode=MODE_TXT):
|
24
|
-
"""
|
25
|
-
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
|
26
|
-
"""
|
27
13
|
raise NotImplementedError
|
28
14
|
|
29
15
|
@abstractmethod
|
30
|
-
def
|
31
|
-
"""
|
32
|
-
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
|
33
|
-
"""
|
16
|
+
def read_offset(self, path: str, offset=0, limit=None) -> bytes:
|
34
17
|
raise NotImplementedError
|
magic_pdf/rw/DiskReaderWriter.py
CHANGED
@@ -3,34 +3,29 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
|
|
3
3
|
from loguru import logger
|
4
4
|
|
5
5
|
|
6
|
-
MODE_TXT = "text"
|
7
|
-
MODE_BIN = "binary"
|
8
|
-
|
9
|
-
|
10
6
|
class DiskReaderWriter(AbsReaderWriter):
|
11
|
-
|
12
7
|
def __init__(self, parent_path, encoding="utf-8"):
|
13
8
|
self.path = parent_path
|
14
9
|
self.encoding = encoding
|
15
10
|
|
16
|
-
def read(self, path, mode=MODE_TXT):
|
11
|
+
def read(self, path, mode=AbsReaderWriter.MODE_TXT):
|
17
12
|
if os.path.isabs(path):
|
18
13
|
abspath = path
|
19
14
|
else:
|
20
15
|
abspath = os.path.join(self.path, path)
|
21
16
|
if not os.path.exists(abspath):
|
22
|
-
logger.error(f"
|
23
|
-
raise Exception(f"
|
24
|
-
if mode == MODE_TXT:
|
17
|
+
logger.error(f"file {abspath} not exists")
|
18
|
+
raise Exception(f"file {abspath} no exists")
|
19
|
+
if mode == AbsReaderWriter.MODE_TXT:
|
25
20
|
with open(abspath, "r", encoding=self.encoding) as f:
|
26
21
|
return f.read()
|
27
|
-
elif mode == MODE_BIN:
|
22
|
+
elif mode == AbsReaderWriter.MODE_BIN:
|
28
23
|
with open(abspath, "rb") as f:
|
29
24
|
return f.read()
|
30
25
|
else:
|
31
26
|
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
|
32
27
|
|
33
|
-
def write(self, content, path, mode=MODE_TXT):
|
28
|
+
def write(self, content, path, mode=AbsReaderWriter.MODE_TXT):
|
34
29
|
if os.path.isabs(path):
|
35
30
|
abspath = path
|
36
31
|
else:
|
@@ -38,29 +33,42 @@ class DiskReaderWriter(AbsReaderWriter):
|
|
38
33
|
directory_path = os.path.dirname(abspath)
|
39
34
|
if not os.path.exists(directory_path):
|
40
35
|
os.makedirs(directory_path)
|
41
|
-
if mode == MODE_TXT:
|
36
|
+
if mode == AbsReaderWriter.MODE_TXT:
|
42
37
|
with open(abspath, "w", encoding=self.encoding, errors="replace") as f:
|
43
38
|
f.write(content)
|
44
39
|
|
45
|
-
elif mode == MODE_BIN:
|
40
|
+
elif mode == AbsReaderWriter.MODE_BIN:
|
46
41
|
with open(abspath, "wb") as f:
|
47
42
|
f.write(content)
|
48
43
|
else:
|
49
44
|
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
|
50
45
|
|
51
|
-
def
|
52
|
-
|
46
|
+
def read_offset(self, path: str, offset=0, limit=None):
|
47
|
+
abspath = path
|
48
|
+
if not os.path.isabs(path):
|
49
|
+
abspath = os.path.join(self.path, path)
|
50
|
+
with open(abspath, "rb") as f:
|
51
|
+
f.seek(offset)
|
52
|
+
return f.read(limit)
|
53
53
|
|
54
54
|
|
55
|
-
# 使用示例
|
56
55
|
if __name__ == "__main__":
|
57
|
-
|
58
|
-
|
56
|
+
if 0:
|
57
|
+
file_path = "io/test/example.txt"
|
58
|
+
drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf")
|
59
|
+
|
60
|
+
# 写入内容到文件
|
61
|
+
drw.write(b"Hello, World!", path="io/test/example.txt", mode="binary")
|
62
|
+
|
63
|
+
# 从文件读取内容
|
64
|
+
content = drw.read(path=file_path)
|
65
|
+
if content:
|
66
|
+
logger.info(f"从 {file_path} 读取的内容: {content}")
|
67
|
+
if 1:
|
68
|
+
drw = DiskReaderWriter("/opt/data/pdf/resources/test/io/")
|
69
|
+
content_bin = drw.read_offset("1.txt")
|
70
|
+
assert content_bin == b"ABCD!"
|
59
71
|
|
60
|
-
|
61
|
-
|
72
|
+
content_bin = drw.read_offset("1.txt", offset=1, limit=2)
|
73
|
+
assert content_bin == b"BC"
|
62
74
|
|
63
|
-
# 从文件读取内容
|
64
|
-
content = drw.read(path=file_path)
|
65
|
-
if content:
|
66
|
-
logger.info(f"从 {file_path} 读取的内容: {content}")
|