magic-pdf 0.10.4__py3-none-any.whl → 0.10.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +5 -0
- magic_pdf/data/data_reader_writer/base.py +13 -1
- magic_pdf/data/dataset.py +175 -4
- magic_pdf/data/utils.py +2 -2
- magic_pdf/dict2md/ocr_mkcontent.py +2 -2
- magic_pdf/filter/__init__.py +32 -0
- magic_pdf/filter/pdf_meta_scan.py +3 -2
- magic_pdf/libs/draw_bbox.py +11 -10
- magic_pdf/libs/pdf_check.py +30 -30
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +124 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +119 -60
- magic_pdf/model/operators.py +190 -0
- magic_pdf/model/pdf_extract_kit.py +20 -1
- magic_pdf/model/sub_modules/model_init.py +13 -3
- magic_pdf/model/sub_modules/model_utils.py +11 -5
- magic_pdf/para/para_split_v3.py +2 -2
- magic_pdf/pdf_parse_by_ocr.py +4 -5
- magic_pdf/pdf_parse_by_txt.py +4 -5
- magic_pdf/pdf_parse_union_core_v2.py +10 -11
- magic_pdf/pipe/AbsPipe.py +3 -2
- magic_pdf/pipe/OCRPipe.py +54 -15
- magic_pdf/pipe/TXTPipe.py +5 -4
- magic_pdf/pipe/UNIPipe.py +82 -30
- magic_pdf/pipe/operators.py +138 -0
- magic_pdf/pre_proc/cut_image.py +2 -2
- magic_pdf/tools/common.py +108 -59
- magic_pdf/user_api.py +47 -24
- {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/METADATA +7 -4
- {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/RECORD +34 -32
- {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/WHEEL +0 -0
- {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,34 @@
|
|
1
|
+
import os
|
1
2
|
import time
|
2
3
|
|
3
4
|
import fitz
|
4
5
|
import numpy as np
|
5
6
|
from loguru import logger
|
6
7
|
|
8
|
+
# 关闭paddle的信号处理
|
9
|
+
import paddle
|
10
|
+
paddle.disable_signal_handler()
|
11
|
+
|
12
|
+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
13
|
+
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
14
|
+
|
15
|
+
try:
|
16
|
+
import torchtext
|
17
|
+
|
18
|
+
if torchtext.__version__ >= '0.18.0':
|
19
|
+
torchtext.disable_torchtext_deprecation_warning()
|
20
|
+
except ImportError:
|
21
|
+
pass
|
22
|
+
|
23
|
+
import magic_pdf.model as model_config
|
24
|
+
from magic_pdf.data.dataset import Dataset
|
7
25
|
from magic_pdf.libs.clean_memory import clean_memory
|
8
|
-
from magic_pdf.libs.config_reader import
|
9
|
-
|
26
|
+
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
|
27
|
+
get_layout_config,
|
28
|
+
get_local_models_dir,
|
29
|
+
get_table_recog_config)
|
10
30
|
from magic_pdf.model.model_list import MODEL
|
11
|
-
|
31
|
+
from magic_pdf.model.operators import InferenceResult
|
12
32
|
|
13
33
|
|
14
34
|
def dict_compare(d1, d2):
|
@@ -19,25 +39,31 @@ def remove_duplicates_dicts(lst):
|
|
19
39
|
unique_dicts = []
|
20
40
|
for dict_item in lst:
|
21
41
|
if not any(
|
22
|
-
|
42
|
+
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
|
23
43
|
):
|
24
44
|
unique_dicts.append(dict_item)
|
25
45
|
return unique_dicts
|
26
46
|
|
27
47
|
|
28
|
-
def load_images_from_pdf(
|
48
|
+
def load_images_from_pdf(
|
49
|
+
pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
|
50
|
+
) -> list:
|
29
51
|
try:
|
30
52
|
from PIL import Image
|
31
53
|
except ImportError:
|
32
|
-
logger.error(
|
54
|
+
logger.error('Pillow not installed, please install by pip.')
|
33
55
|
exit(1)
|
34
56
|
|
35
57
|
images = []
|
36
|
-
with fitz.open(
|
58
|
+
with fitz.open('pdf', pdf_bytes) as doc:
|
37
59
|
pdf_page_num = doc.page_count
|
38
|
-
end_page_id =
|
60
|
+
end_page_id = (
|
61
|
+
end_page_id
|
62
|
+
if end_page_id is not None and end_page_id >= 0
|
63
|
+
else pdf_page_num - 1
|
64
|
+
)
|
39
65
|
if end_page_id > pdf_page_num - 1:
|
40
|
-
logger.warning(
|
66
|
+
logger.warning('end_page_id is out of range, use images length')
|
41
67
|
end_page_id = pdf_page_num - 1
|
42
68
|
|
43
69
|
for index in range(0, doc.page_count):
|
@@ -50,11 +76,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
|
|
50
76
|
if pm.width > 4500 or pm.height > 4500:
|
51
77
|
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
|
52
78
|
|
53
|
-
img = Image.frombytes(
|
79
|
+
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
|
54
80
|
img = np.array(img)
|
55
|
-
img_dict = {
|
81
|
+
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
|
56
82
|
else:
|
57
|
-
img_dict = {
|
83
|
+
img_dict = {'img': [], 'width': 0, 'height': 0}
|
58
84
|
|
59
85
|
images.append(img_dict)
|
60
86
|
return images
|
@@ -69,117 +95,150 @@ class ModelSingleton:
|
|
69
95
|
cls._instance = super().__new__(cls)
|
70
96
|
return cls._instance
|
71
97
|
|
72
|
-
def get_model(
|
98
|
+
def get_model(
|
99
|
+
self,
|
100
|
+
ocr: bool,
|
101
|
+
show_log: bool,
|
102
|
+
lang=None,
|
103
|
+
layout_model=None,
|
104
|
+
formula_enable=None,
|
105
|
+
table_enable=None,
|
106
|
+
):
|
73
107
|
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
|
74
108
|
if key not in self._models:
|
75
|
-
self._models[key] = custom_model_init(
|
76
|
-
|
109
|
+
self._models[key] = custom_model_init(
|
110
|
+
ocr=ocr,
|
111
|
+
show_log=show_log,
|
112
|
+
lang=lang,
|
113
|
+
layout_model=layout_model,
|
114
|
+
formula_enable=formula_enable,
|
115
|
+
table_enable=table_enable,
|
116
|
+
)
|
77
117
|
return self._models[key]
|
78
118
|
|
79
119
|
|
80
|
-
def custom_model_init(
|
81
|
-
|
120
|
+
def custom_model_init(
|
121
|
+
ocr: bool = False,
|
122
|
+
show_log: bool = False,
|
123
|
+
lang=None,
|
124
|
+
layout_model=None,
|
125
|
+
formula_enable=None,
|
126
|
+
table_enable=None,
|
127
|
+
):
|
82
128
|
|
83
129
|
model = None
|
84
130
|
|
85
|
-
if model_config.__model_mode__ ==
|
86
|
-
logger.warning(
|
87
|
-
|
131
|
+
if model_config.__model_mode__ == 'lite':
|
132
|
+
logger.warning(
|
133
|
+
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
|
134
|
+
'not guaranteed to be reliable.'
|
135
|
+
)
|
88
136
|
model = MODEL.Paddle
|
89
|
-
elif model_config.__model_mode__ ==
|
137
|
+
elif model_config.__model_mode__ == 'full':
|
90
138
|
model = MODEL.PEK
|
91
139
|
|
92
140
|
if model_config.__use_inside_model__:
|
93
141
|
model_init_start = time.time()
|
94
142
|
if model == MODEL.Paddle:
|
95
143
|
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
|
144
|
+
|
96
145
|
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
|
97
146
|
elif model == MODEL.PEK:
|
98
147
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
148
|
+
|
99
149
|
# 从配置文件读取model-dir和device
|
100
150
|
local_models_dir = get_local_models_dir()
|
101
151
|
device = get_device()
|
102
152
|
|
103
153
|
layout_config = get_layout_config()
|
104
154
|
if layout_model is not None:
|
105
|
-
layout_config[
|
155
|
+
layout_config['model'] = layout_model
|
106
156
|
|
107
157
|
formula_config = get_formula_config()
|
108
158
|
if formula_enable is not None:
|
109
|
-
formula_config[
|
159
|
+
formula_config['enable'] = formula_enable
|
110
160
|
|
111
161
|
table_config = get_table_recog_config()
|
112
162
|
if table_enable is not None:
|
113
|
-
table_config[
|
163
|
+
table_config['enable'] = table_enable
|
114
164
|
|
115
165
|
model_input = {
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
166
|
+
'ocr': ocr,
|
167
|
+
'show_log': show_log,
|
168
|
+
'models_dir': local_models_dir,
|
169
|
+
'device': device,
|
170
|
+
'table_config': table_config,
|
171
|
+
'layout_config': layout_config,
|
172
|
+
'formula_config': formula_config,
|
173
|
+
'lang': lang,
|
124
174
|
}
|
125
175
|
|
126
176
|
custom_model = CustomPEKModel(**model_input)
|
127
177
|
else:
|
128
|
-
logger.error(
|
178
|
+
logger.error('Not allow model_name!')
|
129
179
|
exit(1)
|
130
180
|
model_init_cost = time.time() - model_init_start
|
131
|
-
logger.info(f
|
181
|
+
logger.info(f'model init cost: {model_init_cost}')
|
132
182
|
else:
|
133
|
-
logger.error(
|
183
|
+
logger.error('use_inside_model is False, not allow to use inside model')
|
134
184
|
exit(1)
|
135
185
|
|
136
186
|
return custom_model
|
137
187
|
|
138
188
|
|
139
|
-
def doc_analyze(
|
140
|
-
|
141
|
-
|
189
|
+
def doc_analyze(
|
190
|
+
dataset: Dataset,
|
191
|
+
ocr: bool = False,
|
192
|
+
show_log: bool = False,
|
193
|
+
start_page_id=0,
|
194
|
+
end_page_id=None,
|
195
|
+
lang=None,
|
196
|
+
layout_model=None,
|
197
|
+
formula_enable=None,
|
198
|
+
table_enable=None,
|
199
|
+
) -> InferenceResult:
|
142
200
|
|
143
|
-
if lang ==
|
201
|
+
if lang == '':
|
144
202
|
lang = None
|
145
203
|
|
146
204
|
model_manager = ModelSingleton()
|
147
|
-
custom_model = model_manager.get_model(
|
148
|
-
|
149
|
-
|
150
|
-
pdf_page_num = doc.page_count
|
151
|
-
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
|
152
|
-
if end_page_id > pdf_page_num - 1:
|
153
|
-
logger.warning("end_page_id is out of range, use images length")
|
154
|
-
end_page_id = pdf_page_num - 1
|
155
|
-
|
156
|
-
images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
|
205
|
+
custom_model = model_manager.get_model(
|
206
|
+
ocr, show_log, lang, layout_model, formula_enable, table_enable
|
207
|
+
)
|
157
208
|
|
158
209
|
model_json = []
|
159
210
|
doc_analyze_start = time.time()
|
160
211
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
212
|
+
if end_page_id is None:
|
213
|
+
end_page_id = len(dataset)
|
214
|
+
|
215
|
+
for index in range(len(dataset)):
|
216
|
+
page_data = dataset.get_page(index)
|
217
|
+
img_dict = page_data.get_image()
|
218
|
+
img = img_dict['img']
|
219
|
+
page_width = img_dict['width']
|
220
|
+
page_height = img_dict['height']
|
165
221
|
if start_page_id <= index <= end_page_id:
|
166
222
|
page_start = time.time()
|
167
223
|
result = custom_model(img)
|
168
224
|
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
|
169
225
|
else:
|
170
226
|
result = []
|
171
|
-
|
172
|
-
|
227
|
+
|
228
|
+
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
229
|
+
page_dict = {'layout_dets': result, 'page_info': page_info}
|
173
230
|
model_json.append(page_dict)
|
174
231
|
|
175
232
|
gc_start = time.time()
|
176
233
|
clean_memory()
|
177
234
|
gc_time = round(time.time() - gc_start, 2)
|
178
|
-
logger.info(f
|
235
|
+
logger.info(f'gc time: {gc_time}')
|
179
236
|
|
180
237
|
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
|
181
|
-
doc_analyze_speed = round(
|
182
|
-
logger.info(
|
183
|
-
|
238
|
+
doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
|
239
|
+
logger.info(
|
240
|
+
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
|
241
|
+
f' speed: {doc_analyze_speed} pages/second'
|
242
|
+
)
|
184
243
|
|
185
|
-
return model_json
|
244
|
+
return InferenceResult(model_json, dataset)
|
@@ -0,0 +1,190 @@
|
|
1
|
+
import copy
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
|
7
|
+
from magic_pdf.config.enums import SupportedPdfParseMethod
|
8
|
+
from magic_pdf.data.data_reader_writer import DataWriter
|
9
|
+
from magic_pdf.data.dataset import Dataset
|
10
|
+
from magic_pdf.filter import classify
|
11
|
+
from magic_pdf.libs.draw_bbox import draw_model_bbox
|
12
|
+
from magic_pdf.libs.version import __version__
|
13
|
+
from magic_pdf.model import InferenceResultBase
|
14
|
+
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
|
15
|
+
from magic_pdf.pipe.operators import PipeResult
|
16
|
+
|
17
|
+
|
18
|
+
class InferenceResult(InferenceResultBase):
|
19
|
+
def __init__(self, inference_results: list, dataset: Dataset):
|
20
|
+
"""Initialized method.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
inference_results (list): the inference result generated by model
|
24
|
+
dataset (Dataset): the dataset related with model inference result
|
25
|
+
"""
|
26
|
+
self._infer_res = inference_results
|
27
|
+
self._dataset = dataset
|
28
|
+
|
29
|
+
def draw_model(self, file_path: str) -> None:
|
30
|
+
"""Draw model inference result.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
file_path (str): the output file path
|
34
|
+
"""
|
35
|
+
dir_name = os.path.dirname(file_path)
|
36
|
+
base_name = os.path.basename(file_path)
|
37
|
+
if not os.path.exists(dir_name):
|
38
|
+
os.makedirs(dir_name, exist_ok=True)
|
39
|
+
draw_model_bbox(
|
40
|
+
copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
|
41
|
+
)
|
42
|
+
|
43
|
+
def dump_model(self, writer: DataWriter, file_path: str):
|
44
|
+
"""Dump model inference result to file.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
writer (DataWriter): writer handle
|
48
|
+
file_path (str): the location of target file
|
49
|
+
"""
|
50
|
+
writer.write_string(
|
51
|
+
file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
|
52
|
+
)
|
53
|
+
|
54
|
+
def get_infer_res(self):
|
55
|
+
"""Get the inference result.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
list: the inference result generated by model
|
59
|
+
"""
|
60
|
+
return self._infer_res
|
61
|
+
|
62
|
+
def apply(self, proc: Callable, *args, **kwargs):
|
63
|
+
"""Apply callable method which.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
proc (Callable): invoke proc as follows:
|
67
|
+
proc(inference_result, *args, **kwargs)
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
Any: return the result generated by proc
|
71
|
+
"""
|
72
|
+
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
|
73
|
+
|
74
|
+
def pipe_auto_mode(
|
75
|
+
self,
|
76
|
+
imageWriter: DataWriter,
|
77
|
+
start_page_id=0,
|
78
|
+
end_page_id=None,
|
79
|
+
debug_mode=False,
|
80
|
+
lang=None,
|
81
|
+
) -> PipeResult:
|
82
|
+
"""Post-proc the model inference result.
|
83
|
+
step1: classify the dataset type
|
84
|
+
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
|
85
|
+
|
86
|
+
Args:
|
87
|
+
imageWriter (DataWriter): the image writer handle
|
88
|
+
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
89
|
+
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
90
|
+
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
91
|
+
lang (str, optional): Defaults to None.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
PipeResult: the result
|
95
|
+
"""
|
96
|
+
|
97
|
+
pdf_proc_method = classify(self._dataset.data_bits())
|
98
|
+
|
99
|
+
if pdf_proc_method == SupportedPdfParseMethod.TXT:
|
100
|
+
return self.pipe_txt_mode(
|
101
|
+
imageWriter, start_page_id, end_page_id, debug_mode, lang
|
102
|
+
)
|
103
|
+
else:
|
104
|
+
return self.pipe_ocr_mode(
|
105
|
+
imageWriter, start_page_id, end_page_id, debug_mode, lang
|
106
|
+
)
|
107
|
+
|
108
|
+
def pipe_txt_mode(
|
109
|
+
self,
|
110
|
+
imageWriter: DataWriter,
|
111
|
+
start_page_id=0,
|
112
|
+
end_page_id=None,
|
113
|
+
debug_mode=False,
|
114
|
+
lang=None,
|
115
|
+
) -> PipeResult:
|
116
|
+
"""Post-proc the model inference result, Extract the text using the
|
117
|
+
third library, such as `pymupdf`
|
118
|
+
|
119
|
+
Args:
|
120
|
+
imageWriter (DataWriter): the image writer handle
|
121
|
+
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
122
|
+
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
123
|
+
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
124
|
+
lang (str, optional): Defaults to None.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
PipeResult: the result
|
128
|
+
"""
|
129
|
+
|
130
|
+
def proc(*args, **kwargs) -> PipeResult:
|
131
|
+
res = pdf_parse_union(*args, **kwargs)
|
132
|
+
res['_parse_type'] = PARSE_TYPE_TXT
|
133
|
+
res['_version_name'] = __version__
|
134
|
+
if 'lang' in kwargs and kwargs['lang'] is not None:
|
135
|
+
res['lang'] = kwargs['lang']
|
136
|
+
return PipeResult(res, self._dataset)
|
137
|
+
|
138
|
+
res = self.apply(
|
139
|
+
proc,
|
140
|
+
self._dataset,
|
141
|
+
imageWriter,
|
142
|
+
SupportedPdfParseMethod.TXT,
|
143
|
+
start_page_id=start_page_id,
|
144
|
+
end_page_id=end_page_id,
|
145
|
+
debug_mode=debug_mode,
|
146
|
+
lang=lang,
|
147
|
+
)
|
148
|
+
return res
|
149
|
+
|
150
|
+
def pipe_ocr_mode(
|
151
|
+
self,
|
152
|
+
imageWriter: DataWriter,
|
153
|
+
start_page_id=0,
|
154
|
+
end_page_id=None,
|
155
|
+
debug_mode=False,
|
156
|
+
lang=None,
|
157
|
+
) -> PipeResult:
|
158
|
+
"""Post-proc the model inference result, Extract the text using `OCR`
|
159
|
+
technical.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
imageWriter (DataWriter): the image writer handle
|
163
|
+
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
164
|
+
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
165
|
+
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
166
|
+
lang (str, optional): Defaults to None.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
PipeResult: the result
|
170
|
+
"""
|
171
|
+
|
172
|
+
def proc(*args, **kwargs) -> PipeResult:
|
173
|
+
res = pdf_parse_union(*args, **kwargs)
|
174
|
+
res['_parse_type'] = PARSE_TYPE_OCR
|
175
|
+
res['_version_name'] = __version__
|
176
|
+
if 'lang' in kwargs and kwargs['lang'] is not None:
|
177
|
+
res['lang'] = kwargs['lang']
|
178
|
+
return PipeResult(res, self._dataset)
|
179
|
+
|
180
|
+
res = self.apply(
|
181
|
+
proc,
|
182
|
+
self._dataset,
|
183
|
+
imageWriter,
|
184
|
+
SupportedPdfParseMethod.OCR,
|
185
|
+
start_page_id=start_page_id,
|
186
|
+
end_page_id=end_page_id,
|
187
|
+
debug_mode=debug_mode,
|
188
|
+
lang=lang,
|
189
|
+
)
|
190
|
+
return res
|
@@ -179,7 +179,25 @@ class CustomPEKModel:
|
|
179
179
|
layout_res = self.layout_model(image, ignore_catids=[])
|
180
180
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
181
181
|
# doclayout_yolo
|
182
|
-
|
182
|
+
img_pil = Image.fromarray(image)
|
183
|
+
width, height = img_pil.size
|
184
|
+
# logger.info(f'width: {width}, height: {height}')
|
185
|
+
input_res = {"poly":[0,0,width,0,width,height,0,height]}
|
186
|
+
new_image, useful_list = crop_img(input_res, img_pil, crop_paste_x=width//2, crop_paste_y=0)
|
187
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
188
|
+
layout_res = self.layout_model.predict(new_image)
|
189
|
+
for res in layout_res:
|
190
|
+
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
|
191
|
+
p1 = p1 - paste_x + xmin
|
192
|
+
p2 = p2 - paste_y + ymin
|
193
|
+
p3 = p3 - paste_x + xmin
|
194
|
+
p4 = p4 - paste_y + ymin
|
195
|
+
p5 = p5 - paste_x + xmin
|
196
|
+
p6 = p6 - paste_y + ymin
|
197
|
+
p7 = p7 - paste_x + xmin
|
198
|
+
p8 = p8 - paste_y + ymin
|
199
|
+
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
|
200
|
+
|
183
201
|
layout_cost = round(time.time() - layout_start, 2)
|
184
202
|
logger.info(f'layout detection time: {layout_cost}')
|
185
203
|
|
@@ -215,6 +233,7 @@ class CustomPEKModel:
|
|
215
233
|
|
216
234
|
# OCR recognition
|
217
235
|
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
236
|
+
|
218
237
|
if self.apply_ocr:
|
219
238
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
220
239
|
else:
|
@@ -92,14 +92,24 @@ class AtomModelSingleton:
|
|
92
92
|
return cls._instance
|
93
93
|
|
94
94
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
95
|
+
|
95
96
|
lang = kwargs.get('lang', None)
|
96
97
|
layout_model_name = kwargs.get('layout_model_name', None)
|
97
|
-
|
98
|
+
table_model_name = kwargs.get('table_model_name', None)
|
99
|
+
|
100
|
+
if atom_model_name in [AtomicModel.OCR]:
|
101
|
+
key = (atom_model_name, lang)
|
102
|
+
elif atom_model_name in [AtomicModel.Layout]:
|
103
|
+
key = (atom_model_name, layout_model_name)
|
104
|
+
elif atom_model_name in [AtomicModel.Table]:
|
105
|
+
key = (atom_model_name, table_model_name)
|
106
|
+
else:
|
107
|
+
key = atom_model_name
|
108
|
+
|
98
109
|
if key not in self._models:
|
99
110
|
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
100
111
|
return self._models[key]
|
101
112
|
|
102
|
-
|
103
113
|
def atom_model_init(model_name: str, **kwargs):
|
104
114
|
atom_model = None
|
105
115
|
if model_name == AtomicModel.Layout:
|
@@ -129,7 +139,7 @@ def atom_model_init(model_name: str, **kwargs):
|
|
129
139
|
atom_model = ocr_model_init(
|
130
140
|
kwargs.get('ocr_show_log'),
|
131
141
|
kwargs.get('det_db_box_thresh'),
|
132
|
-
kwargs.get('lang')
|
142
|
+
kwargs.get('lang'),
|
133
143
|
)
|
134
144
|
elif model_name == AtomicModel.Table:
|
135
145
|
atom_model = table_model_init(
|
@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res):
|
|
42
42
|
|
43
43
|
|
44
44
|
def clean_vram(device, vram_threshold=8):
|
45
|
+
total_memory = get_vram(device)
|
46
|
+
if total_memory and total_memory <= vram_threshold:
|
47
|
+
gc_start = time.time()
|
48
|
+
clean_memory()
|
49
|
+
gc_time = round(time.time() - gc_start, 2)
|
50
|
+
logger.info(f"gc time: {gc_time}")
|
51
|
+
|
52
|
+
|
53
|
+
def get_vram(device):
|
45
54
|
if torch.cuda.is_available() and device != 'cpu':
|
46
55
|
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
47
|
-
|
48
|
-
|
49
|
-
clean_memory()
|
50
|
-
gc_time = round(time.time() - gc_start, 2)
|
51
|
-
logger.info(f"gc time: {gc_time}")
|
56
|
+
return total_memory
|
57
|
+
return None
|
magic_pdf/para/para_split_v3.py
CHANGED
@@ -112,8 +112,8 @@ def __is_list_or_index_block(block):
|
|
112
112
|
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
|
113
113
|
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
|
114
114
|
if (
|
115
|
-
line['bbox'][0] - block['bbox_fs'][0] > 0.
|
116
|
-
and block['bbox_fs'][2] - line['bbox'][2] > 0.
|
115
|
+
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
|
116
|
+
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
|
117
117
|
):
|
118
118
|
external_sides_not_close_num += 1
|
119
119
|
if abs(line_mid_x - block_mid_x) < line_height / 2:
|
magic_pdf/pdf_parse_by_ocr.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
2
|
-
from magic_pdf.data.dataset import
|
2
|
+
from magic_pdf.data.dataset import Dataset
|
3
3
|
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
|
4
4
|
|
5
5
|
|
6
|
-
def parse_pdf_by_ocr(
|
6
|
+
def parse_pdf_by_ocr(dataset: Dataset,
|
7
7
|
model_list,
|
8
8
|
imageWriter,
|
9
9
|
start_page_id=0,
|
@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
|
|
11
11
|
debug_mode=False,
|
12
12
|
lang=None,
|
13
13
|
):
|
14
|
-
|
15
|
-
|
16
|
-
model_list,
|
14
|
+
return pdf_parse_union(model_list,
|
15
|
+
dataset,
|
17
16
|
imageWriter,
|
18
17
|
SupportedPdfParseMethod.OCR,
|
19
18
|
start_page_id=start_page_id,
|
magic_pdf/pdf_parse_by_txt.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
2
|
-
from magic_pdf.data.dataset import
|
2
|
+
from magic_pdf.data.dataset import Dataset
|
3
3
|
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
|
4
4
|
|
5
5
|
|
6
6
|
def parse_pdf_by_txt(
|
7
|
-
|
7
|
+
dataset: Dataset,
|
8
8
|
model_list,
|
9
9
|
imageWriter,
|
10
10
|
start_page_id=0,
|
@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
|
|
12
12
|
debug_mode=False,
|
13
13
|
lang=None,
|
14
14
|
):
|
15
|
-
|
16
|
-
|
17
|
-
model_list,
|
15
|
+
return pdf_parse_union(model_list,
|
16
|
+
dataset,
|
18
17
|
imageWriter,
|
19
18
|
SupportedPdfParseMethod.TXT,
|
20
19
|
start_page_id=start_page_id,
|