magic-pdf 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +53 -0
- magic_pdf/config/drop_reason.py +35 -0
- magic_pdf/config/drop_tag.py +19 -0
- magic_pdf/config/make_content_config.py +11 -0
- magic_pdf/{libs/ModelBlockTypeEnum.py → config/model_block_type.py} +2 -1
- magic_pdf/data/read_api.py +1 -1
- magic_pdf/dict2md/mkcontent.py +226 -185
- magic_pdf/dict2md/ocr_mkcontent.py +11 -11
- magic_pdf/filter/pdf_meta_scan.py +101 -79
- magic_pdf/integrations/rag/utils.py +4 -5
- magic_pdf/libs/config_reader.py +5 -5
- magic_pdf/libs/draw_bbox.py +3 -2
- magic_pdf/libs/pdf_image_tools.py +36 -12
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +2 -0
- magic_pdf/model/magic_model.py +13 -13
- magic_pdf/model/pdf_extract_kit.py +122 -76
- magic_pdf/model/sub_modules/model_init.py +40 -35
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -7
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +12 -4
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +2 -0
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +30 -28
- magic_pdf/para/para_split.py +411 -248
- magic_pdf/para/para_split_v2.py +352 -182
- magic_pdf/para/para_split_v3.py +110 -53
- magic_pdf/pdf_parse_by_ocr.py +2 -0
- magic_pdf/pdf_parse_by_txt.py +2 -0
- magic_pdf/pdf_parse_union_core.py +174 -100
- magic_pdf/pdf_parse_union_core_v2.py +202 -36
- magic_pdf/pipe/AbsPipe.py +28 -44
- magic_pdf/pipe/OCRPipe.py +5 -5
- magic_pdf/pipe/TXTPipe.py +5 -6
- magic_pdf/pipe/UNIPipe.py +24 -25
- magic_pdf/post_proc/pdf_post_filter.py +7 -14
- magic_pdf/pre_proc/cut_image.py +9 -11
- magic_pdf/pre_proc/equations_replace.py +203 -212
- magic_pdf/pre_proc/ocr_detect_all_bboxes.py +235 -49
- magic_pdf/pre_proc/ocr_dict_merge.py +5 -5
- magic_pdf/pre_proc/ocr_span_list_modify.py +122 -63
- magic_pdf/pre_proc/pdf_pre_filter.py +37 -33
- magic_pdf/pre_proc/remove_bbox_overlap.py +20 -18
- magic_pdf/pre_proc/remove_colored_strip_bbox.py +36 -14
- magic_pdf/pre_proc/remove_footer_header.py +2 -5
- magic_pdf/pre_proc/remove_rotate_bbox.py +111 -63
- magic_pdf/pre_proc/resolve_bbox_conflict.py +10 -17
- magic_pdf/spark/spark_api.py +15 -17
- magic_pdf/tools/cli.py +3 -4
- magic_pdf/tools/cli_dev.py +6 -9
- magic_pdf/tools/common.py +26 -36
- magic_pdf/user_api.py +29 -38
- {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/METADATA +11 -12
- {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/RECORD +57 -58
- magic_pdf/libs/Constants.py +0 -55
- magic_pdf/libs/MakeContentConfig.py +0 -11
- magic_pdf/libs/drop_reason.py +0 -27
- magic_pdf/libs/drop_tag.py +0 -19
- magic_pdf/para/para_pipeline.py +0 -297
- /magic_pdf/{libs → config}/ocr_content_type.py +0 -0
- {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
|
|
1
|
-
|
2
|
-
import torch
|
3
|
-
from loguru import logger
|
1
|
+
# flake8: noqa
|
4
2
|
import os
|
5
3
|
import time
|
4
|
+
|
6
5
|
import cv2
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
7
8
|
import yaml
|
9
|
+
from loguru import logger
|
8
10
|
from PIL import Image
|
9
11
|
|
10
12
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
|
13
15
|
try:
|
14
16
|
import torchtext
|
15
17
|
|
16
|
-
if torchtext.__version__ >=
|
18
|
+
if torchtext.__version__ >= '0.18.0':
|
17
19
|
torchtext.disable_torchtext_deprecation_warning()
|
18
20
|
except ImportError:
|
19
21
|
pass
|
20
22
|
|
21
|
-
from magic_pdf.
|
23
|
+
from magic_pdf.config.constants import *
|
22
24
|
from magic_pdf.model.model_list import AtomicModel
|
23
25
|
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
24
|
-
from magic_pdf.model.sub_modules.model_utils import
|
25
|
-
|
26
|
+
from magic_pdf.model.sub_modules.model_utils import (
|
27
|
+
clean_vram, crop_img, get_res_list_from_layout_res)
|
28
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
29
|
+
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
26
30
|
|
27
31
|
|
28
32
|
class CustomPEKModel:
|
@@ -41,42 +45,54 @@ class CustomPEKModel:
|
|
41
45
|
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
|
42
46
|
# 构建 model_configs.yaml 文件的完整路径
|
43
47
|
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
|
44
|
-
with open(config_path,
|
48
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
45
49
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
46
50
|
# 初始化解析配置
|
47
51
|
|
48
52
|
# layout config
|
49
|
-
self.layout_config = kwargs.get(
|
50
|
-
self.layout_model_name = self.layout_config.get(
|
53
|
+
self.layout_config = kwargs.get('layout_config')
|
54
|
+
self.layout_model_name = self.layout_config.get(
|
55
|
+
'model', MODEL_NAME.DocLayout_YOLO
|
56
|
+
)
|
51
57
|
|
52
58
|
# formula config
|
53
|
-
self.formula_config = kwargs.get(
|
54
|
-
self.mfd_model_name = self.formula_config.get(
|
55
|
-
|
56
|
-
|
59
|
+
self.formula_config = kwargs.get('formula_config')
|
60
|
+
self.mfd_model_name = self.formula_config.get(
|
61
|
+
'mfd_model', MODEL_NAME.YOLO_V8_MFD
|
62
|
+
)
|
63
|
+
self.mfr_model_name = self.formula_config.get(
|
64
|
+
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
|
65
|
+
)
|
66
|
+
self.apply_formula = self.formula_config.get('enable', True)
|
57
67
|
|
58
68
|
# table config
|
59
|
-
self.table_config = kwargs.get(
|
60
|
-
self.apply_table = self.table_config.get(
|
61
|
-
self.table_max_time = self.table_config.get(
|
62
|
-
self.table_model_name = self.table_config.get(
|
69
|
+
self.table_config = kwargs.get('table_config')
|
70
|
+
self.apply_table = self.table_config.get('enable', False)
|
71
|
+
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
|
72
|
+
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
|
63
73
|
|
64
74
|
# ocr config
|
65
75
|
self.apply_ocr = ocr
|
66
|
-
self.lang = kwargs.get(
|
76
|
+
self.lang = kwargs.get('lang', None)
|
67
77
|
|
68
78
|
logger.info(
|
69
|
-
|
70
|
-
|
71
|
-
self.layout_model_name,
|
72
|
-
self.
|
79
|
+
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
|
80
|
+
'apply_table: {}, table_model: {}, lang: {}'.format(
|
81
|
+
self.layout_model_name,
|
82
|
+
self.apply_formula,
|
83
|
+
self.apply_ocr,
|
84
|
+
self.apply_table,
|
85
|
+
self.table_model_name,
|
86
|
+
self.lang,
|
73
87
|
)
|
74
88
|
)
|
75
89
|
# 初始化解析方案
|
76
|
-
self.device = kwargs.get(
|
77
|
-
logger.info(
|
78
|
-
models_dir = kwargs.get(
|
79
|
-
|
90
|
+
self.device = kwargs.get('device', 'cpu')
|
91
|
+
logger.info('using device: {}'.format(self.device))
|
92
|
+
models_dir = kwargs.get(
|
93
|
+
'models_dir', os.path.join(root_dir, 'resources', 'models')
|
94
|
+
)
|
95
|
+
logger.info('using models_dir: {}'.format(models_dir))
|
80
96
|
|
81
97
|
atom_model_manager = AtomModelSingleton()
|
82
98
|
|
@@ -85,18 +101,24 @@ class CustomPEKModel:
|
|
85
101
|
# 初始化公式检测模型
|
86
102
|
self.mfd_model = atom_model_manager.get_atom_model(
|
87
103
|
atom_model_name=AtomicModel.MFD,
|
88
|
-
mfd_weights=str(
|
89
|
-
|
104
|
+
mfd_weights=str(
|
105
|
+
os.path.join(
|
106
|
+
models_dir, self.configs['weights'][self.mfd_model_name]
|
107
|
+
)
|
108
|
+
),
|
109
|
+
device=self.device,
|
90
110
|
)
|
91
111
|
|
92
112
|
# 初始化公式解析模型
|
93
|
-
mfr_weight_dir = str(
|
94
|
-
|
113
|
+
mfr_weight_dir = str(
|
114
|
+
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
|
115
|
+
)
|
116
|
+
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
|
95
117
|
self.mfr_model = atom_model_manager.get_atom_model(
|
96
118
|
atom_model_name=AtomicModel.MFR,
|
97
119
|
mfr_weight_dir=mfr_weight_dir,
|
98
120
|
mfr_cfg_path=mfr_cfg_path,
|
99
|
-
device=self.device
|
121
|
+
device=self.device,
|
100
122
|
)
|
101
123
|
|
102
124
|
# 初始化layout模型
|
@@ -104,42 +126,51 @@ class CustomPEKModel:
|
|
104
126
|
self.layout_model = atom_model_manager.get_atom_model(
|
105
127
|
atom_model_name=AtomicModel.Layout,
|
106
128
|
layout_model_name=MODEL_NAME.LAYOUTLMv3,
|
107
|
-
layout_weights=str(
|
108
|
-
|
109
|
-
|
129
|
+
layout_weights=str(
|
130
|
+
os.path.join(
|
131
|
+
models_dir, self.configs['weights'][self.layout_model_name]
|
132
|
+
)
|
133
|
+
),
|
134
|
+
layout_config_file=str(
|
135
|
+
os.path.join(
|
136
|
+
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
|
137
|
+
)
|
138
|
+
),
|
139
|
+
device=self.device,
|
110
140
|
)
|
111
141
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
112
142
|
self.layout_model = atom_model_manager.get_atom_model(
|
113
143
|
atom_model_name=AtomicModel.Layout,
|
114
144
|
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
115
|
-
doclayout_yolo_weights=str(
|
116
|
-
|
145
|
+
doclayout_yolo_weights=str(
|
146
|
+
os.path.join(
|
147
|
+
models_dir, self.configs['weights'][self.layout_model_name]
|
148
|
+
)
|
149
|
+
),
|
150
|
+
device=self.device,
|
117
151
|
)
|
118
152
|
# 初始化ocr
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
)
|
153
|
+
self.ocr_model = atom_model_manager.get_atom_model(
|
154
|
+
atom_model_name=AtomicModel.OCR,
|
155
|
+
ocr_show_log=show_log,
|
156
|
+
det_db_box_thresh=0.3,
|
157
|
+
lang=self.lang
|
158
|
+
)
|
126
159
|
# init table model
|
127
160
|
if self.apply_table:
|
128
|
-
table_model_dir = self.configs[
|
161
|
+
table_model_dir = self.configs['weights'][self.table_model_name]
|
129
162
|
self.table_model = atom_model_manager.get_atom_model(
|
130
163
|
atom_model_name=AtomicModel.Table,
|
131
164
|
table_model_name=self.table_model_name,
|
132
165
|
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
133
166
|
table_max_time=self.table_max_time,
|
134
|
-
device=self.device
|
167
|
+
device=self.device,
|
135
168
|
)
|
136
169
|
|
137
170
|
logger.info('DocAnalysis init done!')
|
138
171
|
|
139
172
|
def __call__(self, image):
|
140
173
|
|
141
|
-
page_start = time.time()
|
142
|
-
|
143
174
|
# layout检测
|
144
175
|
layout_start = time.time()
|
145
176
|
layout_res = []
|
@@ -150,7 +181,7 @@ class CustomPEKModel:
|
|
150
181
|
# doclayout_yolo
|
151
182
|
layout_res = self.layout_model.predict(image)
|
152
183
|
layout_cost = round(time.time() - layout_start, 2)
|
153
|
-
logger.info(f
|
184
|
+
logger.info(f'layout detection time: {layout_cost}')
|
154
185
|
|
155
186
|
pil_img = Image.fromarray(image)
|
156
187
|
|
@@ -158,40 +189,47 @@ class CustomPEKModel:
|
|
158
189
|
# 公式检测
|
159
190
|
mfd_start = time.time()
|
160
191
|
mfd_res = self.mfd_model.predict(image)
|
161
|
-
logger.info(f
|
192
|
+
logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
|
162
193
|
|
163
194
|
# 公式识别
|
164
195
|
mfr_start = time.time()
|
165
196
|
formula_list = self.mfr_model.predict(mfd_res, image)
|
166
197
|
layout_res.extend(formula_list)
|
167
198
|
mfr_cost = round(time.time() - mfr_start, 2)
|
168
|
-
logger.info(f
|
199
|
+
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
|
169
200
|
|
170
201
|
# 清理显存
|
171
202
|
clean_vram(self.device, vram_threshold=8)
|
172
203
|
|
173
204
|
# 从layout_res中获取ocr区域、表格区域、公式区域
|
174
|
-
ocr_res_list, table_res_list, single_page_mfdetrec_res =
|
205
|
+
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
206
|
+
get_res_list_from_layout_res(layout_res)
|
207
|
+
)
|
175
208
|
|
176
209
|
# ocr识别
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
210
|
+
ocr_start = time.time()
|
211
|
+
# Process each area that requires OCR processing
|
212
|
+
for res in ocr_res_list:
|
213
|
+
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
|
214
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
215
|
+
|
216
|
+
# OCR recognition
|
217
|
+
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
218
|
+
if self.apply_ocr:
|
186
219
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
220
|
+
else:
|
221
|
+
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
|
187
222
|
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
223
|
+
# Integration results
|
224
|
+
if ocr_res:
|
225
|
+
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
226
|
+
layout_res.extend(ocr_result_list)
|
192
227
|
|
193
|
-
|
228
|
+
ocr_cost = round(time.time() - ocr_start, 2)
|
229
|
+
if self.apply_ocr:
|
194
230
|
logger.info(f"ocr time: {ocr_cost}")
|
231
|
+
else:
|
232
|
+
logger.info(f"det time: {ocr_cost}")
|
195
233
|
|
196
234
|
# 表格识别 table recognition
|
197
235
|
if self.apply_table:
|
@@ -202,27 +240,35 @@ class CustomPEKModel:
|
|
202
240
|
html_code = None
|
203
241
|
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
204
242
|
with torch.no_grad():
|
205
|
-
table_result = self.table_model.predict(new_image,
|
243
|
+
table_result = self.table_model.predict(new_image, 'html')
|
206
244
|
if len(table_result) > 0:
|
207
245
|
html_code = table_result[0]
|
208
246
|
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
209
247
|
html_code = self.table_model.img2html(new_image)
|
210
248
|
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
211
|
-
html_code, table_cell_bboxes, elapse = self.table_model.predict(
|
249
|
+
html_code, table_cell_bboxes, elapse = self.table_model.predict(
|
250
|
+
new_image
|
251
|
+
)
|
212
252
|
run_time = time.time() - single_table_start_time
|
213
253
|
if run_time > self.table_max_time:
|
214
|
-
logger.warning(
|
254
|
+
logger.warning(
|
255
|
+
f'table recognition processing exceeds max time {self.table_max_time}s'
|
256
|
+
)
|
215
257
|
# 判断是否返回正常
|
216
258
|
if html_code:
|
217
|
-
expected_ending = html_code.strip().endswith(
|
259
|
+
expected_ending = html_code.strip().endswith(
|
260
|
+
'</html>'
|
261
|
+
) or html_code.strip().endswith('</table>')
|
218
262
|
if expected_ending:
|
219
|
-
res[
|
263
|
+
res['html'] = html_code
|
220
264
|
else:
|
221
|
-
logger.warning(
|
265
|
+
logger.warning(
|
266
|
+
'table recognition processing fails, not found expected HTML table end'
|
267
|
+
)
|
222
268
|
else:
|
223
|
-
logger.warning(
|
224
|
-
|
225
|
-
|
226
|
-
|
269
|
+
logger.warning(
|
270
|
+
'table recognition processing fails, not get html return'
|
271
|
+
)
|
272
|
+
logger.info(f'table time: {round(time.time() - table_start, 2)}')
|
227
273
|
|
228
274
|
return layout_res
|
@@ -1,17 +1,22 @@
|
|
1
1
|
from loguru import logger
|
2
2
|
|
3
|
-
from magic_pdf.
|
3
|
+
from magic_pdf.config.constants import MODEL_NAME
|
4
4
|
from magic_pdf.model.model_list import AtomicModel
|
5
|
-
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import
|
6
|
-
|
5
|
+
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
|
6
|
+
DocLayoutYOLOModel
|
7
|
+
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
|
8
|
+
Layoutlmv3_Predictor
|
7
9
|
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
8
|
-
|
9
10
|
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
10
|
-
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import
|
11
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
|
12
|
+
ModifiedPaddleOCR
|
13
|
+
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
|
14
|
+
RapidTableModel
|
11
15
|
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
12
|
-
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import
|
13
|
-
|
14
|
-
from magic_pdf.model.sub_modules.table.
|
16
|
+
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
|
17
|
+
StructTableModel
|
18
|
+
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
|
19
|
+
TableMasterPaddleModel
|
15
20
|
|
16
21
|
|
17
22
|
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
19
24
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
20
25
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
21
26
|
config = {
|
22
|
-
|
23
|
-
|
27
|
+
'model_dir': model_path,
|
28
|
+
'device': _device_
|
24
29
|
}
|
25
30
|
table_model = TableMasterPaddleModel(config)
|
26
31
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
27
32
|
table_model = RapidTableModel()
|
28
33
|
else:
|
29
|
-
logger.error(
|
34
|
+
logger.error('table model type not allow')
|
30
35
|
exit(1)
|
31
36
|
|
32
37
|
return table_model
|
@@ -58,7 +63,7 @@ def ocr_model_init(show_log: bool = False,
|
|
58
63
|
use_dilation=True,
|
59
64
|
det_db_unclip_ratio=1.8,
|
60
65
|
):
|
61
|
-
if lang is not None:
|
66
|
+
if lang is not None and lang != '':
|
62
67
|
model = ModifiedPaddleOCR(
|
63
68
|
show_log=show_log,
|
64
69
|
det_db_box_thresh=det_db_box_thresh,
|
@@ -87,8 +92,8 @@ class AtomModelSingleton:
|
|
87
92
|
return cls._instance
|
88
93
|
|
89
94
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
90
|
-
lang = kwargs.get(
|
91
|
-
layout_model_name = kwargs.get(
|
95
|
+
lang = kwargs.get('lang', None)
|
96
|
+
layout_model_name = kwargs.get('layout_model_name', None)
|
92
97
|
key = (atom_model_name, layout_model_name, lang)
|
93
98
|
if key not in self._models:
|
94
99
|
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
@@ -98,47 +103,47 @@ class AtomModelSingleton:
|
|
98
103
|
def atom_model_init(model_name: str, **kwargs):
|
99
104
|
atom_model = None
|
100
105
|
if model_name == AtomicModel.Layout:
|
101
|
-
if kwargs.get(
|
106
|
+
if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
|
102
107
|
atom_model = layout_model_init(
|
103
|
-
kwargs.get(
|
104
|
-
kwargs.get(
|
105
|
-
kwargs.get(
|
108
|
+
kwargs.get('layout_weights'),
|
109
|
+
kwargs.get('layout_config_file'),
|
110
|
+
kwargs.get('device')
|
106
111
|
)
|
107
|
-
elif kwargs.get(
|
112
|
+
elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
|
108
113
|
atom_model = doclayout_yolo_model_init(
|
109
|
-
kwargs.get(
|
110
|
-
kwargs.get(
|
114
|
+
kwargs.get('doclayout_yolo_weights'),
|
115
|
+
kwargs.get('device')
|
111
116
|
)
|
112
117
|
elif model_name == AtomicModel.MFD:
|
113
118
|
atom_model = mfd_model_init(
|
114
|
-
kwargs.get(
|
115
|
-
kwargs.get(
|
119
|
+
kwargs.get('mfd_weights'),
|
120
|
+
kwargs.get('device')
|
116
121
|
)
|
117
122
|
elif model_name == AtomicModel.MFR:
|
118
123
|
atom_model = mfr_model_init(
|
119
|
-
kwargs.get(
|
120
|
-
kwargs.get(
|
121
|
-
kwargs.get(
|
124
|
+
kwargs.get('mfr_weight_dir'),
|
125
|
+
kwargs.get('mfr_cfg_path'),
|
126
|
+
kwargs.get('device')
|
122
127
|
)
|
123
128
|
elif model_name == AtomicModel.OCR:
|
124
129
|
atom_model = ocr_model_init(
|
125
|
-
kwargs.get(
|
126
|
-
kwargs.get(
|
127
|
-
kwargs.get(
|
130
|
+
kwargs.get('ocr_show_log'),
|
131
|
+
kwargs.get('det_db_box_thresh'),
|
132
|
+
kwargs.get('lang')
|
128
133
|
)
|
129
134
|
elif model_name == AtomicModel.Table:
|
130
135
|
atom_model = table_model_init(
|
131
|
-
kwargs.get(
|
132
|
-
kwargs.get(
|
133
|
-
kwargs.get(
|
134
|
-
kwargs.get(
|
136
|
+
kwargs.get('table_model_name'),
|
137
|
+
kwargs.get('table_model_path'),
|
138
|
+
kwargs.get('table_max_time'),
|
139
|
+
kwargs.get('device')
|
135
140
|
)
|
136
141
|
else:
|
137
|
-
logger.error(
|
142
|
+
logger.error('model name not allow')
|
138
143
|
exit(1)
|
139
144
|
|
140
145
|
if atom_model is None:
|
141
|
-
logger.error(
|
146
|
+
logger.error('model init failed')
|
142
147
|
exit(1)
|
143
148
|
else:
|
144
149
|
return atom_model
|
@@ -71,7 +71,13 @@ def remove_intervals(original, masks):
|
|
71
71
|
|
72
72
|
def update_det_boxes(dt_boxes, mfd_res):
|
73
73
|
new_dt_boxes = []
|
74
|
+
angle_boxes_list = []
|
74
75
|
for text_box in dt_boxes:
|
76
|
+
|
77
|
+
if calculate_is_angle(text_box):
|
78
|
+
angle_boxes_list.append(text_box)
|
79
|
+
continue
|
80
|
+
|
75
81
|
text_bbox = points_to_bbox(text_box)
|
76
82
|
masks_list = []
|
77
83
|
for mf_box in mfd_res:
|
@@ -85,6 +91,9 @@ def update_det_boxes(dt_boxes, mfd_res):
|
|
85
91
|
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
86
92
|
if len(temp_dt_box) > 0:
|
87
93
|
new_dt_boxes.extend(temp_dt_box)
|
94
|
+
|
95
|
+
new_dt_boxes.extend(angle_boxes_list)
|
96
|
+
|
88
97
|
return new_dt_boxes
|
89
98
|
|
90
99
|
|
@@ -143,9 +152,11 @@ def merge_det_boxes(dt_boxes):
|
|
143
152
|
angle_boxes_list = []
|
144
153
|
for text_box in dt_boxes:
|
145
154
|
text_bbox = points_to_bbox(text_box)
|
146
|
-
|
155
|
+
|
156
|
+
if calculate_is_angle(text_box):
|
147
157
|
angle_boxes_list.append(text_box)
|
148
158
|
continue
|
159
|
+
|
149
160
|
text_box_dict = {
|
150
161
|
'bbox': text_bbox,
|
151
162
|
'type': 'text',
|
@@ -200,15 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
|
|
200
211
|
ocr_result_list = []
|
201
212
|
for box_ocr_res in ocr_res:
|
202
213
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
214
|
+
if len(box_ocr_res) == 2:
|
215
|
+
p1, p2, p3, p4 = box_ocr_res[0]
|
216
|
+
text, score = box_ocr_res[1]
|
217
|
+
else:
|
218
|
+
p1, p2, p3, p4 = box_ocr_res
|
219
|
+
text, score = "", 1
|
220
|
+
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
|
221
|
+
# if average_angle_degrees > 0.5:
|
222
|
+
poly = [p1, p2, p3, p4]
|
223
|
+
if calculate_is_angle(poly):
|
207
224
|
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
|
208
225
|
# 与x轴的夹角超过0.5度,对边界做一下矫正
|
209
226
|
# 计算几何中心
|
210
|
-
x_center = sum(point[0] for point in
|
211
|
-
y_center = sum(point[1] for point in
|
227
|
+
x_center = sum(point[0] for point in poly) / 4
|
228
|
+
y_center = sum(point[1] for point in poly) / 4
|
212
229
|
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
|
213
230
|
new_width = p3[0] - p1[0]
|
214
231
|
p1 = [x_center - new_width / 2, y_center - new_height / 2]
|
@@ -257,3 +274,12 @@ def calculate_angle_degrees(poly):
|
|
257
274
|
# logger.info(f"average_angle_degrees: {average_angle_degrees}")
|
258
275
|
return average_angle_degrees
|
259
276
|
|
277
|
+
|
278
|
+
def calculate_is_angle(poly):
|
279
|
+
p1, p2, p3, p4 = poly
|
280
|
+
height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
|
281
|
+
if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
|
282
|
+
return False
|
283
|
+
else:
|
284
|
+
# logger.info((p3[1] - p1[1])/height)
|
285
|
+
return True
|
@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
78
78
|
for idx, img in enumerate(imgs):
|
79
79
|
img = preprocess_image(img)
|
80
80
|
dt_boxes, elapse = self.text_detector(img)
|
81
|
-
if
|
81
|
+
if dt_boxes is None:
|
82
82
|
ocr_res.append(None)
|
83
83
|
continue
|
84
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
85
|
+
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
|
86
|
+
dt_boxes = merge_det_boxes(dt_boxes)
|
87
|
+
if mfd_res:
|
88
|
+
bef = time.time()
|
89
|
+
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
90
|
+
aft = time.time()
|
91
|
+
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
92
|
+
len(dt_boxes), aft - bef))
|
84
93
|
tmp_res = [box.tolist() for box in dt_boxes]
|
85
94
|
ocr_res.append(tmp_res)
|
86
95
|
return ocr_res
|
@@ -125,9 +134,8 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
125
134
|
|
126
135
|
dt_boxes = sorted_boxes(dt_boxes)
|
127
136
|
|
128
|
-
#
|
129
|
-
|
130
|
-
|
137
|
+
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
|
138
|
+
dt_boxes = merge_det_boxes(dt_boxes)
|
131
139
|
|
132
140
|
if mfd_res:
|
133
141
|
bef = time.time()
|
@@ -10,5 +10,7 @@ class RapidTableModel(object):
|
|
10
10
|
|
11
11
|
def predict(self, image):
|
12
12
|
ocr_result, _ = self.ocr_engine(np.asarray(image))
|
13
|
+
if ocr_result is None:
|
14
|
+
return None, None, None
|
13
15
|
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
|
14
16
|
return html_code, table_cell_bboxes, elapse
|