magic-pdf 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +24 -0
- magic_pdf/filter/__init__.py +1 -1
- magic_pdf/filter/pdf_classify_by_type.py +6 -4
- magic_pdf/filter/pdf_meta_scan.py +4 -4
- magic_pdf/libs/boxbase.py +5 -2
- magic_pdf/libs/draw_bbox.py +14 -2
- magic_pdf/libs/language.py +9 -0
- magic_pdf/libs/pdf_check.py +11 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +103 -99
- magic_pdf/model/doc_analyze_by_custom_model.py +87 -36
- magic_pdf/model/magic_model.py +161 -4
- magic_pdf/model/pdf_extract_kit.py +23 -28
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
- magic_pdf/model/sub_modules/model_init.py +34 -19
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
- magic_pdf/pdf_parse_union_core_v2.py +176 -61
- magic_pdf/post_proc/llm_aided.py +55 -24
- magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
- magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
- magic_pdf/resources/model_config/model_configs.yaml +2 -2
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/METADATA +36 -19
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/RECORD +30 -30
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/top_level.txt +0 -0
@@ -126,11 +126,35 @@ def detect_language(text):
|
|
126
126
|
return 'empty'
|
127
127
|
|
128
128
|
|
129
|
+
def full_to_half(text: str) -> str:
|
130
|
+
"""Convert full-width characters to half-width characters using code point manipulation.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
text: String containing full-width characters
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
String with full-width characters converted to half-width
|
137
|
+
"""
|
138
|
+
result = []
|
139
|
+
for char in text:
|
140
|
+
code = ord(char)
|
141
|
+
# Full-width ASCII variants (FF01-FF5E)
|
142
|
+
if 0xFF01 <= code <= 0xFF5E:
|
143
|
+
result.append(chr(code - 0xFEE0)) # Shift to ASCII range
|
144
|
+
# Full-width space
|
145
|
+
elif code == 0x3000:
|
146
|
+
result.append(' ')
|
147
|
+
else:
|
148
|
+
result.append(char)
|
149
|
+
return ''.join(result)
|
150
|
+
|
151
|
+
|
129
152
|
def merge_para_with_text(para_block):
|
130
153
|
block_text = ''
|
131
154
|
for line in para_block['lines']:
|
132
155
|
for span in line['spans']:
|
133
156
|
if span['type'] in [ContentType.Text]:
|
157
|
+
span['content'] = full_to_half(span['content'])
|
134
158
|
block_text += span['content']
|
135
159
|
block_lang = detect_lang(block_text)
|
136
160
|
|
magic_pdf/filter/__init__.py
CHANGED
@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
|
|
23
23
|
pdf_meta['image_info_per_page'],
|
24
24
|
pdf_meta['text_len_per_page'],
|
25
25
|
pdf_meta['imgs_per_page'],
|
26
|
-
pdf_meta['text_layout_per_page'],
|
26
|
+
# pdf_meta['text_layout_per_page'],
|
27
27
|
pdf_meta['invalid_chars'],
|
28
28
|
)
|
29
29
|
if is_text_pdf:
|
@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
|
|
305
305
|
|
306
306
|
|
307
307
|
def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
|
308
|
-
text_layout_list: list,
|
308
|
+
# text_layout_list: list,
|
309
|
+
invalid_chars: bool):
|
309
310
|
"""
|
310
311
|
这里的图片和页面长度单位是pts
|
311
312
|
:param total_page:
|
@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
|
|
321
322
|
'by_text_len': classify_by_text_len(text_len_list, total_page),
|
322
323
|
'by_avg_words': classify_by_avg_words(text_len_list),
|
323
324
|
'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
|
324
|
-
'by_text_layout': classify_by_text_layout(text_layout_list),
|
325
|
+
# 'by_text_layout': classify_by_text_layout(text_layout_list),
|
325
326
|
'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
|
326
327
|
'by_invalid_chars': invalid_chars,
|
327
328
|
}
|
@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
|
|
332
333
|
return False, results
|
333
334
|
else:
|
334
335
|
logger.warning(
|
335
|
-
f"
|
336
|
+
f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
|
336
337
|
f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
|
337
|
-
f" by_text_layout: {results['by_text_layout']},
|
338
|
+
# f" by_text_layout: {results['by_text_layout']},"
|
339
|
+
f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
|
338
340
|
f" by_invalid_chars: {results['by_invalid_chars']}",
|
339
341
|
file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
|
340
342
|
return False, results
|
@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
|
|
356
356
|
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
|
357
357
|
text_len_per_page = get_pdf_textlen_per_page(doc)
|
358
358
|
# logger.info(f"text_len_per_page: {text_len_per_page}")
|
359
|
-
text_layout_per_page = get_pdf_text_layout_per_page(doc)
|
359
|
+
# text_layout_per_page = get_pdf_text_layout_per_page(doc)
|
360
360
|
# logger.info(f"text_layout_per_page: {text_layout_per_page}")
|
361
|
-
text_language = get_language(doc)
|
361
|
+
# text_language = get_language(doc)
|
362
362
|
# logger.info(f"text_language: {text_language}")
|
363
363
|
invalid_chars = check_invalid_chars(pdf_bytes)
|
364
364
|
# logger.info(f"invalid_chars: {invalid_chars}")
|
@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
|
|
372
372
|
'page_height_pts': int(page_height_pts),
|
373
373
|
'image_info_per_page': image_info_per_page,
|
374
374
|
'text_len_per_page': text_len_per_page,
|
375
|
-
'text_layout_per_page': text_layout_per_page,
|
376
|
-
'text_language': text_language,
|
375
|
+
# 'text_layout_per_page': text_layout_per_page,
|
376
|
+
# 'text_language': text_language,
|
377
377
|
# "svgs_per_page": svgs_per_page,
|
378
378
|
'imgs_per_page': imgs_per_page, # 增加每页img数量list
|
379
379
|
'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
|
magic_pdf/libs/boxbase.py
CHANGED
@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
|
|
185
185
|
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
186
186
|
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
187
187
|
|
188
|
+
if any([bbox1_area == 0, bbox2_area == 0]):
|
189
|
+
return 0
|
190
|
+
|
188
191
|
# Compute the intersection over union by taking the intersection area
|
189
192
|
# and dividing it by the sum of both areas minus the intersection area
|
190
|
-
iou = intersection_area / float(bbox1_area + bbox2_area -
|
191
|
-
|
193
|
+
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
|
194
|
+
|
192
195
|
return iou
|
193
196
|
|
194
197
|
|
magic_pdf/libs/draw_bbox.py
CHANGED
@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
|
|
362
362
|
for page in pdf_info:
|
363
363
|
page_line_list = []
|
364
364
|
for block in page['preproc_blocks']:
|
365
|
-
if block['type'] in [BlockType.Text
|
365
|
+
if block['type'] in [BlockType.Text]:
|
366
366
|
for line in block['lines']:
|
367
367
|
bbox = line['bbox']
|
368
368
|
index = line['index']
|
369
369
|
page_line_list.append({'index': index, 'bbox': bbox})
|
370
|
-
|
370
|
+
elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
|
371
|
+
if 'virtual_lines' in block:
|
372
|
+
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
|
373
|
+
for line in block['virtual_lines']:
|
374
|
+
bbox = line['bbox']
|
375
|
+
index = line['index']
|
376
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
377
|
+
else:
|
378
|
+
for line in block['lines']:
|
379
|
+
bbox = line['bbox']
|
380
|
+
index = line['index']
|
381
|
+
page_line_list.append({'index': index, 'bbox': bbox})
|
382
|
+
elif block['type'] in [BlockType.Image, BlockType.Table]:
|
371
383
|
for sub_block in block['blocks']:
|
372
384
|
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
373
385
|
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
|
magic_pdf/libs/language.py
CHANGED
@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
|
|
12
12
|
from fast_langdetect import detect_language
|
13
13
|
|
14
14
|
|
15
|
+
def remove_invalid_surrogates(text):
|
16
|
+
# 移除无效的 UTF-16 代理对
|
17
|
+
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
|
18
|
+
|
19
|
+
|
15
20
|
def detect_lang(text: str) -> str:
|
16
21
|
|
17
22
|
if len(text) == 0:
|
18
23
|
return ""
|
19
24
|
|
20
25
|
text = text.replace("\n", "")
|
26
|
+
text = remove_invalid_surrogates(text)
|
27
|
+
|
28
|
+
# print(text)
|
21
29
|
try:
|
22
30
|
lang_upper = detect_language(text)
|
23
31
|
except:
|
@@ -37,3 +45,4 @@ if __name__ == '__main__':
|
|
37
45
|
print(detect_lang("<html>This is a test</html>"))
|
38
46
|
print(detect_lang("这个是中文测试。"))
|
39
47
|
print(detect_lang("<html>这个是中文测试。</html>"))
|
48
|
+
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
|
magic_pdf/libs/pdf_check.py
CHANGED
@@ -4,6 +4,7 @@ from loguru import logger
|
|
4
4
|
import re
|
5
5
|
from io import BytesIO
|
6
6
|
from pdfminer.high_level import extract_text
|
7
|
+
from pdfminer.layout import LAParams
|
7
8
|
|
8
9
|
|
9
10
|
def calculate_sample_count(total_page: int):
|
@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
|
|
41
42
|
sample_docs = extract_pages(src_pdf_bytes)
|
42
43
|
sample_pdf_bytes = sample_docs.tobytes()
|
43
44
|
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
|
44
|
-
|
45
|
+
laparams = LAParams(
|
46
|
+
line_overlap=0.5,
|
47
|
+
char_margin=2.0,
|
48
|
+
line_margin=0.5,
|
49
|
+
word_margin=0.1,
|
50
|
+
boxes_flow=None,
|
51
|
+
detect_vertical=False,
|
52
|
+
all_texts=False,
|
53
|
+
)
|
54
|
+
text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
|
45
55
|
text = text.replace("\n", "")
|
46
56
|
# logger.info(text)
|
47
57
|
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.0
|
1
|
+
__version__ = "1.2.0"
|
magic_pdf/model/batch_analyze.py
CHANGED
@@ -7,19 +7,19 @@ from loguru import logger
|
|
7
7
|
from PIL import Image
|
8
8
|
|
9
9
|
from magic_pdf.config.constants import MODEL_NAME
|
10
|
-
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
11
|
-
from magic_pdf.data.dataset import Dataset
|
12
|
-
from magic_pdf.libs.clean_memory import clean_memory
|
13
|
-
from magic_pdf.libs.config_reader import get_device
|
14
|
-
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
10
|
+
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
|
11
|
+
# from magic_pdf.data.dataset import Dataset
|
12
|
+
# from magic_pdf.libs.clean_memory import clean_memory
|
13
|
+
# from magic_pdf.libs.config_reader import get_device
|
14
|
+
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
15
15
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
16
16
|
from magic_pdf.model.sub_modules.model_utils import (
|
17
17
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
18
18
|
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
19
19
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
20
|
-
from magic_pdf.operators.models import InferenceResult
|
20
|
+
# from magic_pdf.operators.models import InferenceResult
|
21
21
|
|
22
|
-
YOLO_LAYOUT_BASE_BATCH_SIZE =
|
22
|
+
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
23
23
|
MFD_BASE_BATCH_SIZE = 1
|
24
24
|
MFR_BASE_BATCH_SIZE = 16
|
25
25
|
|
@@ -44,19 +44,20 @@ class BatchAnalyze:
|
|
44
44
|
modified_images = []
|
45
45
|
for image_index, image in enumerate(images):
|
46
46
|
pil_img = Image.fromarray(image)
|
47
|
-
width, height = pil_img.size
|
48
|
-
if height > width:
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
else:
|
56
|
-
|
47
|
+
# width, height = pil_img.size
|
48
|
+
# if height > width:
|
49
|
+
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
|
50
|
+
# new_image, useful_list = crop_img(
|
51
|
+
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
|
52
|
+
# )
|
53
|
+
# layout_images.append(new_image)
|
54
|
+
# modified_images.append([image_index, useful_list])
|
55
|
+
# else:
|
56
|
+
layout_images.append(pil_img)
|
57
57
|
|
58
58
|
images_layout_res += self.model.layout_model.batch_predict(
|
59
|
-
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
|
59
|
+
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
|
60
|
+
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
60
61
|
)
|
61
62
|
|
62
63
|
for image_index, useful_list in modified_images:
|
@@ -78,7 +79,8 @@ class BatchAnalyze:
|
|
78
79
|
# 公式检测
|
79
80
|
mfd_start_time = time.time()
|
80
81
|
images_mfd_res = self.model.mfd_model.batch_predict(
|
81
|
-
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
|
82
|
+
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE
|
83
|
+
images, MFD_BASE_BATCH_SIZE
|
82
84
|
)
|
83
85
|
logger.info(
|
84
86
|
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
|
@@ -91,10 +93,12 @@ class BatchAnalyze:
|
|
91
93
|
images,
|
92
94
|
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
93
95
|
)
|
96
|
+
mfr_count = 0
|
94
97
|
for image_index in range(len(images)):
|
95
98
|
images_layout_res[image_index] += images_formula_list[image_index]
|
99
|
+
mfr_count += len(images_formula_list[image_index])
|
96
100
|
logger.info(
|
97
|
-
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {
|
101
|
+
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
|
98
102
|
)
|
99
103
|
|
100
104
|
# 清理显存
|
@@ -159,7 +163,7 @@ class BatchAnalyze:
|
|
159
163
|
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
|
160
164
|
html_code = self.model.table_model.img2html(new_image)
|
161
165
|
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
|
162
|
-
html_code, table_cell_bboxes, elapse = (
|
166
|
+
html_code, table_cell_bboxes, logic_points, elapse = (
|
163
167
|
self.model.table_model.predict(new_image)
|
164
168
|
)
|
165
169
|
run_time = time.time() - single_table_start_time
|
@@ -195,81 +199,81 @@ class BatchAnalyze:
|
|
195
199
|
return images_layout_res
|
196
200
|
|
197
201
|
|
198
|
-
def doc_batch_analyze(
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
) -> InferenceResult:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
202
|
+
# def doc_batch_analyze(
|
203
|
+
# dataset: Dataset,
|
204
|
+
# ocr: bool = False,
|
205
|
+
# show_log: bool = False,
|
206
|
+
# start_page_id=0,
|
207
|
+
# end_page_id=None,
|
208
|
+
# lang=None,
|
209
|
+
# layout_model=None,
|
210
|
+
# formula_enable=None,
|
211
|
+
# table_enable=None,
|
212
|
+
# batch_ratio: int | None = None,
|
213
|
+
# ) -> InferenceResult:
|
214
|
+
# """Perform batch analysis on a document dataset.
|
215
|
+
#
|
216
|
+
# Args:
|
217
|
+
# dataset (Dataset): The dataset containing document pages to be analyzed.
|
218
|
+
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
|
219
|
+
# show_log (bool, optional): Flag to enable logging. Defaults to False.
|
220
|
+
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
|
221
|
+
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
|
222
|
+
# lang (str, optional): Language for OCR. Defaults to None.
|
223
|
+
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
|
224
|
+
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
|
225
|
+
# table_enable (optional): Flag to enable table detection. Defaults to None.
|
226
|
+
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
|
227
|
+
#
|
228
|
+
# Raises:
|
229
|
+
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
|
230
|
+
#
|
231
|
+
# Returns:
|
232
|
+
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
|
233
|
+
# """
|
234
|
+
#
|
235
|
+
# if not torch.cuda.is_available():
|
236
|
+
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
|
237
|
+
#
|
238
|
+
# lang = None if lang == '' else lang
|
239
|
+
# # TODO: auto detect batch size
|
240
|
+
# batch_ratio = 1 if batch_ratio is None else batch_ratio
|
241
|
+
# end_page_id = end_page_id if end_page_id else len(dataset)
|
242
|
+
#
|
243
|
+
# model_manager = ModelSingleton()
|
244
|
+
# custom_model: CustomPEKModel = model_manager.get_model(
|
245
|
+
# ocr, show_log, lang, layout_model, formula_enable, table_enable
|
246
|
+
# )
|
247
|
+
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
248
|
+
#
|
249
|
+
# model_json = []
|
250
|
+
#
|
251
|
+
# # batch analyze
|
252
|
+
# images = []
|
253
|
+
# for index in range(len(dataset)):
|
254
|
+
# if start_page_id <= index <= end_page_id:
|
255
|
+
# page_data = dataset.get_page(index)
|
256
|
+
# img_dict = page_data.get_image()
|
257
|
+
# images.append(img_dict['img'])
|
258
|
+
# analyze_result = batch_model(images)
|
259
|
+
#
|
260
|
+
# for index in range(len(dataset)):
|
261
|
+
# page_data = dataset.get_page(index)
|
262
|
+
# img_dict = page_data.get_image()
|
263
|
+
# page_width = img_dict['width']
|
264
|
+
# page_height = img_dict['height']
|
265
|
+
# if start_page_id <= index <= end_page_id:
|
266
|
+
# result = analyze_result.pop(0)
|
267
|
+
# else:
|
268
|
+
# result = []
|
269
|
+
#
|
270
|
+
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
|
271
|
+
# page_dict = {'layout_dets': result, 'page_info': page_info}
|
272
|
+
# model_json.append(page_dict)
|
273
|
+
#
|
274
|
+
# # TODO: clean memory when gpu memory is not enough
|
275
|
+
# clean_memory_start_time = time.time()
|
276
|
+
# clean_memory(get_device())
|
277
|
+
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
|
278
|
+
#
|
279
|
+
# return InferenceResult(model_json, dataset)
|
@@ -1,17 +1,22 @@
|
|
1
1
|
import os
|
2
2
|
import time
|
3
|
+
import torch
|
3
4
|
|
5
|
+
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
|
6
|
+
os.environ['FLAGS_use_stride_kernel'] = '0'
|
7
|
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
|
8
|
+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
4
9
|
# 关闭paddle的信号处理
|
5
10
|
import paddle
|
6
|
-
from loguru import logger
|
7
|
-
|
8
11
|
paddle.disable_signal_handler()
|
9
12
|
|
10
|
-
|
13
|
+
from loguru import logger
|
14
|
+
|
15
|
+
from magic_pdf.model.batch_analyze import BatchAnalyze
|
16
|
+
from magic_pdf.model.sub_modules.model_utils import get_vram
|
11
17
|
|
12
18
|
try:
|
13
19
|
import torchtext
|
14
|
-
|
15
20
|
if torchtext.__version__ >= '0.18.0':
|
16
21
|
torchtext.disable_torchtext_deprecation_warning()
|
17
22
|
except ImportError:
|
@@ -28,20 +33,6 @@ from magic_pdf.model.model_list import MODEL
|
|
28
33
|
from magic_pdf.operators.models import InferenceResult
|
29
34
|
|
30
35
|
|
31
|
-
def dict_compare(d1, d2):
|
32
|
-
return d1.items() == d2.items()
|
33
|
-
|
34
|
-
|
35
|
-
def remove_duplicates_dicts(lst):
|
36
|
-
unique_dicts = []
|
37
|
-
for dict_item in lst:
|
38
|
-
if not any(
|
39
|
-
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
|
40
|
-
):
|
41
|
-
unique_dicts.append(dict_item)
|
42
|
-
return unique_dicts
|
43
|
-
|
44
|
-
|
45
36
|
class ModelSingleton:
|
46
37
|
_instance = None
|
47
38
|
_models = {}
|
@@ -154,33 +145,93 @@ def doc_analyze(
|
|
154
145
|
table_enable=None,
|
155
146
|
) -> InferenceResult:
|
156
147
|
|
148
|
+
end_page_id = (
|
149
|
+
end_page_id
|
150
|
+
if end_page_id is not None and end_page_id >= 0
|
151
|
+
else len(dataset) - 1
|
152
|
+
)
|
153
|
+
|
157
154
|
model_manager = ModelSingleton()
|
158
155
|
custom_model = model_manager.get_model(
|
159
156
|
ocr, show_log, lang, layout_model, formula_enable, table_enable
|
160
157
|
)
|
161
158
|
|
159
|
+
batch_analyze = False
|
160
|
+
batch_ratio = 1
|
161
|
+
device = get_device()
|
162
|
+
|
163
|
+
npu_support = False
|
164
|
+
if str(device).startswith("npu"):
|
165
|
+
import torch_npu
|
166
|
+
if torch_npu.npu.is_available():
|
167
|
+
npu_support = True
|
168
|
+
|
169
|
+
if torch.cuda.is_available() and device != 'cpu' or npu_support:
|
170
|
+
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
|
171
|
+
if gpu_memory is not None and gpu_memory >= 8:
|
172
|
+
|
173
|
+
if gpu_memory >= 40:
|
174
|
+
batch_ratio = 32
|
175
|
+
elif gpu_memory >=20:
|
176
|
+
batch_ratio = 16
|
177
|
+
elif gpu_memory >= 16:
|
178
|
+
batch_ratio = 8
|
179
|
+
elif gpu_memory >= 10:
|
180
|
+
batch_ratio = 4
|
181
|
+
else:
|
182
|
+
batch_ratio = 2
|
183
|
+
|
184
|
+
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
185
|
+
batch_analyze = True
|
186
|
+
|
162
187
|
model_json = []
|
163
188
|
doc_analyze_start = time.time()
|
164
189
|
|
165
|
-
if
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
190
|
+
if batch_analyze:
|
191
|
+
# batch analyze
|
192
|
+
images = []
|
193
|
+
page_wh_list = []
|
194
|
+
for index in range(len(dataset)):
|
195
|
+
if start_page_id <= index <= end_page_id:
|
196
|
+
page_data = dataset.get_page(index)
|
197
|
+
img_dict = page_data.get_image()
|
198
|
+
images.append(img_dict['img'])
|
199
|
+
page_wh_list.append((img_dict['width'], img_dict['height']))
|
200
|
+
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
201
|
+
analyze_result = batch_model(images)
|
202
|
+
|
203
|
+
for index in range(len(dataset)):
|
204
|
+
if start_page_id <= index <= end_page_id:
|
205
|
+
result = analyze_result.pop(0)
|
206
|
+
page_width, page_height = page_wh_list.pop(0)
|
207
|
+
else:
|
208
|
+
result = []
|
209
|
+
page_height = 0
|
210
|
+
page_width = 0
|
211
|
+
|
212
|
+
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
213
|
+
page_dict = {'layout_dets': result, 'page_info': page_info}
|
214
|
+
model_json.append(page_dict)
|
180
215
|
|
181
|
-
|
182
|
-
|
183
|
-
|
216
|
+
else:
|
217
|
+
# single analyze
|
218
|
+
|
219
|
+
for index in range(len(dataset)):
|
220
|
+
page_data = dataset.get_page(index)
|
221
|
+
img_dict = page_data.get_image()
|
222
|
+
img = img_dict['img']
|
223
|
+
page_width = img_dict['width']
|
224
|
+
page_height = img_dict['height']
|
225
|
+
if start_page_id <= index <= end_page_id:
|
226
|
+
page_start = time.time()
|
227
|
+
result = custom_model(img)
|
228
|
+
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
|
229
|
+
else:
|
230
|
+
result = []
|
231
|
+
|
232
|
+
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
233
|
+
page_dict = {'layout_dets': result, 'page_info': page_info}
|
234
|
+
model_json.append(page_dict)
|
184
235
|
|
185
236
|
gc_start = time.time()
|
186
237
|
clean_memory(get_device())
|