magic-pdf 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +24 -0
- magic_pdf/filter/__init__.py +1 -1
- magic_pdf/filter/pdf_classify_by_type.py +6 -4
- magic_pdf/filter/pdf_meta_scan.py +4 -4
- magic_pdf/libs/boxbase.py +5 -2
- magic_pdf/libs/draw_bbox.py +14 -2
- magic_pdf/libs/language.py +9 -0
- magic_pdf/libs/pdf_check.py +11 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +103 -99
- magic_pdf/model/doc_analyze_by_custom_model.py +87 -36
- magic_pdf/model/magic_model.py +161 -4
- magic_pdf/model/pdf_extract_kit.py +23 -28
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
- magic_pdf/model/sub_modules/model_init.py +34 -19
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
- magic_pdf/pdf_parse_union_core_v2.py +176 -61
- magic_pdf/post_proc/llm_aided.py +55 -24
- magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
- magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
- magic_pdf/resources/model_config/model_configs.yaml +2 -2
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/METADATA +36 -19
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/RECORD +30 -30
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
|
|
1
1
|
import copy
|
2
|
+
import math
|
2
3
|
import os
|
3
4
|
import re
|
4
5
|
import statistics
|
5
6
|
import time
|
6
7
|
from typing import List
|
7
8
|
|
9
|
+
import cv2
|
8
10
|
import fitz
|
9
11
|
import torch
|
12
|
+
import numpy as np
|
10
13
|
from loguru import logger
|
11
14
|
|
12
15
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
13
16
|
from magic_pdf.config.ocr_content_type import BlockType, ContentType
|
14
17
|
from magic_pdf.data.dataset import Dataset, PageableData
|
15
|
-
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
|
18
|
+
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
|
16
19
|
from magic_pdf.libs.clean_memory import clean_memory
|
17
20
|
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
|
18
21
|
from magic_pdf.libs.convert_utils import dict_to_list
|
@@ -117,24 +120,24 @@ def fill_char_in_spans(spans, all_chars):
|
|
117
120
|
|
118
121
|
for char in all_chars:
|
119
122
|
# 跳过非法bbox的char
|
120
|
-
x1, y1, x2, y2 = char['bbox']
|
121
|
-
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
|
122
|
-
|
123
|
+
# x1, y1, x2, y2 = char['bbox']
|
124
|
+
# if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
|
125
|
+
# continue
|
126
|
+
|
123
127
|
for span in spans:
|
124
128
|
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
|
125
129
|
span['chars'].append(char)
|
126
130
|
break
|
127
131
|
|
128
|
-
|
129
|
-
|
132
|
+
need_ocr_spans = []
|
130
133
|
for span in spans:
|
131
134
|
chars_to_content(span)
|
132
135
|
# 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
|
133
136
|
if len(span['content']) * span['height'] < span['width'] * 0.5:
|
134
137
|
# logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
|
135
|
-
|
138
|
+
need_ocr_spans.append(span)
|
136
139
|
del span['height'], span['width']
|
137
|
-
return
|
140
|
+
return need_ocr_spans
|
138
141
|
|
139
142
|
|
140
143
|
# 使用鲁棒性更强的中心点坐标判断
|
@@ -173,12 +176,60 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
|
|
173
176
|
return False
|
174
177
|
|
175
178
|
|
179
|
+
def remove_tilted_line(text_blocks):
|
180
|
+
for block in text_blocks:
|
181
|
+
remove_lines = []
|
182
|
+
for line in block['lines']:
|
183
|
+
cosine, sine = line['dir']
|
184
|
+
# 计算弧度值
|
185
|
+
angle_radians = math.atan2(sine, cosine)
|
186
|
+
# 将弧度值转换为角度值
|
187
|
+
angle_degrees = math.degrees(angle_radians)
|
188
|
+
if 2 < abs(angle_degrees) < 88:
|
189
|
+
remove_lines.append(line)
|
190
|
+
for line in remove_lines:
|
191
|
+
block['lines'].remove(line)
|
192
|
+
|
193
|
+
|
194
|
+
def calculate_contrast(img, img_mode) -> float:
|
195
|
+
"""
|
196
|
+
计算给定图像的对比度。
|
197
|
+
:param img: 图像,类型为numpy.ndarray
|
198
|
+
:Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
|
199
|
+
:return: 图像的对比度值
|
200
|
+
"""
|
201
|
+
if img_mode == 'rgb':
|
202
|
+
# 将RGB图像转换为灰度图
|
203
|
+
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
204
|
+
elif img_mode == 'bgr':
|
205
|
+
# 将BGR图像转换为灰度图
|
206
|
+
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
207
|
+
else:
|
208
|
+
raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
|
209
|
+
|
210
|
+
# 计算均值和标准差
|
211
|
+
mean_value = np.mean(gray_img)
|
212
|
+
std_dev = np.std(gray_img)
|
213
|
+
# 对比度定义为标准差除以平均值(加上小常数避免除零错误)
|
214
|
+
contrast = std_dev / (mean_value + 1e-6)
|
215
|
+
# logger.info(f"contrast: {contrast}")
|
216
|
+
return round(contrast, 2)
|
217
|
+
|
218
|
+
|
176
219
|
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
|
177
220
|
# cid用0xfffd表示,连字符拆开
|
178
221
|
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
179
222
|
|
180
223
|
# cid用0xfffd表示,连字符不拆开
|
181
|
-
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
224
|
+
#text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
225
|
+
|
226
|
+
# 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用
|
227
|
+
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
|
228
|
+
# text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
|
229
|
+
|
230
|
+
# 移除所有角度不为0或90的line
|
231
|
+
remove_tilted_line(text_blocks_raw)
|
232
|
+
|
182
233
|
all_pymu_chars = []
|
183
234
|
for block in text_blocks_raw:
|
184
235
|
for line in block['lines']:
|
@@ -249,9 +300,9 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
249
300
|
span['chars'] = []
|
250
301
|
new_spans.append(span)
|
251
302
|
|
252
|
-
|
303
|
+
need_ocr_spans = fill_char_in_spans(new_spans, all_pymu_chars)
|
253
304
|
|
254
|
-
if len(
|
305
|
+
if len(need_ocr_spans) > 0:
|
255
306
|
|
256
307
|
# 初始化ocr模型
|
257
308
|
atom_model_manager = AtomModelSingleton()
|
@@ -262,9 +313,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
262
313
|
lang=lang
|
263
314
|
)
|
264
315
|
|
265
|
-
for span in
|
316
|
+
for span in need_ocr_spans:
|
266
317
|
# 对span的bbox截图再ocr
|
267
318
|
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
|
319
|
+
|
320
|
+
# 计算span的对比度,低于0.20的span不进行ocr
|
321
|
+
if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
|
322
|
+
spans.remove(span)
|
323
|
+
continue
|
324
|
+
|
268
325
|
ocr_res = ocr_model.ocr(span_img, det=False)
|
269
326
|
if ocr_res and len(ocr_res) > 0:
|
270
327
|
if len(ocr_res[0]) > 0:
|
@@ -281,24 +338,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
281
338
|
|
282
339
|
def model_init(model_name: str):
|
283
340
|
from transformers import LayoutLMv3ForTokenClassification
|
284
|
-
device = get_device()
|
285
|
-
if torch.cuda.is_available():
|
286
|
-
device = torch.device('cuda')
|
287
|
-
if torch.cuda.is_bf16_supported():
|
288
|
-
supports_bfloat16 = True
|
289
|
-
else:
|
290
|
-
supports_bfloat16 = False
|
291
|
-
elif str(device).startswith("npu"):
|
292
|
-
import torch_npu
|
293
|
-
if torch_npu.npu.is_available():
|
294
|
-
device = torch.device('npu')
|
295
|
-
supports_bfloat16 = False
|
296
|
-
else:
|
297
|
-
device = torch.device('cpu')
|
298
|
-
supports_bfloat16 = False
|
299
|
-
else:
|
300
|
-
device = torch.device('cpu')
|
301
|
-
supports_bfloat16 = False
|
341
|
+
device = torch.device(get_device())
|
302
342
|
|
303
343
|
if model_name == 'layoutreader':
|
304
344
|
# 检测modelscope的缓存目录是否存在
|
@@ -314,9 +354,6 @@ def model_init(model_name: str):
|
|
314
354
|
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
315
355
|
'hantian/layoutreader'
|
316
356
|
)
|
317
|
-
# 检查设备是否支持 bfloat16
|
318
|
-
if supports_bfloat16:
|
319
|
-
model.bfloat16()
|
320
357
|
model.to(device).eval()
|
321
358
|
else:
|
322
359
|
logger.error('model name not allow')
|
@@ -365,10 +402,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
|
|
365
402
|
block['index'] = median_value
|
366
403
|
|
367
404
|
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
|
368
|
-
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
369
|
-
|
370
|
-
|
371
|
-
|
405
|
+
if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
|
406
|
+
if 'real_lines' in block:
|
407
|
+
block['virtual_lines'] = copy.deepcopy(block['lines'])
|
408
|
+
block['lines'] = copy.deepcopy(block['real_lines'])
|
409
|
+
del block['real_lines']
|
372
410
|
else:
|
373
411
|
# 使用xycut排序
|
374
412
|
block_bboxes = []
|
@@ -378,10 +416,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
|
|
378
416
|
block_bboxes.append(block['bbox'])
|
379
417
|
|
380
418
|
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
|
381
|
-
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
382
|
-
|
383
|
-
|
384
|
-
|
419
|
+
if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
|
420
|
+
if 'real_lines' in block:
|
421
|
+
block['virtual_lines'] = copy.deepcopy(block['lines'])
|
422
|
+
block['lines'] = copy.deepcopy(block['real_lines'])
|
423
|
+
del block['real_lines']
|
385
424
|
|
386
425
|
import numpy as np
|
387
426
|
|
@@ -417,7 +456,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
417
456
|
block_weight = x1 - x0
|
418
457
|
|
419
458
|
# 如果block高度小于n行正文,则直接返回block的bbox
|
420
|
-
if line_height *
|
459
|
+
if line_height * 2 < block_height:
|
421
460
|
if (
|
422
461
|
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
|
423
462
|
): # 可能是双列结构,可以切细点
|
@@ -425,16 +464,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
425
464
|
else:
|
426
465
|
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
|
427
466
|
if block_weight > page_w * 0.4:
|
428
|
-
line_height = (y1 - y0) / 3
|
429
467
|
lines = 3
|
468
|
+
line_height = (y1 - y0) / lines
|
430
469
|
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
|
431
470
|
lines = int(block_height / line_height) + 1
|
432
471
|
else: # 判断长宽比
|
433
472
|
if block_height / block_weight > 1.2: # 细长的不分
|
434
473
|
return [[x0, y0, x1, y1]]
|
435
474
|
else: # 不细长的还是分成两行
|
436
|
-
line_height = (y1 - y0) / 2
|
437
475
|
lines = 2
|
476
|
+
line_height = (y1 - y0) / lines
|
438
477
|
|
439
478
|
# 确定从哪个y位置开始绘制线条
|
440
479
|
current_y = y0
|
@@ -453,30 +492,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
453
492
|
|
454
493
|
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
|
455
494
|
page_line_list = []
|
495
|
+
|
496
|
+
def add_lines_to_block(b):
|
497
|
+
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
|
498
|
+
b['lines'] = []
|
499
|
+
for line_bbox in line_bboxes:
|
500
|
+
b['lines'].append({'bbox': line_bbox, 'spans': []})
|
501
|
+
page_line_list.extend(line_bboxes)
|
502
|
+
|
456
503
|
for block in fix_blocks:
|
457
504
|
if block['type'] in [
|
458
|
-
BlockType.Text, BlockType.Title,
|
505
|
+
BlockType.Text, BlockType.Title,
|
459
506
|
BlockType.ImageCaption, BlockType.ImageFootnote,
|
460
507
|
BlockType.TableCaption, BlockType.TableFootnote
|
461
508
|
]:
|
462
509
|
if len(block['lines']) == 0:
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
page_line_list.extend(lines)
|
510
|
+
add_lines_to_block(block)
|
511
|
+
elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
|
512
|
+
block['real_lines'] = copy.deepcopy(block['lines'])
|
513
|
+
add_lines_to_block(block)
|
468
514
|
else:
|
469
515
|
for line in block['lines']:
|
470
516
|
bbox = line['bbox']
|
471
517
|
page_line_list.append(bbox)
|
472
|
-
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
473
|
-
bbox = block['bbox']
|
518
|
+
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
|
474
519
|
block['real_lines'] = copy.deepcopy(block['lines'])
|
475
|
-
|
476
|
-
block['lines'] = []
|
477
|
-
for line in lines:
|
478
|
-
block['lines'].append({'bbox': line, 'spans': []})
|
479
|
-
page_line_list.extend(lines)
|
520
|
+
add_lines_to_block(block)
|
480
521
|
|
481
522
|
if len(page_line_list) > 200: # layoutreader最高支持512line
|
482
523
|
return None
|
@@ -663,12 +704,77 @@ def parse_page_core(
|
|
663
704
|
discarded_blocks = magic_model.get_discarded(page_id)
|
664
705
|
text_blocks = magic_model.get_text_blocks(page_id)
|
665
706
|
title_blocks = magic_model.get_title_blocks(page_id)
|
666
|
-
inline_equations, interline_equations, interline_equation_blocks = (
|
667
|
-
magic_model.get_equations(page_id)
|
668
|
-
)
|
669
|
-
|
707
|
+
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
|
670
708
|
page_w, page_h = magic_model.get_page_size(page_id)
|
671
709
|
|
710
|
+
def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
|
711
|
+
def merge_two_bbox(b1, b2):
|
712
|
+
x_min = min(b1['bbox'][0], b2['bbox'][0])
|
713
|
+
y_min = min(b1['bbox'][1], b2['bbox'][1])
|
714
|
+
x_max = max(b1['bbox'][2], b2['bbox'][2])
|
715
|
+
y_max = max(b1['bbox'][3], b2['bbox'][3])
|
716
|
+
return x_min, y_min, x_max, y_max
|
717
|
+
|
718
|
+
def merge_two_blocks(b1, b2):
|
719
|
+
# 合并两个标题块的边界框
|
720
|
+
b1['bbox'] = merge_two_bbox(b1, b2)
|
721
|
+
|
722
|
+
# 合并两个标题块的文本内容
|
723
|
+
line1 = b1['lines'][0]
|
724
|
+
line2 = b2['lines'][0]
|
725
|
+
line1['bbox'] = merge_two_bbox(line1, line2)
|
726
|
+
line1['spans'].extend(line2['spans'])
|
727
|
+
|
728
|
+
return b1, b2
|
729
|
+
|
730
|
+
# 按 y 轴重叠度聚集标题块
|
731
|
+
y_overlapping_blocks = []
|
732
|
+
title_bs = [b for b in blocks if b['type'] == BlockType.Title]
|
733
|
+
while title_bs:
|
734
|
+
block1 = title_bs.pop(0)
|
735
|
+
current_row = [block1]
|
736
|
+
to_remove = []
|
737
|
+
for block2 in title_bs:
|
738
|
+
if (
|
739
|
+
__is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9)
|
740
|
+
and len(block1['lines']) == 1
|
741
|
+
and len(block2['lines']) == 1
|
742
|
+
):
|
743
|
+
current_row.append(block2)
|
744
|
+
to_remove.append(block2)
|
745
|
+
for b in to_remove:
|
746
|
+
title_bs.remove(b)
|
747
|
+
y_overlapping_blocks.append(current_row)
|
748
|
+
|
749
|
+
# 按x轴坐标排序并合并标题块
|
750
|
+
to_remove_blocks = []
|
751
|
+
for row in y_overlapping_blocks:
|
752
|
+
if len(row) == 1:
|
753
|
+
continue
|
754
|
+
|
755
|
+
# 按x轴坐标排序
|
756
|
+
row.sort(key=lambda x: x['bbox'][0])
|
757
|
+
|
758
|
+
merged_block = row[0]
|
759
|
+
for i in range(1, len(row)):
|
760
|
+
left_block = merged_block
|
761
|
+
right_block = row[i]
|
762
|
+
|
763
|
+
left_height = left_block['bbox'][3] - left_block['bbox'][1]
|
764
|
+
right_height = right_block['bbox'][3] - right_block['bbox'][1]
|
765
|
+
|
766
|
+
if (
|
767
|
+
right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold
|
768
|
+
and left_height * 0.95 < right_height < left_height * 1.05
|
769
|
+
):
|
770
|
+
merged_block, to_remove_block = merge_two_blocks(merged_block, right_block)
|
771
|
+
to_remove_blocks.append(to_remove_block)
|
772
|
+
else:
|
773
|
+
merged_block = right_block
|
774
|
+
|
775
|
+
for b in to_remove_blocks:
|
776
|
+
blocks.remove(b)
|
777
|
+
|
672
778
|
"""将所有区块的bbox整理到一起"""
|
673
779
|
# interline_equation_blocks参数不够准,后面切换到interline_equations上
|
674
780
|
interline_equation_blocks = []
|
@@ -753,6 +859,9 @@ def parse_page_core(
|
|
753
859
|
"""对block进行fix操作"""
|
754
860
|
fix_blocks = fix_block_spans_v2(block_with_spans)
|
755
861
|
|
862
|
+
"""同一行被断开的titile合并"""
|
863
|
+
merge_title_blocks(fix_blocks)
|
864
|
+
|
756
865
|
"""获取所有line并计算正文line的高度"""
|
757
866
|
line_height = get_line_height(fix_blocks)
|
758
867
|
|
@@ -861,17 +970,23 @@ def pdf_parse_union(
|
|
861
970
|
formula_aided_config = llm_aided_config.get('formula_aided', None)
|
862
971
|
if formula_aided_config is not None:
|
863
972
|
if formula_aided_config.get('enable', False):
|
973
|
+
llm_aided_formula_start_time = time.time()
|
864
974
|
llm_aided_formula(pdf_info_dict, formula_aided_config)
|
975
|
+
logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}')
|
865
976
|
"""文本优化"""
|
866
977
|
text_aided_config = llm_aided_config.get('text_aided', None)
|
867
978
|
if text_aided_config is not None:
|
868
979
|
if text_aided_config.get('enable', False):
|
980
|
+
llm_aided_text_start_time = time.time()
|
869
981
|
llm_aided_text(pdf_info_dict, text_aided_config)
|
982
|
+
logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}')
|
870
983
|
"""标题优化"""
|
871
984
|
title_aided_config = llm_aided_config.get('title_aided', None)
|
872
985
|
if title_aided_config is not None:
|
873
986
|
if title_aided_config.get('enable', False):
|
987
|
+
llm_aided_title_start_time = time.time()
|
874
988
|
llm_aided_title(pdf_info_dict, title_aided_config)
|
989
|
+
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
|
875
990
|
|
876
991
|
"""dict转list"""
|
877
992
|
pdf_info_list = dict_to_list(pdf_info_dict)
|
magic_pdf/post_proc/llm_aided.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3
3
|
from loguru import logger
|
4
4
|
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
|
5
5
|
from openai import OpenAI
|
6
|
+
import ast
|
6
7
|
|
7
8
|
|
8
9
|
#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
|
@@ -83,26 +84,48 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
|
|
83
84
|
if block["type"] == "title":
|
84
85
|
origin_title_list.append(block)
|
85
86
|
title_text = merge_para_with_text(block)
|
86
|
-
|
87
|
+
page_line_height_list = []
|
88
|
+
for line in block['lines']:
|
89
|
+
bbox = line['bbox']
|
90
|
+
page_line_height_list.append(int(bbox[3] - bbox[1]))
|
91
|
+
if len(page_line_height_list) > 0:
|
92
|
+
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
|
93
|
+
else:
|
94
|
+
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
|
95
|
+
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
|
87
96
|
i += 1
|
88
97
|
# logger.info(f"Title list: {title_dict}")
|
89
98
|
|
90
99
|
title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
|
91
100
|
|
92
|
-
1.
|
101
|
+
1. 字典中每个value均为一个list,包含以下元素:
|
102
|
+
- 标题文本
|
103
|
+
- 文本行高是标题所在块的平均行高
|
104
|
+
- 标题所在的页码
|
105
|
+
|
106
|
+
2. 保留原始内容:
|
93
107
|
- 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
|
94
108
|
- 请务必保证输出的字典中元素的数量和输入的数量一致
|
95
109
|
|
96
|
-
|
110
|
+
3. 保持字典内key-value的对应关系不变
|
97
111
|
|
98
|
-
|
112
|
+
4. 优化层次结构:
|
99
113
|
- 为每个标题元素添加适当的层次结构
|
100
|
-
-
|
114
|
+
- 行高较大的标题一般是更高级别的标题
|
115
|
+
- 标题从前至后的层级必须是连续的,不能跳过层级
|
101
116
|
- 标题层级最多为4级,不要添加过多的层级
|
102
|
-
-
|
103
|
-
|
117
|
+
- 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
|
118
|
+
|
119
|
+
5. 合理性检查与微调:
|
120
|
+
- 在完成初步分级后,仔细检查分级结果的合理性
|
121
|
+
- 根据上下文关系和逻辑顺序,对不合理的分级进行微调
|
122
|
+
- 确保最终的分级结果符合文档的实际结构和逻辑
|
123
|
+
- 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
|
124
|
+
|
104
125
|
IMPORTANT:
|
105
|
-
|
126
|
+
请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
|
127
|
+
{{0:1,1:2,2:2,3:3}}
|
128
|
+
不需要对字典格式化,不需要返回任何其他信息。
|
106
129
|
|
107
130
|
Input title list:
|
108
131
|
{title_dict}
|
@@ -110,24 +133,32 @@ Input title list:
|
|
110
133
|
Corrected title list:
|
111
134
|
"""
|
112
135
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
{'role': 'user', 'content': title_optimize_prompt}],
|
117
|
-
temperature=0.7,
|
118
|
-
)
|
119
|
-
|
120
|
-
json_completion = json.loads(completion.choices[0].message.content)
|
121
|
-
|
122
|
-
# logger.info(f"Title completion: {json_completion}")
|
136
|
+
retry_count = 0
|
137
|
+
max_retries = 3
|
138
|
+
dict_completion = None
|
123
139
|
|
124
|
-
|
125
|
-
if len(json_completion) == len(title_dict):
|
140
|
+
while retry_count < max_retries:
|
126
141
|
try:
|
127
|
-
|
128
|
-
|
142
|
+
completion = client.chat.completions.create(
|
143
|
+
model=title_aided_config["model"],
|
144
|
+
messages=[
|
145
|
+
{'role': 'user', 'content': title_optimize_prompt}],
|
146
|
+
temperature=0.7,
|
147
|
+
)
|
148
|
+
# logger.info(f"Title completion: {completion.choices[0].message.content}")
|
149
|
+
dict_completion = ast.literal_eval(completion.choices[0].message.content)
|
150
|
+
# logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
|
151
|
+
|
152
|
+
if len(dict_completion) == len(title_dict):
|
153
|
+
for i, origin_title_block in enumerate(origin_title_list):
|
154
|
+
origin_title_block["level"] = int(dict_completion[i])
|
155
|
+
break
|
156
|
+
else:
|
157
|
+
logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
|
158
|
+
retry_count += 1
|
129
159
|
except Exception as e:
|
130
160
|
logger.exception(e)
|
131
|
-
|
132
|
-
logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
|
161
|
+
retry_count += 1
|
133
162
|
|
163
|
+
if dict_completion is None:
|
164
|
+
logger.error("Failed to decode dict after maximum retries.")
|
@@ -60,6 +60,19 @@ def merge_spans_to_line(spans, threshold=0.6):
|
|
60
60
|
return lines
|
61
61
|
|
62
62
|
|
63
|
+
def span_block_type_compatible(span_type, block_type):
|
64
|
+
if span_type in [ContentType.Text, ContentType.InlineEquation]:
|
65
|
+
return block_type in [BlockType.Text, BlockType.Title, BlockType.ImageCaption, BlockType.ImageFootnote, BlockType.TableCaption, BlockType.TableFootnote]
|
66
|
+
elif span_type == ContentType.InterlineEquation:
|
67
|
+
return block_type in [BlockType.InterlineEquation]
|
68
|
+
elif span_type == ContentType.Image:
|
69
|
+
return block_type in [BlockType.ImageBody]
|
70
|
+
elif span_type == ContentType.Table:
|
71
|
+
return block_type in [BlockType.TableBody]
|
72
|
+
else:
|
73
|
+
return False
|
74
|
+
|
75
|
+
|
63
76
|
def fill_spans_in_blocks(blocks, spans, radio):
|
64
77
|
"""将allspans中的span按位置关系,放入blocks中."""
|
65
78
|
block_with_spans = []
|
@@ -78,8 +91,7 @@ def fill_spans_in_blocks(blocks, spans, radio):
|
|
78
91
|
block_spans = []
|
79
92
|
for span in spans:
|
80
93
|
span_bbox = span['bbox']
|
81
|
-
if calculate_overlap_area_in_bbox1_area_ratio(
|
82
|
-
span_bbox, block_bbox) > radio:
|
94
|
+
if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio and span_block_type_compatible(span['type'], block_type):
|
83
95
|
block_spans.append(span)
|
84
96
|
|
85
97
|
block_dict['spans'] = block_spans
|
@@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans):
|
|
36
36
|
def check_chars_is_overlap_in_span(chars):
|
37
37
|
for i in range(len(chars)):
|
38
38
|
for j in range(i + 1, len(chars)):
|
39
|
-
if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.
|
39
|
+
if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.35:
|
40
40
|
return True
|
41
41
|
return False
|
42
42
|
|
@@ -1,8 +1,8 @@
|
|
1
1
|
weights:
|
2
2
|
layoutlmv3: Layout/LayoutLMv3/model_final.pth
|
3
|
-
doclayout_yolo: Layout/YOLO/
|
3
|
+
doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
|
4
4
|
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
|
5
|
-
unimernet_small: MFR/
|
5
|
+
unimernet_small: MFR/unimernet_small_2501
|
6
6
|
struct_eqtable: TabRec/StructEqTable
|
7
7
|
tablemaster: TabRec/TableMaster
|
8
8
|
rapid_table: TabRec/RapidTable
|