magic-pdf 0.10.5__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +7 -0
- magic_pdf/config/exceptions.py +7 -0
- magic_pdf/data/data_reader_writer/base.py +13 -1
- magic_pdf/data/data_reader_writer/filebase.py +1 -1
- magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
- magic_pdf/data/dataset.py +188 -5
- magic_pdf/data/read_api.py +59 -12
- magic_pdf/data/utils.py +35 -0
- magic_pdf/dict2md/ocr_mkcontent.py +16 -15
- magic_pdf/filter/__init__.py +32 -0
- magic_pdf/filter/pdf_meta_scan.py +3 -2
- magic_pdf/libs/clean_memory.py +11 -4
- magic_pdf/libs/config_reader.py +9 -0
- magic_pdf/libs/draw_bbox.py +19 -22
- magic_pdf/libs/language.py +3 -0
- magic_pdf/libs/pdf_check.py +30 -30
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -1
- magic_pdf/model/batch_analyze.py +275 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +104 -92
- magic_pdf/model/magic_model.py +4 -435
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +35 -5
- magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
- magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
- magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
- magic_pdf/model/sub_modules/model_init.py +43 -7
- magic_pdf/model/sub_modules/model_utils.py +17 -5
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
- magic_pdf/operators/__init__.py +94 -0
- magic_pdf/operators/models.py +154 -0
- magic_pdf/operators/pipes.py +191 -0
- magic_pdf/pdf_parse_union_core_v2.py +77 -27
- magic_pdf/post_proc/__init__.py +1 -0
- magic_pdf/post_proc/llm_aided.py +133 -0
- magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
- magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
- magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
- magic_pdf/tools/cli.py +36 -11
- magic_pdf/tools/common.py +120 -61
- magic_pdf/utils/office_to_pdf.py +29 -0
- {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/METADATA +78 -25
- {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/RECORD +54 -55
- magic_pdf/para/__init__.py +0 -0
- magic_pdf/pdf_parse_by_ocr.py +0 -23
- magic_pdf/pdf_parse_by_txt.py +0 -24
- magic_pdf/pipe/AbsPipe.py +0 -98
- magic_pdf/pipe/OCRPipe.py +0 -41
- magic_pdf/pipe/TXTPipe.py +0 -41
- magic_pdf/pipe/UNIPipe.py +0 -98
- magic_pdf/pipe/__init__.py +0 -0
- magic_pdf/rw/AbsReaderWriter.py +0 -17
- magic_pdf/rw/DiskReaderWriter.py +0 -74
- magic_pdf/rw/S3ReaderWriter.py +0 -142
- magic_pdf/rw/__init__.py +0 -0
- magic_pdf/user_api.py +0 -121
- /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
- {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/top_level.txt +0 -0
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import fitz
|
2
2
|
from magic_pdf.config.constants import CROSS_PAGE
|
3
|
-
from magic_pdf.config.ocr_content_type import BlockType, CategoryId,
|
4
|
-
|
3
|
+
from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
|
4
|
+
ContentType)
|
5
|
+
from magic_pdf.data.dataset import Dataset
|
5
6
|
from magic_pdf.model.magic_model import MagicModel
|
6
7
|
|
7
8
|
|
@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
194
195
|
)
|
195
196
|
|
196
197
|
# Save the PDF
|
197
|
-
pdf_docs.save(f'{out_path}/{filename}
|
198
|
+
pdf_docs.save(f'{out_path}/{filename}')
|
198
199
|
|
199
200
|
|
200
201
|
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
|
@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
282
283
|
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
|
283
284
|
|
284
285
|
# Save the PDF
|
285
|
-
pdf_docs.save(f'{out_path}/{filename}
|
286
|
+
pdf_docs.save(f'{out_path}/{filename}')
|
286
287
|
|
287
288
|
|
288
|
-
def draw_model_bbox(model_list:
|
289
|
+
def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
|
289
290
|
dropped_bbox_list = []
|
290
291
|
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
|
291
292
|
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
|
292
293
|
titles_list = []
|
293
294
|
texts_list = []
|
294
295
|
interequations_list = []
|
295
|
-
|
296
|
-
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
|
296
|
+
magic_model = MagicModel(model_list, dataset)
|
297
297
|
for i in range(len(model_list)):
|
298
298
|
page_dropped_list = []
|
299
299
|
tables_body, tables_caption, tables_footnote = [], [], []
|
@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
|
|
337
337
|
dropped_bbox_list.append(page_dropped_list)
|
338
338
|
imgs_footnote_list.append(imgs_footnote)
|
339
339
|
|
340
|
-
for i
|
340
|
+
for i in range(len(dataset)):
|
341
|
+
page = dataset.get_page(i)
|
341
342
|
draw_bbox_with_number(
|
342
343
|
i, dropped_bbox_list, page, [158, 158, 158], True
|
343
344
|
) # color !
|
@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
|
|
352
353
|
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
|
353
354
|
|
354
355
|
# Save the PDF
|
355
|
-
|
356
|
+
dataset.dump_to_file(f'{out_path}/{filename}')
|
356
357
|
|
357
358
|
|
358
359
|
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
@@ -390,20 +391,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
390
391
|
for i, page in enumerate(pdf_docs):
|
391
392
|
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
|
392
393
|
|
393
|
-
pdf_docs.save(f'{out_path}/{filename}
|
394
|
-
|
394
|
+
pdf_docs.save(f'{out_path}/{filename}')
|
395
395
|
|
396
|
-
def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
397
|
-
layout_bbox_list = []
|
398
396
|
|
399
|
-
|
400
|
-
page_block_list = []
|
401
|
-
for block in page['para_blocks']:
|
402
|
-
bbox = block['bbox']
|
403
|
-
page_block_list.append(bbox)
|
404
|
-
layout_bbox_list.append(page_block_list)
|
397
|
+
def draw_char_bbox(pdf_bytes, out_path, filename):
|
405
398
|
pdf_docs = fitz.open('pdf', pdf_bytes)
|
406
399
|
for i, page in enumerate(pdf_docs):
|
407
|
-
|
408
|
-
|
409
|
-
|
400
|
+
for block in page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']:
|
401
|
+
for line in block['lines']:
|
402
|
+
for span in line['spans']:
|
403
|
+
for char in span['chars']:
|
404
|
+
char_bbox = char['bbox']
|
405
|
+
page.draw_rect(char_bbox, color=[1, 0, 0], fill=None, fill_opacity=1, width=0.3, overlay=True,)
|
406
|
+
pdf_docs.save(f'{out_path}/{filename}')
|
magic_pdf/libs/language.py
CHANGED
@@ -16,11 +16,14 @@ def detect_lang(text: str) -> str:
|
|
16
16
|
|
17
17
|
if len(text) == 0:
|
18
18
|
return ""
|
19
|
+
|
20
|
+
text = text.replace("\n", "")
|
19
21
|
try:
|
20
22
|
lang_upper = detect_language(text)
|
21
23
|
except:
|
22
24
|
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
|
23
25
|
lang_upper = detect_language(html_no_ctrl_chars)
|
26
|
+
|
24
27
|
try:
|
25
28
|
lang = lang_upper.lower()
|
26
29
|
except:
|
magic_pdf/libs/pdf_check.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
import fitz
|
2
2
|
import numpy as np
|
3
3
|
from loguru import logger
|
4
|
-
|
5
|
-
|
6
|
-
|
4
|
+
import re
|
5
|
+
from io import BytesIO
|
6
|
+
from pdfminer.high_level import extract_text
|
7
7
|
|
8
8
|
|
9
9
|
def calculate_sample_count(total_page: int):
|
@@ -33,33 +33,33 @@ def extract_pages(src_pdf_bytes: bytes) -> fitz.Document:
|
|
33
33
|
return sample_docs
|
34
34
|
|
35
35
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
#
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
36
|
+
def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
|
37
|
+
""""
|
38
|
+
检测PDF中是否包含非法字符
|
39
|
+
"""
|
40
|
+
'''pdfminer比较慢,需要先随机抽取10页左右的sample'''
|
41
|
+
sample_docs = extract_pages(src_pdf_bytes)
|
42
|
+
sample_pdf_bytes = sample_docs.tobytes()
|
43
|
+
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
|
44
|
+
text = extract_text(sample_pdf_file_like_object)
|
45
|
+
text = text.replace("\n", "")
|
46
|
+
# logger.info(text)
|
47
|
+
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
|
48
|
+
cid_pattern = re.compile(r'\(cid:\d+\)')
|
49
|
+
matches = cid_pattern.findall(text)
|
50
|
+
cid_count = len(matches)
|
51
|
+
cid_len = sum(len(match) for match in matches)
|
52
|
+
text_len = len(text)
|
53
|
+
if text_len == 0:
|
54
|
+
cid_chars_radio = 0
|
55
|
+
else:
|
56
|
+
cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
|
57
|
+
logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
|
58
|
+
'''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
|
59
|
+
if cid_chars_radio > 0.05:
|
60
|
+
return False # 乱码文档
|
61
|
+
else:
|
62
|
+
return True # 正常文档
|
63
63
|
|
64
64
|
|
65
65
|
def count_replacement_characters(text: str) -> int:
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "1.0.0"
|
magic_pdf/model/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
1
|
__use_inside_model__ = True
|
2
|
-
__model_mode__ =
|
2
|
+
__model_mode__ = 'full'
|
@@ -0,0 +1,275 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from loguru import logger
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
from magic_pdf.config.constants import MODEL_NAME
|
10
|
+
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
11
|
+
from magic_pdf.data.dataset import Dataset
|
12
|
+
from magic_pdf.libs.clean_memory import clean_memory
|
13
|
+
from magic_pdf.libs.config_reader import get_device
|
14
|
+
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
15
|
+
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
16
|
+
from magic_pdf.model.sub_modules.model_utils import (
|
17
|
+
clean_vram, crop_img, get_res_list_from_layout_res)
|
18
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
19
|
+
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
20
|
+
from magic_pdf.operators.models import InferenceResult
|
21
|
+
|
22
|
+
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
|
23
|
+
MFD_BASE_BATCH_SIZE = 1
|
24
|
+
MFR_BASE_BATCH_SIZE = 16
|
25
|
+
|
26
|
+
|
27
|
+
class BatchAnalyze:
|
28
|
+
def __init__(self, model: CustomPEKModel, batch_ratio: int):
|
29
|
+
self.model = model
|
30
|
+
self.batch_ratio = batch_ratio
|
31
|
+
|
32
|
+
def __call__(self, images: list) -> list:
|
33
|
+
images_layout_res = []
|
34
|
+
|
35
|
+
layout_start_time = time.time()
|
36
|
+
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
37
|
+
# layoutlmv3
|
38
|
+
for image in images:
|
39
|
+
layout_res = self.model.layout_model(image, ignore_catids=[])
|
40
|
+
images_layout_res.append(layout_res)
|
41
|
+
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
42
|
+
# doclayout_yolo
|
43
|
+
layout_images = []
|
44
|
+
modified_images = []
|
45
|
+
for image_index, image in enumerate(images):
|
46
|
+
pil_img = Image.fromarray(image)
|
47
|
+
width, height = pil_img.size
|
48
|
+
if height > width:
|
49
|
+
input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
|
50
|
+
new_image, useful_list = crop_img(
|
51
|
+
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
|
52
|
+
)
|
53
|
+
layout_images.append(new_image)
|
54
|
+
modified_images.append([image_index, useful_list])
|
55
|
+
else:
|
56
|
+
layout_images.append(pil_img)
|
57
|
+
|
58
|
+
images_layout_res += self.model.layout_model.batch_predict(
|
59
|
+
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
|
60
|
+
)
|
61
|
+
|
62
|
+
for image_index, useful_list in modified_images:
|
63
|
+
for res in images_layout_res[image_index]:
|
64
|
+
for i in range(len(res['poly'])):
|
65
|
+
if i % 2 == 0:
|
66
|
+
res['poly'][i] = (
|
67
|
+
res['poly'][i] - useful_list[0] + useful_list[2]
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
res['poly'][i] = (
|
71
|
+
res['poly'][i] - useful_list[1] + useful_list[3]
|
72
|
+
)
|
73
|
+
logger.info(
|
74
|
+
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
|
75
|
+
)
|
76
|
+
|
77
|
+
if self.model.apply_formula:
|
78
|
+
# 公式检测
|
79
|
+
mfd_start_time = time.time()
|
80
|
+
images_mfd_res = self.model.mfd_model.batch_predict(
|
81
|
+
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
|
82
|
+
)
|
83
|
+
logger.info(
|
84
|
+
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
|
85
|
+
)
|
86
|
+
|
87
|
+
# 公式识别
|
88
|
+
mfr_start_time = time.time()
|
89
|
+
images_formula_list = self.model.mfr_model.batch_predict(
|
90
|
+
images_mfd_res,
|
91
|
+
images,
|
92
|
+
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
93
|
+
)
|
94
|
+
for image_index in range(len(images)):
|
95
|
+
images_layout_res[image_index] += images_formula_list[image_index]
|
96
|
+
logger.info(
|
97
|
+
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
|
98
|
+
)
|
99
|
+
|
100
|
+
# 清理显存
|
101
|
+
clean_vram(self.model.device, vram_threshold=8)
|
102
|
+
|
103
|
+
ocr_time = 0
|
104
|
+
ocr_count = 0
|
105
|
+
table_time = 0
|
106
|
+
table_count = 0
|
107
|
+
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
|
108
|
+
for index in range(len(images)):
|
109
|
+
layout_res = images_layout_res[index]
|
110
|
+
pil_img = Image.fromarray(images[index])
|
111
|
+
|
112
|
+
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
113
|
+
get_res_list_from_layout_res(layout_res)
|
114
|
+
)
|
115
|
+
# ocr识别
|
116
|
+
ocr_start = time.time()
|
117
|
+
# Process each area that requires OCR processing
|
118
|
+
for res in ocr_res_list:
|
119
|
+
new_image, useful_list = crop_img(
|
120
|
+
res, pil_img, crop_paste_x=50, crop_paste_y=50
|
121
|
+
)
|
122
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
123
|
+
single_page_mfdetrec_res, useful_list
|
124
|
+
)
|
125
|
+
|
126
|
+
# OCR recognition
|
127
|
+
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
128
|
+
|
129
|
+
if self.model.apply_ocr:
|
130
|
+
ocr_res = self.model.ocr_model.ocr(
|
131
|
+
new_image, mfd_res=adjusted_mfdetrec_res
|
132
|
+
)[0]
|
133
|
+
else:
|
134
|
+
ocr_res = self.model.ocr_model.ocr(
|
135
|
+
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
136
|
+
)[0]
|
137
|
+
|
138
|
+
# Integration results
|
139
|
+
if ocr_res:
|
140
|
+
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
141
|
+
layout_res.extend(ocr_result_list)
|
142
|
+
ocr_time += time.time() - ocr_start
|
143
|
+
ocr_count += len(ocr_res_list)
|
144
|
+
|
145
|
+
# 表格识别 table recognition
|
146
|
+
if self.model.apply_table:
|
147
|
+
table_start = time.time()
|
148
|
+
for res in table_res_list:
|
149
|
+
new_image, _ = crop_img(res, pil_img)
|
150
|
+
single_table_start_time = time.time()
|
151
|
+
html_code = None
|
152
|
+
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
153
|
+
with torch.no_grad():
|
154
|
+
table_result = self.model.table_model.predict(
|
155
|
+
new_image, 'html'
|
156
|
+
)
|
157
|
+
if len(table_result) > 0:
|
158
|
+
html_code = table_result[0]
|
159
|
+
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
|
160
|
+
html_code = self.model.table_model.img2html(new_image)
|
161
|
+
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
|
162
|
+
html_code, table_cell_bboxes, elapse = (
|
163
|
+
self.model.table_model.predict(new_image)
|
164
|
+
)
|
165
|
+
run_time = time.time() - single_table_start_time
|
166
|
+
if run_time > self.model.table_max_time:
|
167
|
+
logger.warning(
|
168
|
+
f'table recognition processing exceeds max time {self.model.table_max_time}s'
|
169
|
+
)
|
170
|
+
# 判断是否返回正常
|
171
|
+
if html_code:
|
172
|
+
expected_ending = html_code.strip().endswith(
|
173
|
+
'</html>'
|
174
|
+
) or html_code.strip().endswith('</table>')
|
175
|
+
if expected_ending:
|
176
|
+
res['html'] = html_code
|
177
|
+
else:
|
178
|
+
logger.warning(
|
179
|
+
'table recognition processing fails, not found expected HTML table end'
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
logger.warning(
|
183
|
+
'table recognition processing fails, not get html return'
|
184
|
+
)
|
185
|
+
table_time += time.time() - table_start
|
186
|
+
table_count += len(table_res_list)
|
187
|
+
|
188
|
+
if self.model.apply_ocr:
|
189
|
+
logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
|
190
|
+
else:
|
191
|
+
logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
|
192
|
+
if self.model.apply_table:
|
193
|
+
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
|
194
|
+
|
195
|
+
return images_layout_res
|
196
|
+
|
197
|
+
|
198
|
+
def doc_batch_analyze(
|
199
|
+
dataset: Dataset,
|
200
|
+
ocr: bool = False,
|
201
|
+
show_log: bool = False,
|
202
|
+
start_page_id=0,
|
203
|
+
end_page_id=None,
|
204
|
+
lang=None,
|
205
|
+
layout_model=None,
|
206
|
+
formula_enable=None,
|
207
|
+
table_enable=None,
|
208
|
+
batch_ratio: int | None = None,
|
209
|
+
) -> InferenceResult:
|
210
|
+
"""Perform batch analysis on a document dataset.
|
211
|
+
|
212
|
+
Args:
|
213
|
+
dataset (Dataset): The dataset containing document pages to be analyzed.
|
214
|
+
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
|
215
|
+
show_log (bool, optional): Flag to enable logging. Defaults to False.
|
216
|
+
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
|
217
|
+
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
|
218
|
+
lang (str, optional): Language for OCR. Defaults to None.
|
219
|
+
layout_model (optional): Layout model to be used for analysis. Defaults to None.
|
220
|
+
formula_enable (optional): Flag to enable formula detection. Defaults to None.
|
221
|
+
table_enable (optional): Flag to enable table detection. Defaults to None.
|
222
|
+
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
|
223
|
+
|
224
|
+
Raises:
|
225
|
+
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
|
229
|
+
"""
|
230
|
+
|
231
|
+
if not torch.cuda.is_available():
|
232
|
+
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
|
233
|
+
|
234
|
+
lang = None if lang == '' else lang
|
235
|
+
# TODO: auto detect batch size
|
236
|
+
batch_ratio = 1 if batch_ratio is None else batch_ratio
|
237
|
+
end_page_id = end_page_id if end_page_id else len(dataset)
|
238
|
+
|
239
|
+
model_manager = ModelSingleton()
|
240
|
+
custom_model: CustomPEKModel = model_manager.get_model(
|
241
|
+
ocr, show_log, lang, layout_model, formula_enable, table_enable
|
242
|
+
)
|
243
|
+
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
244
|
+
|
245
|
+
model_json = []
|
246
|
+
|
247
|
+
# batch analyze
|
248
|
+
images = []
|
249
|
+
for index in range(len(dataset)):
|
250
|
+
if start_page_id <= index <= end_page_id:
|
251
|
+
page_data = dataset.get_page(index)
|
252
|
+
img_dict = page_data.get_image()
|
253
|
+
images.append(img_dict['img'])
|
254
|
+
analyze_result = batch_model(images)
|
255
|
+
|
256
|
+
for index in range(len(dataset)):
|
257
|
+
page_data = dataset.get_page(index)
|
258
|
+
img_dict = page_data.get_image()
|
259
|
+
page_width = img_dict['width']
|
260
|
+
page_height = img_dict['height']
|
261
|
+
if start_page_id <= index <= end_page_id:
|
262
|
+
result = analyze_result.pop(0)
|
263
|
+
else:
|
264
|
+
result = []
|
265
|
+
|
266
|
+
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
267
|
+
page_dict = {'layout_dets': result, 'page_info': page_info}
|
268
|
+
model_json.append(page_dict)
|
269
|
+
|
270
|
+
# TODO: clean memory when gpu memory is not enough
|
271
|
+
clean_memory_start_time = time.time()
|
272
|
+
clean_memory(get_device())
|
273
|
+
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
|
274
|
+
|
275
|
+
return InferenceResult(model_json, dataset)
|