magic-pdf 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/libs/boxbase.py +5 -2
- magic_pdf/libs/draw_bbox.py +14 -2
- magic_pdf/libs/language.py +9 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +103 -99
- magic_pdf/model/doc_analyze_by_custom_model.py +77 -18
- magic_pdf/model/pdf_extract_kit.py +23 -21
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
- magic_pdf/model/sub_modules/model_init.py +4 -3
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
- magic_pdf/pdf_parse_union_core_v2.py +137 -32
- magic_pdf/post_proc/llm_aided.py +59 -26
- magic_pdf/post_proc/llm_aided_ocr.py +689 -0
- magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
- magic_pdf/resources/model_config/model_configs.yaml +2 -2
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/METADATA +50 -41
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/RECORD +23 -22
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/WHEEL +1 -1
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ import base64
|
|
7
7
|
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
8
8
|
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
9
9
|
|
10
|
+
import importlib.resources
|
11
|
+
from paddleocr import PaddleOCR
|
10
12
|
from ppocr.utils.utility import check_and_read
|
11
13
|
|
12
14
|
|
@@ -327,30 +329,35 @@ class ONNXModelSingleton:
|
|
327
329
|
return self._models[key]
|
328
330
|
|
329
331
|
def onnx_model_init(key):
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
|
334
|
-
|
335
|
-
onnx_model = None
|
336
|
-
additional_ocr_params = {
|
337
|
-
"use_onnx": True,
|
338
|
-
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
|
339
|
-
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
|
340
|
-
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
|
341
|
-
"det_db_box_thresh": key[1],
|
342
|
-
"use_dilation": key[2],
|
343
|
-
"det_db_unclip_ratio": key[3],
|
344
|
-
}
|
345
|
-
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
|
346
|
-
if key[0] is not None:
|
347
|
-
additional_ocr_params["lang"] = key[0]
|
348
|
-
|
349
|
-
from paddleocr import PaddleOCR
|
350
|
-
onnx_model = PaddleOCR(**additional_ocr_params)
|
351
|
-
|
352
|
-
if onnx_model is None:
|
353
|
-
logger.error('model init failed')
|
332
|
+
if len(key) < 4:
|
333
|
+
logger.error('Invalid key length, expected at least 4 elements')
|
354
334
|
exit(1)
|
355
|
-
|
356
|
-
|
335
|
+
|
336
|
+
try:
|
337
|
+
with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
|
338
|
+
additional_ocr_params = {
|
339
|
+
"use_onnx": True,
|
340
|
+
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
|
341
|
+
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
|
342
|
+
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
|
343
|
+
"det_db_box_thresh": key[1],
|
344
|
+
"use_dilation": key[2],
|
345
|
+
"det_db_unclip_ratio": key[3],
|
346
|
+
}
|
347
|
+
|
348
|
+
if key[0] is not None:
|
349
|
+
additional_ocr_params["lang"] = key[0]
|
350
|
+
|
351
|
+
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
|
352
|
+
|
353
|
+
onnx_model = PaddleOCR(**additional_ocr_params)
|
354
|
+
|
355
|
+
if onnx_model is None:
|
356
|
+
logger.error('model init failed')
|
357
|
+
exit(1)
|
358
|
+
else:
|
359
|
+
return onnx_model
|
360
|
+
|
361
|
+
except Exception as e:
|
362
|
+
logger.exception(f'Error initializing model: {e}')
|
363
|
+
exit(1)
|
@@ -2,12 +2,27 @@ import cv2
|
|
2
2
|
import numpy as np
|
3
3
|
import torch
|
4
4
|
from loguru import logger
|
5
|
-
from rapid_table import RapidTable
|
5
|
+
from rapid_table import RapidTable, RapidTableInput
|
6
|
+
from rapid_table.main import ModelType
|
7
|
+
|
8
|
+
from magic_pdf.libs.config_reader import get_device
|
6
9
|
|
7
10
|
|
8
11
|
class RapidTableModel(object):
|
9
|
-
def __init__(self, ocr_engine):
|
10
|
-
|
12
|
+
def __init__(self, ocr_engine, table_sub_model_name):
|
13
|
+
sub_model_list = [model.value for model in ModelType]
|
14
|
+
if table_sub_model_name is None:
|
15
|
+
input_args = RapidTableInput()
|
16
|
+
elif table_sub_model_name in sub_model_list:
|
17
|
+
if torch.cuda.is_available() and table_sub_model_name == "unitable":
|
18
|
+
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
|
19
|
+
else:
|
20
|
+
input_args = RapidTableInput(model_type=table_sub_model_name)
|
21
|
+
else:
|
22
|
+
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
|
23
|
+
|
24
|
+
self.table_model = RapidTable(input_args)
|
25
|
+
|
11
26
|
# if ocr_engine is None:
|
12
27
|
# self.ocr_model_name = "RapidOCR"
|
13
28
|
# if torch.cuda.is_available():
|
@@ -45,7 +60,11 @@ class RapidTableModel(object):
|
|
45
60
|
ocr_result = None
|
46
61
|
|
47
62
|
if ocr_result:
|
48
|
-
|
49
|
-
|
63
|
+
table_results = self.table_model(np.asarray(image), ocr_result)
|
64
|
+
html_code = table_results.pred_html
|
65
|
+
table_cell_bboxes = table_results.cell_bboxes
|
66
|
+
logic_points = table_results.logic_points
|
67
|
+
elapse = table_results.elapse
|
68
|
+
return html_code, table_cell_bboxes, logic_points, elapse
|
50
69
|
else:
|
51
|
-
return None, None, None
|
70
|
+
return None, None, None, None
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import copy
|
2
|
+
import math
|
2
3
|
import os
|
3
4
|
import re
|
4
5
|
import statistics
|
@@ -12,7 +13,7 @@ from loguru import logger
|
|
12
13
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
13
14
|
from magic_pdf.config.ocr_content_type import BlockType, ContentType
|
14
15
|
from magic_pdf.data.dataset import Dataset, PageableData
|
15
|
-
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
|
16
|
+
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
|
16
17
|
from magic_pdf.libs.clean_memory import clean_memory
|
17
18
|
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
|
18
19
|
from magic_pdf.libs.convert_utils import dict_to_list
|
@@ -117,9 +118,10 @@ def fill_char_in_spans(spans, all_chars):
|
|
117
118
|
|
118
119
|
for char in all_chars:
|
119
120
|
# 跳过非法bbox的char
|
120
|
-
x1, y1, x2, y2 = char['bbox']
|
121
|
-
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
|
122
|
-
|
121
|
+
# x1, y1, x2, y2 = char['bbox']
|
122
|
+
# if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
|
123
|
+
# continue
|
124
|
+
|
123
125
|
for span in spans:
|
124
126
|
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
|
125
127
|
span['chars'].append(char)
|
@@ -173,12 +175,35 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
|
|
173
175
|
return False
|
174
176
|
|
175
177
|
|
178
|
+
def remove_tilted_line(text_blocks):
|
179
|
+
for block in text_blocks:
|
180
|
+
remove_lines = []
|
181
|
+
for line in block['lines']:
|
182
|
+
cosine, sine = line['dir']
|
183
|
+
# 计算弧度值
|
184
|
+
angle_radians = math.atan2(sine, cosine)
|
185
|
+
# 将弧度值转换为角度值
|
186
|
+
angle_degrees = math.degrees(angle_radians)
|
187
|
+
if 2 < abs(angle_degrees) < 88:
|
188
|
+
remove_lines.append(line)
|
189
|
+
for line in remove_lines:
|
190
|
+
block['lines'].remove(line)
|
191
|
+
|
192
|
+
|
176
193
|
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
|
177
194
|
# cid用0xfffd表示,连字符拆开
|
178
195
|
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
179
196
|
|
180
197
|
# cid用0xfffd表示,连字符不拆开
|
181
|
-
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
198
|
+
#text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
199
|
+
|
200
|
+
# 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用
|
201
|
+
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
|
202
|
+
# text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
|
203
|
+
|
204
|
+
# 移除所有角度不为0或90的line
|
205
|
+
remove_tilted_line(text_blocks_raw)
|
206
|
+
|
182
207
|
all_pymu_chars = []
|
183
208
|
for block in text_blocks_raw:
|
184
209
|
for line in block['lines']:
|
@@ -365,10 +390,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
|
|
365
390
|
block['index'] = median_value
|
366
391
|
|
367
392
|
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
|
368
|
-
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
369
|
-
|
370
|
-
|
371
|
-
|
393
|
+
if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
|
394
|
+
if 'real_lines' in block:
|
395
|
+
block['virtual_lines'] = copy.deepcopy(block['lines'])
|
396
|
+
block['lines'] = copy.deepcopy(block['real_lines'])
|
397
|
+
del block['real_lines']
|
372
398
|
else:
|
373
399
|
# 使用xycut排序
|
374
400
|
block_bboxes = []
|
@@ -417,7 +443,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
417
443
|
block_weight = x1 - x0
|
418
444
|
|
419
445
|
# 如果block高度小于n行正文,则直接返回block的bbox
|
420
|
-
if line_height *
|
446
|
+
if line_height * 2 < block_height:
|
421
447
|
if (
|
422
448
|
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
|
423
449
|
): # 可能是双列结构,可以切细点
|
@@ -425,16 +451,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
425
451
|
else:
|
426
452
|
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
|
427
453
|
if block_weight > page_w * 0.4:
|
428
|
-
line_height = (y1 - y0) / 3
|
429
454
|
lines = 3
|
455
|
+
line_height = (y1 - y0) / lines
|
430
456
|
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
|
431
457
|
lines = int(block_height / line_height) + 1
|
432
458
|
else: # 判断长宽比
|
433
459
|
if block_height / block_weight > 1.2: # 细长的不分
|
434
460
|
return [[x0, y0, x1, y1]]
|
435
461
|
else: # 不细长的还是分成两行
|
436
|
-
line_height = (y1 - y0) / 2
|
437
462
|
lines = 2
|
463
|
+
line_height = (y1 - y0) / lines
|
438
464
|
|
439
465
|
# 确定从哪个y位置开始绘制线条
|
440
466
|
current_y = y0
|
@@ -453,30 +479,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
453
479
|
|
454
480
|
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
|
455
481
|
page_line_list = []
|
482
|
+
|
483
|
+
def add_lines_to_block(b):
|
484
|
+
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
|
485
|
+
b['lines'] = []
|
486
|
+
for line_bbox in line_bboxes:
|
487
|
+
b['lines'].append({'bbox': line_bbox, 'spans': []})
|
488
|
+
page_line_list.extend(line_bboxes)
|
489
|
+
|
456
490
|
for block in fix_blocks:
|
457
491
|
if block['type'] in [
|
458
|
-
BlockType.Text, BlockType.Title,
|
492
|
+
BlockType.Text, BlockType.Title,
|
459
493
|
BlockType.ImageCaption, BlockType.ImageFootnote,
|
460
494
|
BlockType.TableCaption, BlockType.TableFootnote
|
461
495
|
]:
|
462
496
|
if len(block['lines']) == 0:
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
page_line_list.extend(lines)
|
497
|
+
add_lines_to_block(block)
|
498
|
+
elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
|
499
|
+
block['real_lines'] = copy.deepcopy(block['lines'])
|
500
|
+
add_lines_to_block(block)
|
468
501
|
else:
|
469
502
|
for line in block['lines']:
|
470
503
|
bbox = line['bbox']
|
471
504
|
page_line_list.append(bbox)
|
472
|
-
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
|
473
|
-
bbox = block['bbox']
|
505
|
+
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
|
474
506
|
block['real_lines'] = copy.deepcopy(block['lines'])
|
475
|
-
|
476
|
-
block['lines'] = []
|
477
|
-
for line in lines:
|
478
|
-
block['lines'].append({'bbox': line, 'spans': []})
|
479
|
-
page_line_list.extend(lines)
|
507
|
+
add_lines_to_block(block)
|
480
508
|
|
481
509
|
if len(page_line_list) > 200: # layoutreader最高支持512line
|
482
510
|
return None
|
@@ -663,12 +691,77 @@ def parse_page_core(
|
|
663
691
|
discarded_blocks = magic_model.get_discarded(page_id)
|
664
692
|
text_blocks = magic_model.get_text_blocks(page_id)
|
665
693
|
title_blocks = magic_model.get_title_blocks(page_id)
|
666
|
-
inline_equations, interline_equations, interline_equation_blocks = (
|
667
|
-
magic_model.get_equations(page_id)
|
668
|
-
)
|
669
|
-
|
694
|
+
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
|
670
695
|
page_w, page_h = magic_model.get_page_size(page_id)
|
671
696
|
|
697
|
+
def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
|
698
|
+
def merge_two_bbox(b1, b2):
|
699
|
+
x_min = min(b1['bbox'][0], b2['bbox'][0])
|
700
|
+
y_min = min(b1['bbox'][1], b2['bbox'][1])
|
701
|
+
x_max = max(b1['bbox'][2], b2['bbox'][2])
|
702
|
+
y_max = max(b1['bbox'][3], b2['bbox'][3])
|
703
|
+
return x_min, y_min, x_max, y_max
|
704
|
+
|
705
|
+
def merge_two_blocks(b1, b2):
|
706
|
+
# 合并两个标题块的边界框
|
707
|
+
b1['bbox'] = merge_two_bbox(b1, b2)
|
708
|
+
|
709
|
+
# 合并两个标题块的文本内容
|
710
|
+
line1 = b1['lines'][0]
|
711
|
+
line2 = b2['lines'][0]
|
712
|
+
line1['bbox'] = merge_two_bbox(line1, line2)
|
713
|
+
line1['spans'].extend(line2['spans'])
|
714
|
+
|
715
|
+
return b1, b2
|
716
|
+
|
717
|
+
# 按 y 轴重叠度聚集标题块
|
718
|
+
y_overlapping_blocks = []
|
719
|
+
title_bs = [b for b in blocks if b['type'] == BlockType.Title]
|
720
|
+
while title_bs:
|
721
|
+
block1 = title_bs.pop(0)
|
722
|
+
current_row = [block1]
|
723
|
+
to_remove = []
|
724
|
+
for block2 in title_bs:
|
725
|
+
if (
|
726
|
+
__is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9)
|
727
|
+
and len(block1['lines']) == 1
|
728
|
+
and len(block2['lines']) == 1
|
729
|
+
):
|
730
|
+
current_row.append(block2)
|
731
|
+
to_remove.append(block2)
|
732
|
+
for b in to_remove:
|
733
|
+
title_bs.remove(b)
|
734
|
+
y_overlapping_blocks.append(current_row)
|
735
|
+
|
736
|
+
# 按x轴坐标排序并合并标题块
|
737
|
+
to_remove_blocks = []
|
738
|
+
for row in y_overlapping_blocks:
|
739
|
+
if len(row) == 1:
|
740
|
+
continue
|
741
|
+
|
742
|
+
# 按x轴坐标排序
|
743
|
+
row.sort(key=lambda x: x['bbox'][0])
|
744
|
+
|
745
|
+
merged_block = row[0]
|
746
|
+
for i in range(1, len(row)):
|
747
|
+
left_block = merged_block
|
748
|
+
right_block = row[i]
|
749
|
+
|
750
|
+
left_height = left_block['bbox'][3] - left_block['bbox'][1]
|
751
|
+
right_height = right_block['bbox'][3] - right_block['bbox'][1]
|
752
|
+
|
753
|
+
if (
|
754
|
+
right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold
|
755
|
+
and left_height * 0.95 < right_height < left_height * 1.05
|
756
|
+
):
|
757
|
+
merged_block, to_remove_block = merge_two_blocks(merged_block, right_block)
|
758
|
+
to_remove_blocks.append(to_remove_block)
|
759
|
+
else:
|
760
|
+
merged_block = right_block
|
761
|
+
|
762
|
+
for b in to_remove_blocks:
|
763
|
+
blocks.remove(b)
|
764
|
+
|
672
765
|
"""将所有区块的bbox整理到一起"""
|
673
766
|
# interline_equation_blocks参数不够准,后面切换到interline_equations上
|
674
767
|
interline_equation_blocks = []
|
@@ -753,6 +846,9 @@ def parse_page_core(
|
|
753
846
|
"""对block进行fix操作"""
|
754
847
|
fix_blocks = fix_block_spans_v2(block_with_spans)
|
755
848
|
|
849
|
+
"""同一行被断开的titile合并"""
|
850
|
+
merge_title_blocks(fix_blocks)
|
851
|
+
|
756
852
|
"""获取所有line并计算正文line的高度"""
|
757
853
|
line_height = get_line_height(fix_blocks)
|
758
854
|
|
@@ -860,15 +956,24 @@ def pdf_parse_union(
|
|
860
956
|
"""公式优化"""
|
861
957
|
formula_aided_config = llm_aided_config.get('formula_aided', None)
|
862
958
|
if formula_aided_config is not None:
|
863
|
-
|
959
|
+
if formula_aided_config.get('enable', False):
|
960
|
+
llm_aided_formula_start_time = time.time()
|
961
|
+
llm_aided_formula(pdf_info_dict, formula_aided_config)
|
962
|
+
logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}')
|
864
963
|
"""文本优化"""
|
865
964
|
text_aided_config = llm_aided_config.get('text_aided', None)
|
866
965
|
if text_aided_config is not None:
|
867
|
-
|
966
|
+
if text_aided_config.get('enable', False):
|
967
|
+
llm_aided_text_start_time = time.time()
|
968
|
+
llm_aided_text(pdf_info_dict, text_aided_config)
|
969
|
+
logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}')
|
868
970
|
"""标题优化"""
|
869
971
|
title_aided_config = llm_aided_config.get('title_aided', None)
|
870
972
|
if title_aided_config is not None:
|
871
|
-
|
973
|
+
if title_aided_config.get('enable', False):
|
974
|
+
llm_aided_title_start_time = time.time()
|
975
|
+
llm_aided_title(pdf_info_dict, title_aided_config)
|
976
|
+
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
|
872
977
|
|
873
978
|
"""dict转list"""
|
874
979
|
pdf_info_list = dict_to_list(pdf_info_dict)
|
magic_pdf/post_proc/llm_aided.py
CHANGED
@@ -83,26 +83,47 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
|
|
83
83
|
if block["type"] == "title":
|
84
84
|
origin_title_list.append(block)
|
85
85
|
title_text = merge_para_with_text(block)
|
86
|
-
|
86
|
+
page_line_height_list = []
|
87
|
+
for line in block['lines']:
|
88
|
+
bbox = line['bbox']
|
89
|
+
page_line_height_list.append(int(bbox[3] - bbox[1]))
|
90
|
+
if len(page_line_height_list) > 0:
|
91
|
+
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
|
92
|
+
else:
|
93
|
+
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
|
94
|
+
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
|
87
95
|
i += 1
|
88
96
|
# logger.info(f"Title list: {title_dict}")
|
89
97
|
|
90
98
|
title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
|
91
99
|
|
92
|
-
1.
|
100
|
+
1. 字典中每个value均为一个list,包含以下元素:
|
101
|
+
- 标题文本
|
102
|
+
- 文本行高是标题所在块的平均行高
|
103
|
+
- 标题所在的页码
|
104
|
+
|
105
|
+
2. 保留原始内容:
|
93
106
|
- 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
|
94
107
|
- 请务必保证输出的字典中元素的数量和输入的数量一致
|
95
108
|
|
96
|
-
|
109
|
+
3. 保持字典内key-value的对应关系不变
|
97
110
|
|
98
|
-
|
111
|
+
4. 优化层次结构:
|
99
112
|
- 为每个标题元素添加适当的层次结构
|
100
|
-
-
|
113
|
+
- 行高较大的标题一般是更高级别的标题
|
114
|
+
- 标题从前至后的层级必须是连续的,不能跳过层级
|
101
115
|
- 标题层级最多为4级,不要添加过多的层级
|
102
|
-
-
|
103
|
-
|
116
|
+
- 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
|
117
|
+
|
118
|
+
5. 合理性检查与微调:
|
119
|
+
- 在完成初步分级后,仔细检查分级结果的合理性
|
120
|
+
- 根据上下文关系和逻辑顺序,对不合理的分级进行微调
|
121
|
+
- 确保最终的分级结果符合文档的实际结构和逻辑
|
122
|
+
|
104
123
|
IMPORTANT:
|
105
|
-
请直接返回优化过的由标题层级组成的json
|
124
|
+
请直接返回优化过的由标题层级组成的json,格式如下:
|
125
|
+
{{"0":1,"1":2,"2":2,"3":3}}
|
126
|
+
返回的json不需要格式化。
|
106
127
|
|
107
128
|
Input title list:
|
108
129
|
{title_dict}
|
@@ -110,24 +131,36 @@ Input title list:
|
|
110
131
|
Corrected title list:
|
111
132
|
"""
|
112
133
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
{'role': 'user', 'content': title_optimize_prompt}],
|
117
|
-
temperature=0.7,
|
118
|
-
)
|
119
|
-
|
120
|
-
json_completion = json.loads(completion.choices[0].message.content)
|
121
|
-
|
122
|
-
# logger.info(f"Title completion: {json_completion}")
|
134
|
+
retry_count = 0
|
135
|
+
max_retries = 3
|
136
|
+
json_completion = None
|
123
137
|
|
124
|
-
|
125
|
-
if len(json_completion) == len(title_dict):
|
138
|
+
while retry_count < max_retries:
|
126
139
|
try:
|
127
|
-
|
128
|
-
|
140
|
+
completion = client.chat.completions.create(
|
141
|
+
model=title_aided_config["model"],
|
142
|
+
messages=[
|
143
|
+
{'role': 'user', 'content': title_optimize_prompt}],
|
144
|
+
temperature=0.7,
|
145
|
+
)
|
146
|
+
json_completion = json.loads(completion.choices[0].message.content)
|
147
|
+
|
148
|
+
# logger.info(f"Title completion: {json_completion}")
|
149
|
+
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
|
150
|
+
|
151
|
+
if len(json_completion) == len(title_dict):
|
152
|
+
for i, origin_title_block in enumerate(origin_title_list):
|
153
|
+
origin_title_block["level"] = int(json_completion[str(i)])
|
154
|
+
break
|
155
|
+
else:
|
156
|
+
logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
|
157
|
+
retry_count += 1
|
129
158
|
except Exception as e:
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
159
|
+
if isinstance(e, json.decoder.JSONDecodeError):
|
160
|
+
logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
|
161
|
+
else:
|
162
|
+
logger.exception(e)
|
163
|
+
retry_count += 1
|
164
|
+
|
165
|
+
if json_completion is None:
|
166
|
+
logger.error("Failed to decode JSON after maximum retries.")
|