magic-pdf 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +24 -0
  2. magic_pdf/filter/__init__.py +1 -1
  3. magic_pdf/filter/pdf_classify_by_type.py +6 -4
  4. magic_pdf/filter/pdf_meta_scan.py +4 -4
  5. magic_pdf/libs/boxbase.py +5 -2
  6. magic_pdf/libs/draw_bbox.py +14 -2
  7. magic_pdf/libs/language.py +9 -0
  8. magic_pdf/libs/pdf_check.py +11 -1
  9. magic_pdf/libs/version.py +1 -1
  10. magic_pdf/model/batch_analyze.py +103 -99
  11. magic_pdf/model/doc_analyze_by_custom_model.py +87 -36
  12. magic_pdf/model/magic_model.py +161 -4
  13. magic_pdf/model/pdf_extract_kit.py +23 -28
  14. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
  15. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
  16. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
  17. magic_pdf/model/sub_modules/model_init.py +34 -19
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
  19. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
  20. magic_pdf/pdf_parse_union_core_v2.py +176 -61
  21. magic_pdf/post_proc/llm_aided.py +55 -24
  22. magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
  23. magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
  24. magic_pdf/resources/model_config/model_configs.yaml +2 -2
  25. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/METADATA +36 -19
  26. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/RECORD +30 -30
  27. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/LICENSE.md +0 -0
  28. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/WHEEL +0 -0
  29. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/entry_points.txt +0 -0
  30. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
1
  import copy
2
+ import math
2
3
  import os
3
4
  import re
4
5
  import statistics
5
6
  import time
6
7
  from typing import List
7
8
 
9
+ import cv2
8
10
  import fitz
9
11
  import torch
12
+ import numpy as np
10
13
  from loguru import logger
11
14
 
12
15
  from magic_pdf.config.enums import SupportedPdfParseMethod
13
16
  from magic_pdf.config.ocr_content_type import BlockType, ContentType
14
17
  from magic_pdf.data.dataset import Dataset, PageableData
15
- from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
18
+ from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
16
19
  from magic_pdf.libs.clean_memory import clean_memory
17
20
  from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
18
21
  from magic_pdf.libs.convert_utils import dict_to_list
@@ -117,24 +120,24 @@ def fill_char_in_spans(spans, all_chars):
117
120
 
118
121
  for char in all_chars:
119
122
  # 跳过非法bbox的char
120
- x1, y1, x2, y2 = char['bbox']
121
- if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
122
- continue
123
+ # x1, y1, x2, y2 = char['bbox']
124
+ # if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
125
+ # continue
126
+
123
127
  for span in spans:
124
128
  if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
125
129
  span['chars'].append(char)
126
130
  break
127
131
 
128
- empty_spans = []
129
-
132
+ need_ocr_spans = []
130
133
  for span in spans:
131
134
  chars_to_content(span)
132
135
  # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
133
136
  if len(span['content']) * span['height'] < span['width'] * 0.5:
134
137
  # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
135
- empty_spans.append(span)
138
+ need_ocr_spans.append(span)
136
139
  del span['height'], span['width']
137
- return empty_spans
140
+ return need_ocr_spans
138
141
 
139
142
 
140
143
  # 使用鲁棒性更强的中心点坐标判断
@@ -173,12 +176,60 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
173
176
  return False
174
177
 
175
178
 
179
+ def remove_tilted_line(text_blocks):
180
+ for block in text_blocks:
181
+ remove_lines = []
182
+ for line in block['lines']:
183
+ cosine, sine = line['dir']
184
+ # 计算弧度值
185
+ angle_radians = math.atan2(sine, cosine)
186
+ # 将弧度值转换为角度值
187
+ angle_degrees = math.degrees(angle_radians)
188
+ if 2 < abs(angle_degrees) < 88:
189
+ remove_lines.append(line)
190
+ for line in remove_lines:
191
+ block['lines'].remove(line)
192
+
193
+
194
+ def calculate_contrast(img, img_mode) -> float:
195
+ """
196
+ 计算给定图像的对比度。
197
+ :param img: 图像,类型为numpy.ndarray
198
+ :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
199
+ :return: 图像的对比度值
200
+ """
201
+ if img_mode == 'rgb':
202
+ # 将RGB图像转换为灰度图
203
+ gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
204
+ elif img_mode == 'bgr':
205
+ # 将BGR图像转换为灰度图
206
+ gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
207
+ else:
208
+ raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
209
+
210
+ # 计算均值和标准差
211
+ mean_value = np.mean(gray_img)
212
+ std_dev = np.std(gray_img)
213
+ # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
214
+ contrast = std_dev / (mean_value + 1e-6)
215
+ # logger.info(f"contrast: {contrast}")
216
+ return round(contrast, 2)
217
+
218
+
176
219
  def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
177
220
  # cid用0xfffd表示,连字符拆开
178
221
  # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
179
222
 
180
223
  # cid用0xfffd表示,连字符不拆开
181
- text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
224
+ #text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
225
+
226
+ # 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用
227
+ text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
228
+ # text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
229
+
230
+ # 移除所有角度不为0或90的line
231
+ remove_tilted_line(text_blocks_raw)
232
+
182
233
  all_pymu_chars = []
183
234
  for block in text_blocks_raw:
184
235
  for line in block['lines']:
@@ -249,9 +300,9 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
249
300
  span['chars'] = []
250
301
  new_spans.append(span)
251
302
 
252
- empty_spans = fill_char_in_spans(new_spans, all_pymu_chars)
303
+ need_ocr_spans = fill_char_in_spans(new_spans, all_pymu_chars)
253
304
 
254
- if len(empty_spans) > 0:
305
+ if len(need_ocr_spans) > 0:
255
306
 
256
307
  # 初始化ocr模型
257
308
  atom_model_manager = AtomModelSingleton()
@@ -262,9 +313,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
262
313
  lang=lang
263
314
  )
264
315
 
265
- for span in empty_spans:
316
+ for span in need_ocr_spans:
266
317
  # 对span的bbox截图再ocr
267
318
  span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
319
+
320
+ # 计算span的对比度,低于0.20的span不进行ocr
321
+ if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
322
+ spans.remove(span)
323
+ continue
324
+
268
325
  ocr_res = ocr_model.ocr(span_img, det=False)
269
326
  if ocr_res and len(ocr_res) > 0:
270
327
  if len(ocr_res[0]) > 0:
@@ -281,24 +338,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
281
338
 
282
339
  def model_init(model_name: str):
283
340
  from transformers import LayoutLMv3ForTokenClassification
284
- device = get_device()
285
- if torch.cuda.is_available():
286
- device = torch.device('cuda')
287
- if torch.cuda.is_bf16_supported():
288
- supports_bfloat16 = True
289
- else:
290
- supports_bfloat16 = False
291
- elif str(device).startswith("npu"):
292
- import torch_npu
293
- if torch_npu.npu.is_available():
294
- device = torch.device('npu')
295
- supports_bfloat16 = False
296
- else:
297
- device = torch.device('cpu')
298
- supports_bfloat16 = False
299
- else:
300
- device = torch.device('cpu')
301
- supports_bfloat16 = False
341
+ device = torch.device(get_device())
302
342
 
303
343
  if model_name == 'layoutreader':
304
344
  # 检测modelscope的缓存目录是否存在
@@ -314,9 +354,6 @@ def model_init(model_name: str):
314
354
  model = LayoutLMv3ForTokenClassification.from_pretrained(
315
355
  'hantian/layoutreader'
316
356
  )
317
- # 检查设备是否支持 bfloat16
318
- if supports_bfloat16:
319
- model.bfloat16()
320
357
  model.to(device).eval()
321
358
  else:
322
359
  logger.error('model name not allow')
@@ -365,10 +402,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
365
402
  block['index'] = median_value
366
403
 
367
404
  # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
368
- if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
369
- block['virtual_lines'] = copy.deepcopy(block['lines'])
370
- block['lines'] = copy.deepcopy(block['real_lines'])
371
- del block['real_lines']
405
+ if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
406
+ if 'real_lines' in block:
407
+ block['virtual_lines'] = copy.deepcopy(block['lines'])
408
+ block['lines'] = copy.deepcopy(block['real_lines'])
409
+ del block['real_lines']
372
410
  else:
373
411
  # 使用xycut排序
374
412
  block_bboxes = []
@@ -378,10 +416,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
378
416
  block_bboxes.append(block['bbox'])
379
417
 
380
418
  # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
381
- if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
382
- block['virtual_lines'] = copy.deepcopy(block['lines'])
383
- block['lines'] = copy.deepcopy(block['real_lines'])
384
- del block['real_lines']
419
+ if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
420
+ if 'real_lines' in block:
421
+ block['virtual_lines'] = copy.deepcopy(block['lines'])
422
+ block['lines'] = copy.deepcopy(block['real_lines'])
423
+ del block['real_lines']
385
424
 
386
425
  import numpy as np
387
426
 
@@ -417,7 +456,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
417
456
  block_weight = x1 - x0
418
457
 
419
458
  # 如果block高度小于n行正文,则直接返回block的bbox
420
- if line_height * 3 < block_height:
459
+ if line_height * 2 < block_height:
421
460
  if (
422
461
  block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
423
462
  ): # 可能是双列结构,可以切细点
@@ -425,16 +464,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
425
464
  else:
426
465
  # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
427
466
  if block_weight > page_w * 0.4:
428
- line_height = (y1 - y0) / 3
429
467
  lines = 3
468
+ line_height = (y1 - y0) / lines
430
469
  elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
431
470
  lines = int(block_height / line_height) + 1
432
471
  else: # 判断长宽比
433
472
  if block_height / block_weight > 1.2: # 细长的不分
434
473
  return [[x0, y0, x1, y1]]
435
474
  else: # 不细长的还是分成两行
436
- line_height = (y1 - y0) / 2
437
475
  lines = 2
476
+ line_height = (y1 - y0) / lines
438
477
 
439
478
  # 确定从哪个y位置开始绘制线条
440
479
  current_y = y0
@@ -453,30 +492,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
453
492
 
454
493
  def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
455
494
  page_line_list = []
495
+
496
+ def add_lines_to_block(b):
497
+ line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
498
+ b['lines'] = []
499
+ for line_bbox in line_bboxes:
500
+ b['lines'].append({'bbox': line_bbox, 'spans': []})
501
+ page_line_list.extend(line_bboxes)
502
+
456
503
  for block in fix_blocks:
457
504
  if block['type'] in [
458
- BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
505
+ BlockType.Text, BlockType.Title,
459
506
  BlockType.ImageCaption, BlockType.ImageFootnote,
460
507
  BlockType.TableCaption, BlockType.TableFootnote
461
508
  ]:
462
509
  if len(block['lines']) == 0:
463
- bbox = block['bbox']
464
- lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
465
- for line in lines:
466
- block['lines'].append({'bbox': line, 'spans': []})
467
- page_line_list.extend(lines)
510
+ add_lines_to_block(block)
511
+ elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
512
+ block['real_lines'] = copy.deepcopy(block['lines'])
513
+ add_lines_to_block(block)
468
514
  else:
469
515
  for line in block['lines']:
470
516
  bbox = line['bbox']
471
517
  page_line_list.append(bbox)
472
- elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
473
- bbox = block['bbox']
518
+ elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
474
519
  block['real_lines'] = copy.deepcopy(block['lines'])
475
- lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
476
- block['lines'] = []
477
- for line in lines:
478
- block['lines'].append({'bbox': line, 'spans': []})
479
- page_line_list.extend(lines)
520
+ add_lines_to_block(block)
480
521
 
481
522
  if len(page_line_list) > 200: # layoutreader最高支持512line
482
523
  return None
@@ -663,12 +704,77 @@ def parse_page_core(
663
704
  discarded_blocks = magic_model.get_discarded(page_id)
664
705
  text_blocks = magic_model.get_text_blocks(page_id)
665
706
  title_blocks = magic_model.get_title_blocks(page_id)
666
- inline_equations, interline_equations, interline_equation_blocks = (
667
- magic_model.get_equations(page_id)
668
- )
669
-
707
+ inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
670
708
  page_w, page_h = magic_model.get_page_size(page_id)
671
709
 
710
+ def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
711
+ def merge_two_bbox(b1, b2):
712
+ x_min = min(b1['bbox'][0], b2['bbox'][0])
713
+ y_min = min(b1['bbox'][1], b2['bbox'][1])
714
+ x_max = max(b1['bbox'][2], b2['bbox'][2])
715
+ y_max = max(b1['bbox'][3], b2['bbox'][3])
716
+ return x_min, y_min, x_max, y_max
717
+
718
+ def merge_two_blocks(b1, b2):
719
+ # 合并两个标题块的边界框
720
+ b1['bbox'] = merge_two_bbox(b1, b2)
721
+
722
+ # 合并两个标题块的文本内容
723
+ line1 = b1['lines'][0]
724
+ line2 = b2['lines'][0]
725
+ line1['bbox'] = merge_two_bbox(line1, line2)
726
+ line1['spans'].extend(line2['spans'])
727
+
728
+ return b1, b2
729
+
730
+ # 按 y 轴重叠度聚集标题块
731
+ y_overlapping_blocks = []
732
+ title_bs = [b for b in blocks if b['type'] == BlockType.Title]
733
+ while title_bs:
734
+ block1 = title_bs.pop(0)
735
+ current_row = [block1]
736
+ to_remove = []
737
+ for block2 in title_bs:
738
+ if (
739
+ __is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9)
740
+ and len(block1['lines']) == 1
741
+ and len(block2['lines']) == 1
742
+ ):
743
+ current_row.append(block2)
744
+ to_remove.append(block2)
745
+ for b in to_remove:
746
+ title_bs.remove(b)
747
+ y_overlapping_blocks.append(current_row)
748
+
749
+ # 按x轴坐标排序并合并标题块
750
+ to_remove_blocks = []
751
+ for row in y_overlapping_blocks:
752
+ if len(row) == 1:
753
+ continue
754
+
755
+ # 按x轴坐标排序
756
+ row.sort(key=lambda x: x['bbox'][0])
757
+
758
+ merged_block = row[0]
759
+ for i in range(1, len(row)):
760
+ left_block = merged_block
761
+ right_block = row[i]
762
+
763
+ left_height = left_block['bbox'][3] - left_block['bbox'][1]
764
+ right_height = right_block['bbox'][3] - right_block['bbox'][1]
765
+
766
+ if (
767
+ right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold
768
+ and left_height * 0.95 < right_height < left_height * 1.05
769
+ ):
770
+ merged_block, to_remove_block = merge_two_blocks(merged_block, right_block)
771
+ to_remove_blocks.append(to_remove_block)
772
+ else:
773
+ merged_block = right_block
774
+
775
+ for b in to_remove_blocks:
776
+ blocks.remove(b)
777
+
672
778
  """将所有区块的bbox整理到一起"""
673
779
  # interline_equation_blocks参数不够准,后面切换到interline_equations上
674
780
  interline_equation_blocks = []
@@ -753,6 +859,9 @@ def parse_page_core(
753
859
  """对block进行fix操作"""
754
860
  fix_blocks = fix_block_spans_v2(block_with_spans)
755
861
 
862
+ """同一行被断开的titile合并"""
863
+ merge_title_blocks(fix_blocks)
864
+
756
865
  """获取所有line并计算正文line的高度"""
757
866
  line_height = get_line_height(fix_blocks)
758
867
 
@@ -861,17 +970,23 @@ def pdf_parse_union(
861
970
  formula_aided_config = llm_aided_config.get('formula_aided', None)
862
971
  if formula_aided_config is not None:
863
972
  if formula_aided_config.get('enable', False):
973
+ llm_aided_formula_start_time = time.time()
864
974
  llm_aided_formula(pdf_info_dict, formula_aided_config)
975
+ logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}')
865
976
  """文本优化"""
866
977
  text_aided_config = llm_aided_config.get('text_aided', None)
867
978
  if text_aided_config is not None:
868
979
  if text_aided_config.get('enable', False):
980
+ llm_aided_text_start_time = time.time()
869
981
  llm_aided_text(pdf_info_dict, text_aided_config)
982
+ logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}')
870
983
  """标题优化"""
871
984
  title_aided_config = llm_aided_config.get('title_aided', None)
872
985
  if title_aided_config is not None:
873
986
  if title_aided_config.get('enable', False):
987
+ llm_aided_title_start_time = time.time()
874
988
  llm_aided_title(pdf_info_dict, title_aided_config)
989
+ logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
875
990
 
876
991
  """dict转list"""
877
992
  pdf_info_list = dict_to_list(pdf_info_dict)
@@ -3,6 +3,7 @@ import json
3
3
  from loguru import logger
4
4
  from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
5
5
  from openai import OpenAI
6
+ import ast
6
7
 
7
8
 
8
9
  #@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
@@ -83,26 +84,48 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
83
84
  if block["type"] == "title":
84
85
  origin_title_list.append(block)
85
86
  title_text = merge_para_with_text(block)
86
- title_dict[f"{i}"] = title_text
87
+ page_line_height_list = []
88
+ for line in block['lines']:
89
+ bbox = line['bbox']
90
+ page_line_height_list.append(int(bbox[3] - bbox[1]))
91
+ if len(page_line_height_list) > 0:
92
+ line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
93
+ else:
94
+ line_avg_height = int(block['bbox'][3] - block['bbox'][1])
95
+ title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
87
96
  i += 1
88
97
  # logger.info(f"Title list: {title_dict}")
89
98
 
90
99
  title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
91
100
 
92
- 1. 保留原始内容:
101
+ 1. 字典中每个value均为一个list,包含以下元素:
102
+ - 标题文本
103
+ - 文本行高是标题所在块的平均行高
104
+ - 标题所在的页码
105
+
106
+ 2. 保留原始内容:
93
107
  - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
94
108
  - 请务必保证输出的字典中元素的数量和输入的数量一致
95
109
 
96
- 2. 保持字典内key-value的对应关系不变
110
+ 3. 保持字典内key-value的对应关系不变
97
111
 
98
- 3. 优化层次结构:
112
+ 4. 优化层次结构:
99
113
  - 为每个标题元素添加适当的层次结构
100
- - 标题层级应具有连续性,不能跳过某一层级
114
+ - 行高较大的标题一般是更高级别的标题
115
+ - 标题从前至后的层级必须是连续的,不能跳过层级
101
116
  - 标题层级最多为4级,不要添加过多的层级
102
- - 优化后的标题为一个整数,代表该标题的层级
103
-
117
+ - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
118
+
119
+ 5. 合理性检查与微调:
120
+ - 在完成初步分级后,仔细检查分级结果的合理性
121
+ - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
122
+ - 确保最终的分级结果符合文档的实际结构和逻辑
123
+ - 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
124
+
104
125
  IMPORTANT:
105
- 请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。
126
+ 请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
127
+ {{0:1,1:2,2:2,3:3}}
128
+ 不需要对字典格式化,不需要返回任何其他信息。
106
129
 
107
130
  Input title list:
108
131
  {title_dict}
@@ -110,24 +133,32 @@ Input title list:
110
133
  Corrected title list:
111
134
  """
112
135
 
113
- completion = client.chat.completions.create(
114
- model=title_aided_config["model"],
115
- messages=[
116
- {'role': 'user', 'content': title_optimize_prompt}],
117
- temperature=0.7,
118
- )
119
-
120
- json_completion = json.loads(completion.choices[0].message.content)
121
-
122
- # logger.info(f"Title completion: {json_completion}")
136
+ retry_count = 0
137
+ max_retries = 3
138
+ dict_completion = None
123
139
 
124
- # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
125
- if len(json_completion) == len(title_dict):
140
+ while retry_count < max_retries:
126
141
  try:
127
- for i, origin_title_block in enumerate(origin_title_list):
128
- origin_title_block["level"] = int(json_completion[str(i)])
142
+ completion = client.chat.completions.create(
143
+ model=title_aided_config["model"],
144
+ messages=[
145
+ {'role': 'user', 'content': title_optimize_prompt}],
146
+ temperature=0.7,
147
+ )
148
+ # logger.info(f"Title completion: {completion.choices[0].message.content}")
149
+ dict_completion = ast.literal_eval(completion.choices[0].message.content)
150
+ # logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
151
+
152
+ if len(dict_completion) == len(title_dict):
153
+ for i, origin_title_block in enumerate(origin_title_list):
154
+ origin_title_block["level"] = int(dict_completion[i])
155
+ break
156
+ else:
157
+ logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
158
+ retry_count += 1
129
159
  except Exception as e:
130
160
  logger.exception(e)
131
- else:
132
- logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
161
+ retry_count += 1
133
162
 
163
+ if dict_completion is None:
164
+ logger.error("Failed to decode dict after maximum retries.")
@@ -60,6 +60,19 @@ def merge_spans_to_line(spans, threshold=0.6):
60
60
  return lines
61
61
 
62
62
 
63
+ def span_block_type_compatible(span_type, block_type):
64
+ if span_type in [ContentType.Text, ContentType.InlineEquation]:
65
+ return block_type in [BlockType.Text, BlockType.Title, BlockType.ImageCaption, BlockType.ImageFootnote, BlockType.TableCaption, BlockType.TableFootnote]
66
+ elif span_type == ContentType.InterlineEquation:
67
+ return block_type in [BlockType.InterlineEquation]
68
+ elif span_type == ContentType.Image:
69
+ return block_type in [BlockType.ImageBody]
70
+ elif span_type == ContentType.Table:
71
+ return block_type in [BlockType.TableBody]
72
+ else:
73
+ return False
74
+
75
+
63
76
  def fill_spans_in_blocks(blocks, spans, radio):
64
77
  """将allspans中的span按位置关系,放入blocks中."""
65
78
  block_with_spans = []
@@ -78,8 +91,7 @@ def fill_spans_in_blocks(blocks, spans, radio):
78
91
  block_spans = []
79
92
  for span in spans:
80
93
  span_bbox = span['bbox']
81
- if calculate_overlap_area_in_bbox1_area_ratio(
82
- span_bbox, block_bbox) > radio:
94
+ if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio and span_block_type_compatible(span['type'], block_type):
83
95
  block_spans.append(span)
84
96
 
85
97
  block_dict['spans'] = block_spans
@@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans):
36
36
  def check_chars_is_overlap_in_span(chars):
37
37
  for i in range(len(chars)):
38
38
  for j in range(i + 1, len(chars)):
39
- if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.9:
39
+ if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.35:
40
40
  return True
41
41
  return False
42
42
 
@@ -1,8 +1,8 @@
1
1
  weights:
2
2
  layoutlmv3: Layout/LayoutLMv3/model_final.pth
3
- doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
3
+ doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
4
4
  yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
5
- unimernet_small: MFR/unimernet_small
5
+ unimernet_small: MFR/unimernet_small_2501
6
6
  struct_eqtable: TabRec/StructEqTable
7
7
  tablemaster: TabRec/TableMaster
8
8
  rapid_table: TabRec/RapidTable