magic-pdf 0.10.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. magic_pdf/config/constants.py +2 -0
  2. magic_pdf/config/exceptions.py +7 -0
  3. magic_pdf/data/data_reader_writer/filebase.py +1 -1
  4. magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
  5. magic_pdf/data/dataset.py +13 -1
  6. magic_pdf/data/read_api.py +59 -12
  7. magic_pdf/data/utils.py +35 -0
  8. magic_pdf/dict2md/ocr_mkcontent.py +14 -13
  9. magic_pdf/libs/clean_memory.py +11 -4
  10. magic_pdf/libs/config_reader.py +9 -0
  11. magic_pdf/libs/draw_bbox.py +8 -12
  12. magic_pdf/libs/language.py +3 -0
  13. magic_pdf/libs/version.py +1 -1
  14. magic_pdf/model/__init__.py +1 -125
  15. magic_pdf/model/batch_analyze.py +275 -0
  16. magic_pdf/model/doc_analyze_by_custom_model.py +4 -51
  17. magic_pdf/model/magic_model.py +4 -435
  18. magic_pdf/model/model_list.py +1 -0
  19. magic_pdf/model/pdf_extract_kit.py +33 -22
  20. magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
  21. magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
  22. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
  23. magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
  24. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
  25. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
  26. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
  27. magic_pdf/model/sub_modules/model_init.py +30 -4
  28. magic_pdf/model/sub_modules/model_utils.py +8 -2
  29. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
  30. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
  31. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
  32. magic_pdf/operators/__init__.py +94 -0
  33. magic_pdf/{model/operators.py → operators/models.py} +2 -38
  34. magic_pdf/{pipe/operators.py → operators/pipes.py} +70 -17
  35. magic_pdf/pdf_parse_union_core_v2.py +68 -17
  36. magic_pdf/post_proc/__init__.py +1 -0
  37. magic_pdf/post_proc/llm_aided.py +133 -0
  38. magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
  39. magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
  40. magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
  41. magic_pdf/tools/cli.py +36 -11
  42. magic_pdf/tools/common.py +28 -18
  43. magic_pdf/utils/office_to_pdf.py +29 -0
  44. {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/METADATA +73 -23
  45. {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/RECORD +50 -53
  46. magic_pdf/para/__init__.py +0 -0
  47. magic_pdf/pdf_parse_by_ocr.py +0 -22
  48. magic_pdf/pdf_parse_by_txt.py +0 -23
  49. magic_pdf/pipe/AbsPipe.py +0 -99
  50. magic_pdf/pipe/OCRPipe.py +0 -80
  51. magic_pdf/pipe/TXTPipe.py +0 -42
  52. magic_pdf/pipe/UNIPipe.py +0 -150
  53. magic_pdf/pipe/__init__.py +0 -0
  54. magic_pdf/rw/AbsReaderWriter.py +0 -17
  55. magic_pdf/rw/DiskReaderWriter.py +0 -74
  56. magic_pdf/rw/S3ReaderWriter.py +0 -142
  57. magic_pdf/rw/__init__.py +0 -0
  58. magic_pdf/user_api.py +0 -144
  59. /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
  60. {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/LICENSE.md +0 -0
  61. {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/WHEEL +0 -0
  62. {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/entry_points.txt +0 -0
  63. {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,51 @@
1
+ import cv2
1
2
  import numpy as np
3
+ import torch
4
+ from loguru import logger
2
5
  from rapid_table import RapidTable
3
- from rapidocr_paddle import RapidOCR
4
6
 
5
7
 
6
8
  class RapidTableModel(object):
7
- def __init__(self):
9
+ def __init__(self, ocr_engine):
8
10
  self.table_model = RapidTable()
9
- self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
11
+ # if ocr_engine is None:
12
+ # self.ocr_model_name = "RapidOCR"
13
+ # if torch.cuda.is_available():
14
+ # from rapidocr_paddle import RapidOCR
15
+ # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
16
+ # else:
17
+ # from rapidocr_onnxruntime import RapidOCR
18
+ # self.ocr_engine = RapidOCR()
19
+ # else:
20
+ # self.ocr_model_name = "PaddleOCR"
21
+ # self.ocr_engine = ocr_engine
22
+
23
+ self.ocr_model_name = "RapidOCR"
24
+ if torch.cuda.is_available():
25
+ from rapidocr_paddle import RapidOCR
26
+ self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
27
+ else:
28
+ from rapidocr_onnxruntime import RapidOCR
29
+ self.ocr_engine = RapidOCR()
10
30
 
11
31
  def predict(self, image):
12
- ocr_result, _ = self.ocr_engine(np.asarray(image))
13
- if ocr_result is None:
32
+
33
+ if self.ocr_model_name == "RapidOCR":
34
+ ocr_result, _ = self.ocr_engine(np.asarray(image))
35
+ elif self.ocr_model_name == "PaddleOCR":
36
+ bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
37
+ ocr_result = self.ocr_engine.ocr(bgr_image)[0]
38
+ if ocr_result:
39
+ ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
40
+ len(item) == 2 and isinstance(item[1], tuple)]
41
+ else:
42
+ ocr_result = None
43
+ else:
44
+ logger.error("OCR model not supported")
45
+ ocr_result = None
46
+
47
+ if ocr_result:
48
+ html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
49
+ return html_code, table_cell_bboxes, elapse
50
+ else:
14
51
  return None, None, None
15
- html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
16
- return html_code, table_cell_bboxes, elapse
@@ -0,0 +1,94 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Callable
3
+
4
+ from magic_pdf.data.data_reader_writer import DataWriter
5
+ from magic_pdf.data.dataset import Dataset
6
+ from magic_pdf.operators.pipes import PipeResult
7
+
8
+
9
+ class InferenceResultBase(ABC):
10
+
11
+ @abstractmethod
12
+ def __init__(self, inference_results: list, dataset: Dataset):
13
+ """Initialized method.
14
+
15
+ Args:
16
+ inference_results (list): the inference result generated by model
17
+ dataset (Dataset): the dataset related with model inference result
18
+ """
19
+ pass
20
+
21
+ @abstractmethod
22
+ def draw_model(self, file_path: str) -> None:
23
+ """Draw model inference result.
24
+
25
+ Args:
26
+ file_path (str): the output file path
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ def dump_model(self, writer: DataWriter, file_path: str):
32
+ """Dump model inference result to file.
33
+
34
+ Args:
35
+ writer (DataWriter): writer handle
36
+ file_path (str): the location of target file
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def get_infer_res(self):
42
+ """Get the inference result.
43
+
44
+ Returns:
45
+ list: the inference result generated by model
46
+ """
47
+ pass
48
+
49
+ @abstractmethod
50
+ def apply(self, proc: Callable, *args, **kwargs):
51
+ """Apply callable method which.
52
+
53
+ Args:
54
+ proc (Callable): invoke proc as follows:
55
+ proc(inference_result, *args, **kwargs)
56
+
57
+ Returns:
58
+ Any: return the result generated by proc
59
+ """
60
+ pass
61
+
62
+ def pipe_txt_mode(
63
+ self,
64
+ imageWriter: DataWriter,
65
+ start_page_id=0,
66
+ end_page_id=None,
67
+ debug_mode=False,
68
+ lang=None,
69
+ ) -> PipeResult:
70
+ """Post-proc the model inference result, Extract the text using the
71
+ third library, such as `pymupdf`
72
+
73
+ Args:
74
+ imageWriter (DataWriter): the image writer handle
75
+ start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
76
+ end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
77
+ debug_mode (bool, optional): Defaults to False. will dump more log if enabled
78
+ lang (str, optional): Defaults to None.
79
+
80
+ Returns:
81
+ PipeResult: the result
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ def pipe_ocr_mode(
87
+ self,
88
+ imageWriter: DataWriter,
89
+ start_page_id=0,
90
+ end_page_id=None,
91
+ debug_mode=False,
92
+ lang=None,
93
+ ) -> PipeResult:
94
+ pass
@@ -7,13 +7,11 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
7
7
  from magic_pdf.config.enums import SupportedPdfParseMethod
8
8
  from magic_pdf.data.data_reader_writer import DataWriter
9
9
  from magic_pdf.data.dataset import Dataset
10
- from magic_pdf.filter import classify
11
10
  from magic_pdf.libs.draw_bbox import draw_model_bbox
12
11
  from magic_pdf.libs.version import __version__
13
- from magic_pdf.model import InferenceResultBase
12
+ from magic_pdf.operators.pipes import PipeResult
14
13
  from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
15
- from magic_pdf.pipe.operators import PipeResult
16
-
14
+ from magic_pdf.operators import InferenceResultBase
17
15
 
18
16
  class InferenceResult(InferenceResultBase):
19
17
  def __init__(self, inference_results: list, dataset: Dataset):
@@ -71,40 +69,6 @@ class InferenceResult(InferenceResultBase):
71
69
  """
72
70
  return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
73
71
 
74
- def pipe_auto_mode(
75
- self,
76
- imageWriter: DataWriter,
77
- start_page_id=0,
78
- end_page_id=None,
79
- debug_mode=False,
80
- lang=None,
81
- ) -> PipeResult:
82
- """Post-proc the model inference result.
83
- step1: classify the dataset type
84
- step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
85
-
86
- Args:
87
- imageWriter (DataWriter): the image writer handle
88
- start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
89
- end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
90
- debug_mode (bool, optional): Defaults to False. will dump more log if enabled
91
- lang (str, optional): Defaults to None.
92
-
93
- Returns:
94
- PipeResult: the result
95
- """
96
-
97
- pdf_proc_method = classify(self._dataset.data_bits())
98
-
99
- if pdf_proc_method == SupportedPdfParseMethod.TXT:
100
- return self.pipe_txt_mode(
101
- imageWriter, start_page_id, end_page_id, debug_mode, lang
102
- )
103
- else:
104
- return self.pipe_ocr_mode(
105
- imageWriter, start_page_id, end_page_id, debug_mode, lang
106
- )
107
-
108
72
  def pipe_txt_mode(
109
73
  self,
110
74
  imageWriter: DataWriter,
@@ -1,7 +1,7 @@
1
+ import copy
1
2
  import json
2
3
  import os
3
4
  from typing import Callable
4
- import copy
5
5
 
6
6
  from magic_pdf.config.make_content_config import DropMode, MakeMode
7
7
  from magic_pdf.data.data_reader_writer import DataWriter
@@ -23,12 +23,34 @@ class PipeResult:
23
23
  self._pipe_res = pipe_res
24
24
  self._dataset = dataset
25
25
 
26
+ def get_markdown(
27
+ self,
28
+ img_dir_or_bucket_prefix: str,
29
+ drop_mode=DropMode.NONE,
30
+ md_make_mode=MakeMode.MM_MD,
31
+ ) -> str:
32
+ """Get markdown content.
33
+
34
+ Args:
35
+ img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
36
+ drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
37
+ md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
38
+
39
+ Returns:
40
+ str: return markdown content
41
+ """
42
+ pdf_info_list = self._pipe_res['pdf_info']
43
+ md_content = union_make(
44
+ pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
45
+ )
46
+ return md_content
47
+
26
48
  def dump_md(
27
49
  self,
28
50
  writer: DataWriter,
29
51
  file_path: str,
30
52
  img_dir_or_bucket_prefix: str,
31
- drop_mode=DropMode.WHOLE_PDF,
53
+ drop_mode=DropMode.NONE,
32
54
  md_make_mode=MakeMode.MM_MD,
33
55
  ):
34
56
  """Dump The Markdown.
@@ -37,36 +59,68 @@ class PipeResult:
37
59
  writer (DataWriter): File writer handle
38
60
  file_path (str): The file location of markdown
39
61
  img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
40
- drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.WHOLE_PDF.
62
+ drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
41
63
  md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
42
64
  """
43
- pdf_info_list = self._pipe_res['pdf_info']
44
- md_content = union_make(
45
- pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
65
+
66
+ md_content = self.get_markdown(
67
+ img_dir_or_bucket_prefix, drop_mode=drop_mode, md_make_mode=md_make_mode
46
68
  )
47
69
  writer.write_string(file_path, md_content)
48
70
 
49
- def dump_content_list(
50
- self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str
51
- ):
52
- """Dump Content List.
71
+ def get_content_list(
72
+ self,
73
+ image_dir_or_bucket_prefix: str,
74
+ drop_mode=DropMode.NONE,
75
+ ) -> str:
76
+ """Get Content List.
53
77
 
54
78
  Args:
55
- writer (DataWriter): File writer handle
56
- file_path (str): The file location of content list
57
79
  image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
80
+ drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
81
+
82
+ Returns:
83
+ str: content list content
58
84
  """
59
85
  pdf_info_list = self._pipe_res['pdf_info']
60
86
  content_list = union_make(
61
87
  pdf_info_list,
62
88
  MakeMode.STANDARD_FORMAT,
63
- DropMode.NONE,
89
+ drop_mode,
64
90
  image_dir_or_bucket_prefix,
65
91
  )
92
+ return content_list
93
+
94
+ def dump_content_list(
95
+ self,
96
+ writer: DataWriter,
97
+ file_path: str,
98
+ image_dir_or_bucket_prefix: str,
99
+ drop_mode=DropMode.NONE,
100
+ ):
101
+ """Dump Content List.
102
+
103
+ Args:
104
+ writer (DataWriter): File writer handle
105
+ file_path (str): The file location of content list
106
+ image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
107
+ drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
108
+ """
109
+ content_list = self.get_content_list(
110
+ image_dir_or_bucket_prefix, drop_mode=drop_mode,
111
+ )
66
112
  writer.write_string(
67
113
  file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
68
114
  )
69
115
 
116
+ def get_middle_json(self) -> str:
117
+ """Get middle json.
118
+
119
+ Returns:
120
+ str: The content of middle json
121
+ """
122
+ return json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
123
+
70
124
  def dump_middle_json(self, writer: DataWriter, file_path: str):
71
125
  """Dump the result of pipeline.
72
126
 
@@ -74,9 +128,8 @@ class PipeResult:
74
128
  writer (DataWriter): File writer handler
75
129
  file_path (str): The file location of middle json
76
130
  """
77
- writer.write_string(
78
- file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
79
- )
131
+ middle_json = self.get_middle_json()
132
+ writer.write_string(file_path, middle_json)
80
133
 
81
134
  def draw_layout(self, file_path: str) -> None:
82
135
  """Draw the layout.
@@ -123,7 +176,7 @@ class PipeResult:
123
176
  Returns:
124
177
  str: compress the pipeline result and return
125
178
  """
126
- return JsonCompressor.compress_json(self.pdf_mid_data)
179
+ return JsonCompressor.compress_json(self._pipe_res)
127
180
 
128
181
  def apply(self, proc: Callable, *args, **kwargs):
129
182
  """Apply callable method which.
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import os
3
+ import re
3
4
  import statistics
4
5
  import time
5
6
  from typing import List
@@ -13,11 +14,12 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
13
14
  from magic_pdf.data.dataset import Dataset, PageableData
14
15
  from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
15
16
  from magic_pdf.libs.clean_memory import clean_memory
16
- from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
17
+ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
17
18
  from magic_pdf.libs.convert_utils import dict_to_list
18
19
  from magic_pdf.libs.hash_utils import compute_md5
19
20
  from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
20
21
  from magic_pdf.model.magic_model import MagicModel
22
+ from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
21
23
 
22
24
  try:
23
25
  import torchtext
@@ -28,15 +30,15 @@ except ImportError:
28
30
  pass
29
31
 
30
32
  from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
31
- from magic_pdf.para.para_split_v3 import para_split
33
+ from magic_pdf.post_proc.para_split_v3 import para_split
32
34
  from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
33
35
  from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
34
36
  from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
35
37
  from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
36
- from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
38
+ from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
39
+ remove_overlaps_min_spans, check_chars_is_overlap_in_span
37
40
 
38
41
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
39
- os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
40
42
 
41
43
 
42
44
  def __replace_STX_ETX(text_str: str):
@@ -63,11 +65,22 @@ def __replace_0xfffd(text_str: str):
63
65
  return s
64
66
  return text_str
65
67
 
68
+
69
+ # 连写字符拆分
70
+ def __replace_ligatures(text: str):
71
+ ligatures = {
72
+ 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
73
+ }
74
+ return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
75
+
76
+
66
77
  def chars_to_content(span):
67
78
  # 检查span中的char是否为空
68
79
  if len(span['chars']) == 0:
69
80
  pass
70
81
  # span['content'] = ''
82
+ elif check_chars_is_overlap_in_span(span['chars']):
83
+ pass
71
84
  else:
72
85
  # 先给chars按char['bbox']的中心点的x坐标排序
73
86
  span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
@@ -78,11 +91,16 @@ def chars_to_content(span):
78
91
 
79
92
  content = ''
80
93
  for char in span['chars']:
81
- # 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
82
- if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
83
- content += ' '
84
- content += char['c']
85
94
 
95
+ # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
96
+ char1 = char
97
+ char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
98
+ if char2 and char2['bbox'][0] - char1['bbox'][2] > char_avg_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
99
+ content += f"{char['c']} "
100
+ else:
101
+ content += char['c']
102
+
103
+ content = __replace_ligatures(content)
86
104
  span['content'] = __replace_0xfffd(content)
87
105
 
88
106
  del span['chars']
@@ -98,6 +116,10 @@ def fill_char_in_spans(spans, all_chars):
98
116
  spans = sorted(spans, key=lambda x: x['bbox'][1])
99
117
 
100
118
  for char in all_chars:
119
+ # 跳过非法bbox的char
120
+ x1, y1, x2, y2 = char['bbox']
121
+ if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
122
+ continue
101
123
  for span in spans:
102
124
  if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
103
125
  span['chars'].append(char)
@@ -152,14 +174,16 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
152
174
 
153
175
 
154
176
  def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
177
+ # cid用0xfffd表示,连字符拆开
178
+ # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
155
179
 
156
- text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
157
-
180
+ # cid用0xfffd表示,连字符不拆开
181
+ text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
158
182
  all_pymu_chars = []
159
183
  for block in text_blocks_raw:
160
184
  for line in block['lines']:
161
185
  cosine, sine = line['dir']
162
- if abs (cosine) < 0.9 or abs(sine) > 0.1:
186
+ if abs(cosine) < 0.9 or abs(sine) > 0.1:
163
187
  continue
164
188
  for span in line['spans']:
165
189
  all_pymu_chars.extend(span['chars'])
@@ -255,19 +279,23 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
255
279
  return spans
256
280
 
257
281
 
258
- def replace_text_span(pymu_spans, ocr_spans):
259
- return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
260
-
261
-
262
282
  def model_init(model_name: str):
263
283
  from transformers import LayoutLMv3ForTokenClassification
264
-
284
+ device = get_device()
265
285
  if torch.cuda.is_available():
266
286
  device = torch.device('cuda')
267
287
  if torch.cuda.is_bf16_supported():
268
288
  supports_bfloat16 = True
269
289
  else:
270
290
  supports_bfloat16 = False
291
+ elif str(device).startswith("npu"):
292
+ import torch_npu
293
+ if torch_npu.npu.is_available():
294
+ device = torch.device('npu')
295
+ supports_bfloat16 = False
296
+ else:
297
+ device = torch.device('cpu')
298
+ supports_bfloat16 = False
271
299
  else:
272
300
  device = torch.device('cpu')
273
301
  supports_bfloat16 = False
@@ -345,6 +373,8 @@ def cal_block_index(fix_blocks, sorted_bboxes):
345
373
  # 使用xycut排序
346
374
  block_bboxes = []
347
375
  for block in fix_blocks:
376
+ # 如果block['bbox']任意值小于0,将其置为0
377
+ block['bbox'] = [max(0, x) for x in block['bbox']]
348
378
  block_bboxes.append(block['bbox'])
349
379
 
350
380
  # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
@@ -738,6 +768,11 @@ def parse_page_core(
738
768
  """重排block"""
739
769
  sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
740
770
 
771
+ """block内重排(img和table的block内多个caption或footnote的排序)"""
772
+ for block in sorted_blocks:
773
+ if block['type'] in [BlockType.Image, BlockType.Table]:
774
+ block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
775
+
741
776
  """获取QA需要外置的list"""
742
777
  images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
743
778
 
@@ -819,13 +854,29 @@ def pdf_parse_union(
819
854
  """分段"""
820
855
  para_split(pdf_info_dict)
821
856
 
857
+ """llm优化"""
858
+ llm_aided_config = get_llm_aided_config()
859
+ if llm_aided_config is not None:
860
+ """公式优化"""
861
+ formula_aided_config = llm_aided_config.get('formula_aided', None)
862
+ if formula_aided_config is not None:
863
+ llm_aided_formula(pdf_info_dict, formula_aided_config)
864
+ """文本优化"""
865
+ text_aided_config = llm_aided_config.get('text_aided', None)
866
+ if text_aided_config is not None:
867
+ llm_aided_text(pdf_info_dict, text_aided_config)
868
+ """标题优化"""
869
+ title_aided_config = llm_aided_config.get('title_aided', None)
870
+ if title_aided_config is not None:
871
+ llm_aided_title(pdf_info_dict, title_aided_config)
872
+
822
873
  """dict转list"""
823
874
  pdf_info_list = dict_to_list(pdf_info_dict)
824
875
  new_pdf_info_dict = {
825
876
  'pdf_info': pdf_info_list,
826
877
  }
827
878
 
828
- clean_memory()
879
+ clean_memory(get_device())
829
880
 
830
881
  return new_pdf_info_dict
831
882
 
@@ -0,0 +1 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
@@ -0,0 +1,133 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import json
3
+ from loguru import logger
4
+ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
5
+ from openai import OpenAI
6
+
7
+
8
+ #@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
9
+ formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
10
+
11
+ 1. 修正渲染或编译错误:
12
+ - Some syntax errors such as mismatched/missing/extra tokens. Your task is to fix these syntax errors and make sure corrected results conform to latex math syntax principles.
13
+ - 包含KaTeX不支持的关键词等原因导致的无法编译或渲染的错误
14
+
15
+ 2. 保留原始信息:
16
+ - 保留原始公式中的所有重要信息
17
+ - 不要添加任何原始公式中没有的新信息
18
+
19
+ IMPORTANT:请仅返回修正后的公式,不要包含任何介绍、解释或元数据。
20
+
21
+ LaTeX recognition result:
22
+ $FORMULA
23
+
24
+ Your corrected result:
25
+ """
26
+
27
+ text_optimize_prompt = f"""请根据以下指南修正OCR引起的错误,确保文本连贯并符合原始内容:
28
+
29
+ 1. 修正OCR引起的拼写错误和错误:
30
+ - 修正常见的OCR错误(例如,'rn' 被误读为 'm')
31
+ - 使用上下文和常识进行修正
32
+ - 只修正明显的错误,不要不必要的修改内容
33
+ - 不要添加额外的句号或其他不必要的标点符号
34
+
35
+ 2. 保持原始结构:
36
+ - 保留所有标题和子标题
37
+
38
+ 3. 保留原始内容:
39
+ - 保留原始文本中的所有重要信息
40
+ - 不要添加任何原始文本中没有的新信息
41
+ - 保留段落之间的换行符
42
+
43
+ 4. 保持连贯性:
44
+ - 确保内容与前文顺畅连接
45
+ - 适当处理在句子中间开始或结束的文本
46
+
47
+ 5. 修正行内公式:
48
+ - 去除行内公式前后多余的空格
49
+ - 修正公式中的OCR错误
50
+ - 确保公式能够通过KaTeX渲染
51
+
52
+ 6. 修正全角字符
53
+ - 修正全角标点符号为半角标点符号
54
+ - 修正全角字母为半角字母
55
+ - 修正全角数字为半角数字
56
+
57
+ IMPORTANT:请仅返回修正后的文本,保留所有原始格式,包括换行符。不要包含任何介绍、解释或元数据。
58
+
59
+ Previous context:
60
+
61
+ Current chunk to process:
62
+
63
+ Corrected text:
64
+ """
65
+
66
+ def llm_aided_formula(pdf_info_dict, formula_aided_config):
67
+ pass
68
+
69
+ def llm_aided_text(pdf_info_dict, text_aided_config):
70
+ pass
71
+
72
+ def llm_aided_title(pdf_info_dict, title_aided_config):
73
+ client = OpenAI(
74
+ api_key=title_aided_config["api_key"],
75
+ base_url=title_aided_config["base_url"],
76
+ )
77
+ title_dict = {}
78
+ origin_title_list = []
79
+ i = 0
80
+ for page_num, page in pdf_info_dict.items():
81
+ blocks = page["para_blocks"]
82
+ for block in blocks:
83
+ if block["type"] == "title":
84
+ origin_title_list.append(block)
85
+ title_text = merge_para_with_text(block)
86
+ title_dict[f"{i}"] = title_text
87
+ i += 1
88
+ # logger.info(f"Title list: {title_dict}")
89
+
90
+ title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
91
+
92
+ 1. 保留原始内容:
93
+ - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
94
+ - 请务必保证输出的字典中元素的数量和输入的数量一致
95
+
96
+ 2. 保持字典内key-value的对应关系不变
97
+
98
+ 3. 优化层次结构:
99
+ - 为每个标题元素添加适当的层次结构
100
+ - 标题层级应具有连续性,不能跳过某一层级
101
+ - 标题层级最多为4级,不要添加过多的层级
102
+ - 优化后的标题为一个整数,代表该标题的层级
103
+
104
+ IMPORTANT:
105
+ 请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。
106
+
107
+ Input title list:
108
+ {title_dict}
109
+
110
+ Corrected title list:
111
+ """
112
+
113
+ completion = client.chat.completions.create(
114
+ model=title_aided_config["model"],
115
+ messages=[
116
+ {'role': 'user', 'content': title_optimize_prompt}],
117
+ temperature=0.7,
118
+ )
119
+
120
+ json_completion = json.loads(completion.choices[0].message.content)
121
+
122
+ # logger.info(f"Title completion: {json_completion}")
123
+
124
+ # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
125
+ if len(json_completion) == len(title_dict):
126
+ try:
127
+ for i, origin_title_block in enumerate(origin_title_list):
128
+ origin_title_block["level"] = int(json_completion[str(i)])
129
+ except Exception as e:
130
+ logger.exception(e)
131
+ else:
132
+ logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
133
+