magic-pdf 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/libs/boxbase.py +5 -2
- magic_pdf/libs/draw_bbox.py +14 -2
- magic_pdf/libs/language.py +9 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +103 -99
- magic_pdf/model/doc_analyze_by_custom_model.py +77 -18
- magic_pdf/model/pdf_extract_kit.py +23 -21
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
- magic_pdf/model/sub_modules/model_init.py +4 -3
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
- magic_pdf/pdf_parse_union_core_v2.py +137 -32
- magic_pdf/post_proc/llm_aided.py +59 -26
- magic_pdf/post_proc/llm_aided_ocr.py +689 -0
- magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
- magic_pdf/resources/model_config/model_configs.yaml +2 -2
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/METADATA +50 -41
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/RECORD +23 -22
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/WHEEL +1 -1
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/top_level.txt +0 -0
magic_pdf/libs/boxbase.py
CHANGED
@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
|
|
185
185
|
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
186
186
|
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
187
187
|
|
188
|
+
if any([bbox1_area == 0, bbox2_area == 0]):
|
189
|
+
return 0
|
190
|
+
|
188
191
|
# Compute the intersection over union by taking the intersection area
|
189
192
|
# and dividing it by the sum of both areas minus the intersection area
|
190
|
-
iou = intersection_area / float(bbox1_area + bbox2_area -
|
191
|
-
|
193
|
+
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
|
194
|
+
|
192
195
|
return iou
|
193
196
|
|
194
197
|
|
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
362
362
|
for page in pdf_info:
|
363
363
|
page_line_list = []
|
364
364
|
for block in page['preproc_blocks']:
|
365
|
-
if block['type'] in [BlockType.Text
|
365
|
+
if block['type'] in [BlockType.Text]:
|
366
366
|
for line in block['lines']:
|
367
367
|
bbox = line['bbox']
|
368
368
|
index = line['index']
|
369
369
|
page_line_list.append({'index': index, 'bbox': bbox})
|
370
|
-
|
370
|
+
elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
|
371
|
+
if 'virtual_lines' in block:
|
372
|
+
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
|
373
|
+
for line in block['virtual_lines']:
|
374
|
+
bbox = line['bbox']
|
375
|
+
index = line['index']
|
376
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
377
|
+
else:
|
378
|
+
for line in block['lines']:
|
379
|
+
bbox = line['bbox']
|
380
|
+
index = line['index']
|
381
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
382
|
+
elif block['type'] in [BlockType.Image, BlockType.Table]:
|
371
383
|
for sub_block in block['blocks']:
|
372
384
|
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
373
385
|
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
|
magic_pdf/libs/language.py
CHANGED
@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
|
|
12
12
|
from fast_langdetect import detect_language
|
13
13
|
|
14
14
|
|
15
|
+
def remove_invalid_surrogates(text):
|
16
|
+
# 移除无效的 UTF-16 代理对
|
17
|
+
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
|
18
|
+
|
19
|
+
|
15
20
|
def detect_lang(text: str) -> str:
|
16
21
|
|
17
22
|
if len(text) == 0:
|
18
23
|
return ""
|
19
24
|
|
20
25
|
text = text.replace("\n", "")
|
26
|
+
text = remove_invalid_surrogates(text)
|
27
|
+
|
28
|
+
# print(text)
|
21
29
|
try:
|
22
30
|
lang_upper = detect_language(text)
|
23
31
|
except:
|
@@ -37,3 +45,4 @@ if __name__ == '__main__':
|
|
37
45
|
print(detect_lang("<html>This is a test</html>"))
|
38
46
|
print(detect_lang("这个是中文测试。"))
|
39
47
|
print(detect_lang("<html>这个是中文测试。</html>"))
|
48
|
+
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.
|
1
|
+
__version__ = "1.1.0"
|
magic_pdf/model/batch_analyze.py
CHANGED
@@ -7,19 +7,19 @@ from loguru import logger
|
|
7
7
|
from PIL import Image
|
8
8
|
|
9
9
|
from magic_pdf.config.constants import MODEL_NAME
|
10
|
-
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
11
|
-
from magic_pdf.data.dataset import Dataset
|
12
|
-
from magic_pdf.libs.clean_memory import clean_memory
|
13
|
-
from magic_pdf.libs.config_reader import get_device
|
14
|
-
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
10
|
+
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
11
|
+
# from magic_pdf.data.dataset import Dataset
|
12
|
+
# from magic_pdf.libs.clean_memory import clean_memory
|
13
|
+
# from magic_pdf.libs.config_reader import get_device
|
14
|
+
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
15
15
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
16
16
|
from magic_pdf.model.sub_modules.model_utils import (
|
17
17
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
18
18
|
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
19
19
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
20
|
-
from magic_pdf.operators.models import InferenceResult
|
20
|
+
# from magic_pdf.operators.models import InferenceResult
|
21
21
|
|
22
|
-
YOLO_LAYOUT_BASE_BATCH_SIZE =
|
22
|
+
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
23
23
|
MFD_BASE_BATCH_SIZE = 1
|
24
24
|
MFR_BASE_BATCH_SIZE = 16
|
25
25
|
|
@@ -44,19 +44,20 @@ class BatchAnalyze:
|
|
44
44
|
modified_images = []
|
45
45
|
for image_index, image in enumerate(images):
|
46
46
|
pil_img = Image.fromarray(image)
|
47
|
-
width, height = pil_img.size
|
48
|
-
if height > width:
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
else:
|
56
|
-
|
47
|
+
# width, height = pil_img.size
|
48
|
+
# if height > width:
|
49
|
+
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
|
50
|
+
# new_image, useful_list = crop_img(
|
51
|
+
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
|
52
|
+
# )
|
53
|
+
# layout_images.append(new_image)
|
54
|
+
# modified_images.append([image_index, useful_list])
|
55
|
+
# else:
|
56
|
+
layout_images.append(pil_img)
|
57
57
|
|
58
58
|
images_layout_res += self.model.layout_model.batch_predict(
|
59
|
-
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
|
59
|
+
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
|
60
|
+
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
60
61
|
)
|
61
62
|
|
62
63
|
for image_index, useful_list in modified_images:
|
@@ -78,7 +79,8 @@ class BatchAnalyze:
|
|
78
79
|
# 公式检测
|
79
80
|
mfd_start_time = time.time()
|
80
81
|
images_mfd_res = self.model.mfd_model.batch_predict(
|
81
|
-
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
|
82
|
+
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE
|
83
|
+
images, MFD_BASE_BATCH_SIZE
|
82
84
|
)
|
83
85
|
logger.info(
|
84
86
|
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
|
@@ -91,10 +93,12 @@ class BatchAnalyze:
|
|
91
93
|
images,
|
92
94
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
93
95
|
)
|
96
|
+
mfr_count = 0
|
94
97
|
for image_index in range(len(images)):
|
95
98
|
images_layout_res[image_index] += images_formula_list[image_index]
|
99
|
+
mfr_count += len(images_formula_list[image_index])
|
96
100
|
logger.info(
|
97
|
-
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {
|
101
|
+
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
|
98
102
|
)
|
99
103
|
|
100
104
|
# 清理显存
|
@@ -159,7 +163,7 @@ class BatchAnalyze:
|
|
159
163
|
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
|
160
164
|
html_code = self.model.table_model.img2html(new_image)
|
161
165
|
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
|
162
|
-
html_code, table_cell_bboxes, elapse = (
|
166
|
+
html_code, table_cell_bboxes, logic_points, elapse = (
|
163
167
|
self.model.table_model.predict(new_image)
|
164
168
|
)
|
165
169
|
run_time = time.time() - single_table_start_time
|
@@ -195,81 +199,81 @@ class BatchAnalyze:
|
|
195
199
|
return images_layout_res
|
196
200
|
|
197
201
|
|
198
|
-
def doc_batch_analyze(
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
) -> InferenceResult:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
202
|
+
# def doc_batch_analyze(
|
203
|
+
# dataset: Dataset,
|
204
|
+
# ocr: bool = False,
|
205
|
+
# show_log: bool = False,
|
206
|
+
# start_page_id=0,
|
207
|
+
# end_page_id=None,
|
208
|
+
# lang=None,
|
209
|
+
# layout_model=None,
|
210
|
+
# formula_enable=None,
|
211
|
+
# table_enable=None,
|
212
|
+
# batch_ratio: int | None = None,
|
213
|
+
# ) -> InferenceResult:
|
214
|
+
# """Perform batch analysis on a document dataset.
|
215
|
+
#
|
216
|
+
# Args:
|
217
|
+
# dataset (Dataset): The dataset containing document pages to be analyzed.
|
218
|
+
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
|
219
|
+
# show_log (bool, optional): Flag to enable logging. Defaults to False.
|
220
|
+
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
|
221
|
+
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
|
222
|
+
# lang (str, optional): Language for OCR. Defaults to None.
|
223
|
+
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
|
224
|
+
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
|
225
|
+
# table_enable (optional): Flag to enable table detection. Defaults to None.
|
226
|
+
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
|
227
|
+
#
|
228
|
+
# Raises:
|
229
|
+
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
|
230
|
+
#
|
231
|
+
# Returns:
|
232
|
+
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
|
233
|
+
# """
|
234
|
+
#
|
235
|
+
# if not torch.cuda.is_available():
|
236
|
+
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
|
237
|
+
#
|
238
|
+
# lang = None if lang == '' else lang
|
239
|
+
# # TODO: auto detect batch size
|
240
|
+
# batch_ratio = 1 if batch_ratio is None else batch_ratio
|
241
|
+
# end_page_id = end_page_id if end_page_id else len(dataset)
|
242
|
+
#
|
243
|
+
# model_manager = ModelSingleton()
|
244
|
+
# custom_model: CustomPEKModel = model_manager.get_model(
|
245
|
+
# ocr, show_log, lang, layout_model, formula_enable, table_enable
|
246
|
+
# )
|
247
|
+
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
248
|
+
#
|
249
|
+
# model_json = []
|
250
|
+
#
|
251
|
+
# # batch analyze
|
252
|
+
# images = []
|
253
|
+
# for index in range(len(dataset)):
|
254
|
+
# if start_page_id <= index <= end_page_id:
|
255
|
+
# page_data = dataset.get_page(index)
|
256
|
+
# img_dict = page_data.get_image()
|
257
|
+
# images.append(img_dict['img'])
|
258
|
+
# analyze_result = batch_model(images)
|
259
|
+
#
|
260
|
+
# for index in range(len(dataset)):
|
261
|
+
# page_data = dataset.get_page(index)
|
262
|
+
# img_dict = page_data.get_image()
|
263
|
+
# page_width = img_dict['width']
|
264
|
+
# page_height = img_dict['height']
|
265
|
+
# if start_page_id <= index <= end_page_id:
|
266
|
+
# result = analyze_result.pop(0)
|
267
|
+
# else:
|
268
|
+
# result = []
|
269
|
+
#
|
270
|
+
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
271
|
+
# page_dict = {'layout_dets': result, 'page_info': page_info}
|
272
|
+
# model_json.append(page_dict)
|
273
|
+
#
|
274
|
+
# # TODO: clean memory when gpu memory is not enough
|
275
|
+
# clean_memory_start_time = time.time()
|
276
|
+
# clean_memory(get_device())
|
277
|
+
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
|
278
|
+
#
|
279
|
+
# return InferenceResult(model_json, dataset)
|
@@ -3,8 +3,12 @@ import time
|
|
3
3
|
|
4
4
|
# 关闭paddle的信号处理
|
5
5
|
import paddle
|
6
|
+
import torch
|
6
7
|
from loguru import logger
|
7
8
|
|
9
|
+
from magic_pdf.model.batch_analyze import BatchAnalyze
|
10
|
+
from magic_pdf.model.sub_modules.model_utils import get_vram
|
11
|
+
|
8
12
|
paddle.disable_signal_handler()
|
9
13
|
|
10
14
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
@@ -154,33 +158,88 @@ def doc_analyze(
|
|
154
158
|
table_enable=None,
|
155
159
|
) -> InferenceResult:
|
156
160
|
|
161
|
+
end_page_id = end_page_id if end_page_id else len(dataset) - 1
|
162
|
+
|
157
163
|
model_manager = ModelSingleton()
|
158
164
|
custom_model = model_manager.get_model(
|
159
165
|
ocr, show_log, lang, layout_model, formula_enable, table_enable
|
160
166
|
)
|
161
167
|
|
168
|
+
batch_analyze = False
|
169
|
+
device = get_device()
|
170
|
+
|
171
|
+
npu_support = False
|
172
|
+
if str(device).startswith("npu"):
|
173
|
+
import torch_npu
|
174
|
+
if torch_npu.npu.is_available():
|
175
|
+
npu_support = True
|
176
|
+
|
177
|
+
if torch.cuda.is_available() and device != 'cpu' or npu_support:
|
178
|
+
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
|
179
|
+
if gpu_memory is not None and gpu_memory >= 8:
|
180
|
+
|
181
|
+
if 8 <= gpu_memory < 10:
|
182
|
+
batch_ratio = 2
|
183
|
+
elif 10 <= gpu_memory <= 12:
|
184
|
+
batch_ratio = 4
|
185
|
+
elif 12 < gpu_memory <= 16:
|
186
|
+
batch_ratio = 8
|
187
|
+
elif 16 < gpu_memory <= 24:
|
188
|
+
batch_ratio = 16
|
189
|
+
else:
|
190
|
+
batch_ratio = 32
|
191
|
+
|
192
|
+
if batch_ratio >= 1:
|
193
|
+
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
194
|
+
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
195
|
+
batch_analyze = True
|
196
|
+
|
162
197
|
model_json = []
|
163
198
|
doc_analyze_start = time.time()
|
164
199
|
|
165
|
-
if
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
200
|
+
if batch_analyze:
|
201
|
+
# batch analyze
|
202
|
+
images = []
|
203
|
+
for index in range(len(dataset)):
|
204
|
+
if start_page_id <= index <= end_page_id:
|
205
|
+
page_data = dataset.get_page(index)
|
206
|
+
img_dict = page_data.get_image()
|
207
|
+
images.append(img_dict['img'])
|
208
|
+
analyze_result = batch_model(images)
|
209
|
+
|
210
|
+
for index in range(len(dataset)):
|
211
|
+
page_data = dataset.get_page(index)
|
212
|
+
img_dict = page_data.get_image()
|
213
|
+
page_width = img_dict['width']
|
214
|
+
page_height = img_dict['height']
|
215
|
+
if start_page_id <= index <= end_page_id:
|
216
|
+
result = analyze_result.pop(0)
|
217
|
+
else:
|
218
|
+
result = []
|
219
|
+
|
220
|
+
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
221
|
+
page_dict = {'layout_dets': result, 'page_info': page_info}
|
222
|
+
model_json.append(page_dict)
|
180
223
|
|
181
|
-
|
182
|
-
|
183
|
-
|
224
|
+
else:
|
225
|
+
# single analyze
|
226
|
+
|
227
|
+
for index in range(len(dataset)):
|
228
|
+
page_data = dataset.get_page(index)
|
229
|
+
img_dict = page_data.get_image()
|
230
|
+
img = img_dict['img']
|
231
|
+
page_width = img_dict['width']
|
232
|
+
page_height = img_dict['height']
|
233
|
+
if start_page_id <= index <= end_page_id:
|
234
|
+
page_start = time.time()
|
235
|
+
result = custom_model(img)
|
236
|
+
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
|
237
|
+
else:
|
238
|
+
result = []
|
239
|
+
|
240
|
+
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
241
|
+
page_dict = {'layout_dets': result, 'page_info': page_info}
|
242
|
+
model_json.append(page_dict)
|
184
243
|
|
185
244
|
gc_start = time.time()
|
186
245
|
clean_memory(get_device())
|
@@ -69,6 +69,7 @@ class CustomPEKModel:
|
|
69
69
|
self.apply_table = self.table_config.get('enable', False)
|
70
70
|
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
|
71
71
|
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
|
72
|
+
self.table_sub_model_name = self.table_config.get('sub_model', None)
|
72
73
|
|
73
74
|
# ocr config
|
74
75
|
self.apply_ocr = ocr
|
@@ -144,7 +145,7 @@ class CustomPEKModel:
|
|
144
145
|
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
|
145
146
|
)
|
146
147
|
),
|
147
|
-
device=self.device,
|
148
|
+
device='cpu' if str(self.device).startswith("mps") else self.device,
|
148
149
|
)
|
149
150
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
150
151
|
self.layout_model = atom_model_manager.get_atom_model(
|
@@ -174,6 +175,7 @@ class CustomPEKModel:
|
|
174
175
|
table_max_time=self.table_max_time,
|
175
176
|
device=self.device,
|
176
177
|
ocr_engine=self.ocr_model,
|
178
|
+
table_sub_model_name=self.table_sub_model_name
|
177
179
|
)
|
178
180
|
|
179
181
|
logger.info('DocAnalysis init done!')
|
@@ -192,24 +194,24 @@ class CustomPEKModel:
|
|
192
194
|
layout_res = self.layout_model(image, ignore_catids=[])
|
193
195
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
194
196
|
# doclayout_yolo
|
195
|
-
if height > width:
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
else:
|
212
|
-
|
197
|
+
# if height > width:
|
198
|
+
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
|
199
|
+
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
|
200
|
+
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
201
|
+
# layout_res = self.layout_model.predict(new_image)
|
202
|
+
# for res in layout_res:
|
203
|
+
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
|
204
|
+
# p1 = p1 - paste_x + xmin
|
205
|
+
# p2 = p2 - paste_y + ymin
|
206
|
+
# p3 = p3 - paste_x + xmin
|
207
|
+
# p4 = p4 - paste_y + ymin
|
208
|
+
# p5 = p5 - paste_x + xmin
|
209
|
+
# p6 = p6 - paste_y + ymin
|
210
|
+
# p7 = p7 - paste_x + xmin
|
211
|
+
# p8 = p8 - paste_y + ymin
|
212
|
+
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
|
213
|
+
# else:
|
214
|
+
layout_res = self.layout_model.predict(image)
|
213
215
|
|
214
216
|
layout_cost = round(time.time() - layout_start, 2)
|
215
217
|
logger.info(f'layout detection time: {layout_cost}')
|
@@ -228,7 +230,7 @@ class CustomPEKModel:
|
|
228
230
|
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
|
229
231
|
|
230
232
|
# 清理显存
|
231
|
-
clean_vram(self.device, vram_threshold=
|
233
|
+
clean_vram(self.device, vram_threshold=6)
|
232
234
|
|
233
235
|
# 从layout_res中获取ocr区域、表格区域、公式区域
|
234
236
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
@@ -276,7 +278,7 @@ class CustomPEKModel:
|
|
276
278
|
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
277
279
|
html_code = self.table_model.img2html(new_image)
|
278
280
|
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
279
|
-
html_code, table_cell_bboxes, elapse = self.table_model.predict(
|
281
|
+
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
|
280
282
|
new_image
|
281
283
|
)
|
282
284
|
run_time = time.time() - single_table_start_time
|
@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
|
|
9
9
|
def predict(self, image):
|
10
10
|
layout_res = []
|
11
11
|
doclayout_yolo_res = self.model.predict(
|
12
|
-
image,
|
12
|
+
image,
|
13
|
+
imgsz=1280,
|
14
|
+
conf=0.10,
|
15
|
+
iou=0.45,
|
16
|
+
verbose=False, device=self.device
|
13
17
|
)[0]
|
14
18
|
for xyxy, conf, cla in zip(
|
15
19
|
doclayout_yolo_res.boxes.xyxy.cpu(),
|
@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
|
|
32
36
|
image_res.cpu()
|
33
37
|
for image_res in self.model.predict(
|
34
38
|
images[index : index + batch_size],
|
35
|
-
imgsz=
|
36
|
-
conf=0.
|
39
|
+
imgsz=1280,
|
40
|
+
conf=0.10,
|
37
41
|
iou=0.45,
|
38
42
|
verbose=False,
|
39
43
|
device=self.device,
|
@@ -89,7 +89,7 @@ class UnimernetModel(object):
|
|
89
89
|
mf_image_list.append(bbox_img)
|
90
90
|
|
91
91
|
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
92
|
-
dataloader = DataLoader(dataset, batch_size=
|
92
|
+
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
|
93
93
|
mfr_res = []
|
94
94
|
for mf_img in dataloader:
|
95
95
|
mf_img = mf_img.to(self.device)
|
@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
|
|
21
21
|
TableMasterPaddleModel
|
22
22
|
|
23
23
|
|
24
|
-
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
|
24
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
|
25
25
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
26
26
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
27
27
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
|
|
31
31
|
}
|
32
32
|
table_model = TableMasterPaddleModel(config)
|
33
33
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
34
|
-
table_model = RapidTableModel(ocr_engine)
|
34
|
+
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
|
35
35
|
else:
|
36
36
|
logger.error('table model type not allow')
|
37
37
|
exit(1)
|
@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
|
|
163
163
|
kwargs.get('table_model_path'),
|
164
164
|
kwargs.get('table_max_time'),
|
165
165
|
kwargs.get('device'),
|
166
|
-
kwargs.get('ocr_engine')
|
166
|
+
kwargs.get('ocr_engine'),
|
167
|
+
kwargs.get('table_sub_model_name')
|
167
168
|
)
|
168
169
|
elif model_name == AtomicModel.LangDetect:
|
169
170
|
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
|