magic-pdf 0.10.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +2 -0
- magic_pdf/config/exceptions.py +7 -0
- magic_pdf/data/data_reader_writer/filebase.py +1 -1
- magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
- magic_pdf/data/dataset.py +13 -1
- magic_pdf/data/read_api.py +59 -12
- magic_pdf/data/utils.py +35 -0
- magic_pdf/dict2md/ocr_mkcontent.py +14 -13
- magic_pdf/libs/clean_memory.py +11 -4
- magic_pdf/libs/config_reader.py +9 -0
- magic_pdf/libs/draw_bbox.py +8 -12
- magic_pdf/libs/language.py +3 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -125
- magic_pdf/model/batch_analyze.py +275 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +4 -51
- magic_pdf/model/magic_model.py +4 -435
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +33 -22
- magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
- magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
- magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
- magic_pdf/model/sub_modules/model_init.py +30 -4
- magic_pdf/model/sub_modules/model_utils.py +8 -2
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
- magic_pdf/operators/__init__.py +94 -0
- magic_pdf/{model/operators.py → operators/models.py} +2 -38
- magic_pdf/{pipe/operators.py → operators/pipes.py} +70 -17
- magic_pdf/pdf_parse_union_core_v2.py +68 -17
- magic_pdf/post_proc/__init__.py +1 -0
- magic_pdf/post_proc/llm_aided.py +133 -0
- magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
- magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
- magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
- magic_pdf/tools/cli.py +36 -11
- magic_pdf/tools/common.py +28 -18
- magic_pdf/utils/office_to_pdf.py +29 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/METADATA +73 -23
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/RECORD +50 -53
- magic_pdf/para/__init__.py +0 -0
- magic_pdf/pdf_parse_by_ocr.py +0 -22
- magic_pdf/pdf_parse_by_txt.py +0 -23
- magic_pdf/pipe/AbsPipe.py +0 -99
- magic_pdf/pipe/OCRPipe.py +0 -80
- magic_pdf/pipe/TXTPipe.py +0 -42
- magic_pdf/pipe/UNIPipe.py +0 -150
- magic_pdf/pipe/__init__.py +0 -0
- magic_pdf/rw/AbsReaderWriter.py +0 -17
- magic_pdf/rw/DiskReaderWriter.py +0 -74
- magic_pdf/rw/S3ReaderWriter.py +0 -142
- magic_pdf/rw/__init__.py +0 -0
- magic_pdf/user_api.py +0 -144
- /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.10.6.dist-info → magic_pdf-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,51 @@
|
|
1
|
+
import cv2
|
1
2
|
import numpy as np
|
3
|
+
import torch
|
4
|
+
from loguru import logger
|
2
5
|
from rapid_table import RapidTable
|
3
|
-
from rapidocr_paddle import RapidOCR
|
4
6
|
|
5
7
|
|
6
8
|
class RapidTableModel(object):
|
7
|
-
def __init__(self):
|
9
|
+
def __init__(self, ocr_engine):
|
8
10
|
self.table_model = RapidTable()
|
9
|
-
|
11
|
+
# if ocr_engine is None:
|
12
|
+
# self.ocr_model_name = "RapidOCR"
|
13
|
+
# if torch.cuda.is_available():
|
14
|
+
# from rapidocr_paddle import RapidOCR
|
15
|
+
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
16
|
+
# else:
|
17
|
+
# from rapidocr_onnxruntime import RapidOCR
|
18
|
+
# self.ocr_engine = RapidOCR()
|
19
|
+
# else:
|
20
|
+
# self.ocr_model_name = "PaddleOCR"
|
21
|
+
# self.ocr_engine = ocr_engine
|
22
|
+
|
23
|
+
self.ocr_model_name = "RapidOCR"
|
24
|
+
if torch.cuda.is_available():
|
25
|
+
from rapidocr_paddle import RapidOCR
|
26
|
+
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
27
|
+
else:
|
28
|
+
from rapidocr_onnxruntime import RapidOCR
|
29
|
+
self.ocr_engine = RapidOCR()
|
10
30
|
|
11
31
|
def predict(self, image):
|
12
|
-
|
13
|
-
if
|
32
|
+
|
33
|
+
if self.ocr_model_name == "RapidOCR":
|
34
|
+
ocr_result, _ = self.ocr_engine(np.asarray(image))
|
35
|
+
elif self.ocr_model_name == "PaddleOCR":
|
36
|
+
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
37
|
+
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
38
|
+
if ocr_result:
|
39
|
+
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
|
40
|
+
len(item) == 2 and isinstance(item[1], tuple)]
|
41
|
+
else:
|
42
|
+
ocr_result = None
|
43
|
+
else:
|
44
|
+
logger.error("OCR model not supported")
|
45
|
+
ocr_result = None
|
46
|
+
|
47
|
+
if ocr_result:
|
48
|
+
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
|
49
|
+
return html_code, table_cell_bboxes, elapse
|
50
|
+
else:
|
14
51
|
return None, None, None
|
15
|
-
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
|
16
|
-
return html_code, table_cell_bboxes, elapse
|
@@ -0,0 +1,94 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
from magic_pdf.data.data_reader_writer import DataWriter
|
5
|
+
from magic_pdf.data.dataset import Dataset
|
6
|
+
from magic_pdf.operators.pipes import PipeResult
|
7
|
+
|
8
|
+
|
9
|
+
class InferenceResultBase(ABC):
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def __init__(self, inference_results: list, dataset: Dataset):
|
13
|
+
"""Initialized method.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
inference_results (list): the inference result generated by model
|
17
|
+
dataset (Dataset): the dataset related with model inference result
|
18
|
+
"""
|
19
|
+
pass
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
def draw_model(self, file_path: str) -> None:
|
23
|
+
"""Draw model inference result.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
file_path (str): the output file path
|
27
|
+
"""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
def dump_model(self, writer: DataWriter, file_path: str):
|
32
|
+
"""Dump model inference result to file.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
writer (DataWriter): writer handle
|
36
|
+
file_path (str): the location of target file
|
37
|
+
"""
|
38
|
+
pass
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def get_infer_res(self):
|
42
|
+
"""Get the inference result.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
list: the inference result generated by model
|
46
|
+
"""
|
47
|
+
pass
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def apply(self, proc: Callable, *args, **kwargs):
|
51
|
+
"""Apply callable method which.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
proc (Callable): invoke proc as follows:
|
55
|
+
proc(inference_result, *args, **kwargs)
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Any: return the result generated by proc
|
59
|
+
"""
|
60
|
+
pass
|
61
|
+
|
62
|
+
def pipe_txt_mode(
|
63
|
+
self,
|
64
|
+
imageWriter: DataWriter,
|
65
|
+
start_page_id=0,
|
66
|
+
end_page_id=None,
|
67
|
+
debug_mode=False,
|
68
|
+
lang=None,
|
69
|
+
) -> PipeResult:
|
70
|
+
"""Post-proc the model inference result, Extract the text using the
|
71
|
+
third library, such as `pymupdf`
|
72
|
+
|
73
|
+
Args:
|
74
|
+
imageWriter (DataWriter): the image writer handle
|
75
|
+
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
76
|
+
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
77
|
+
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
78
|
+
lang (str, optional): Defaults to None.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
PipeResult: the result
|
82
|
+
"""
|
83
|
+
pass
|
84
|
+
|
85
|
+
@abstractmethod
|
86
|
+
def pipe_ocr_mode(
|
87
|
+
self,
|
88
|
+
imageWriter: DataWriter,
|
89
|
+
start_page_id=0,
|
90
|
+
end_page_id=None,
|
91
|
+
debug_mode=False,
|
92
|
+
lang=None,
|
93
|
+
) -> PipeResult:
|
94
|
+
pass
|
@@ -7,13 +7,11 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
|
|
7
7
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
8
8
|
from magic_pdf.data.data_reader_writer import DataWriter
|
9
9
|
from magic_pdf.data.dataset import Dataset
|
10
|
-
from magic_pdf.filter import classify
|
11
10
|
from magic_pdf.libs.draw_bbox import draw_model_bbox
|
12
11
|
from magic_pdf.libs.version import __version__
|
13
|
-
from magic_pdf.
|
12
|
+
from magic_pdf.operators.pipes import PipeResult
|
14
13
|
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
|
15
|
-
from magic_pdf.
|
16
|
-
|
14
|
+
from magic_pdf.operators import InferenceResultBase
|
17
15
|
|
18
16
|
class InferenceResult(InferenceResultBase):
|
19
17
|
def __init__(self, inference_results: list, dataset: Dataset):
|
@@ -71,40 +69,6 @@ class InferenceResult(InferenceResultBase):
|
|
71
69
|
"""
|
72
70
|
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
|
73
71
|
|
74
|
-
def pipe_auto_mode(
|
75
|
-
self,
|
76
|
-
imageWriter: DataWriter,
|
77
|
-
start_page_id=0,
|
78
|
-
end_page_id=None,
|
79
|
-
debug_mode=False,
|
80
|
-
lang=None,
|
81
|
-
) -> PipeResult:
|
82
|
-
"""Post-proc the model inference result.
|
83
|
-
step1: classify the dataset type
|
84
|
-
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
|
85
|
-
|
86
|
-
Args:
|
87
|
-
imageWriter (DataWriter): the image writer handle
|
88
|
-
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
89
|
-
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
90
|
-
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
91
|
-
lang (str, optional): Defaults to None.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
PipeResult: the result
|
95
|
-
"""
|
96
|
-
|
97
|
-
pdf_proc_method = classify(self._dataset.data_bits())
|
98
|
-
|
99
|
-
if pdf_proc_method == SupportedPdfParseMethod.TXT:
|
100
|
-
return self.pipe_txt_mode(
|
101
|
-
imageWriter, start_page_id, end_page_id, debug_mode, lang
|
102
|
-
)
|
103
|
-
else:
|
104
|
-
return self.pipe_ocr_mode(
|
105
|
-
imageWriter, start_page_id, end_page_id, debug_mode, lang
|
106
|
-
)
|
107
|
-
|
108
72
|
def pipe_txt_mode(
|
109
73
|
self,
|
110
74
|
imageWriter: DataWriter,
|
@@ -1,7 +1,7 @@
|
|
1
|
+
import copy
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
from typing import Callable
|
4
|
-
import copy
|
5
5
|
|
6
6
|
from magic_pdf.config.make_content_config import DropMode, MakeMode
|
7
7
|
from magic_pdf.data.data_reader_writer import DataWriter
|
@@ -23,12 +23,34 @@ class PipeResult:
|
|
23
23
|
self._pipe_res = pipe_res
|
24
24
|
self._dataset = dataset
|
25
25
|
|
26
|
+
def get_markdown(
|
27
|
+
self,
|
28
|
+
img_dir_or_bucket_prefix: str,
|
29
|
+
drop_mode=DropMode.NONE,
|
30
|
+
md_make_mode=MakeMode.MM_MD,
|
31
|
+
) -> str:
|
32
|
+
"""Get markdown content.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
|
36
|
+
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
|
37
|
+
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
str: return markdown content
|
41
|
+
"""
|
42
|
+
pdf_info_list = self._pipe_res['pdf_info']
|
43
|
+
md_content = union_make(
|
44
|
+
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
|
45
|
+
)
|
46
|
+
return md_content
|
47
|
+
|
26
48
|
def dump_md(
|
27
49
|
self,
|
28
50
|
writer: DataWriter,
|
29
51
|
file_path: str,
|
30
52
|
img_dir_or_bucket_prefix: str,
|
31
|
-
drop_mode=DropMode.
|
53
|
+
drop_mode=DropMode.NONE,
|
32
54
|
md_make_mode=MakeMode.MM_MD,
|
33
55
|
):
|
34
56
|
"""Dump The Markdown.
|
@@ -37,36 +59,68 @@ class PipeResult:
|
|
37
59
|
writer (DataWriter): File writer handle
|
38
60
|
file_path (str): The file location of markdown
|
39
61
|
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
|
40
|
-
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.
|
62
|
+
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
|
41
63
|
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
|
42
64
|
"""
|
43
|
-
|
44
|
-
md_content =
|
45
|
-
|
65
|
+
|
66
|
+
md_content = self.get_markdown(
|
67
|
+
img_dir_or_bucket_prefix, drop_mode=drop_mode, md_make_mode=md_make_mode
|
46
68
|
)
|
47
69
|
writer.write_string(file_path, md_content)
|
48
70
|
|
49
|
-
def
|
50
|
-
self,
|
51
|
-
|
52
|
-
|
71
|
+
def get_content_list(
|
72
|
+
self,
|
73
|
+
image_dir_or_bucket_prefix: str,
|
74
|
+
drop_mode=DropMode.NONE,
|
75
|
+
) -> str:
|
76
|
+
"""Get Content List.
|
53
77
|
|
54
78
|
Args:
|
55
|
-
writer (DataWriter): File writer handle
|
56
|
-
file_path (str): The file location of content list
|
57
79
|
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
|
80
|
+
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
str: content list content
|
58
84
|
"""
|
59
85
|
pdf_info_list = self._pipe_res['pdf_info']
|
60
86
|
content_list = union_make(
|
61
87
|
pdf_info_list,
|
62
88
|
MakeMode.STANDARD_FORMAT,
|
63
|
-
|
89
|
+
drop_mode,
|
64
90
|
image_dir_or_bucket_prefix,
|
65
91
|
)
|
92
|
+
return content_list
|
93
|
+
|
94
|
+
def dump_content_list(
|
95
|
+
self,
|
96
|
+
writer: DataWriter,
|
97
|
+
file_path: str,
|
98
|
+
image_dir_or_bucket_prefix: str,
|
99
|
+
drop_mode=DropMode.NONE,
|
100
|
+
):
|
101
|
+
"""Dump Content List.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
writer (DataWriter): File writer handle
|
105
|
+
file_path (str): The file location of content list
|
106
|
+
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
|
107
|
+
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
|
108
|
+
"""
|
109
|
+
content_list = self.get_content_list(
|
110
|
+
image_dir_or_bucket_prefix, drop_mode=drop_mode,
|
111
|
+
)
|
66
112
|
writer.write_string(
|
67
113
|
file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
|
68
114
|
)
|
69
115
|
|
116
|
+
def get_middle_json(self) -> str:
|
117
|
+
"""Get middle json.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
str: The content of middle json
|
121
|
+
"""
|
122
|
+
return json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
|
123
|
+
|
70
124
|
def dump_middle_json(self, writer: DataWriter, file_path: str):
|
71
125
|
"""Dump the result of pipeline.
|
72
126
|
|
@@ -74,9 +128,8 @@ class PipeResult:
|
|
74
128
|
writer (DataWriter): File writer handler
|
75
129
|
file_path (str): The file location of middle json
|
76
130
|
"""
|
77
|
-
|
78
|
-
|
79
|
-
)
|
131
|
+
middle_json = self.get_middle_json()
|
132
|
+
writer.write_string(file_path, middle_json)
|
80
133
|
|
81
134
|
def draw_layout(self, file_path: str) -> None:
|
82
135
|
"""Draw the layout.
|
@@ -123,7 +176,7 @@ class PipeResult:
|
|
123
176
|
Returns:
|
124
177
|
str: compress the pipeline result and return
|
125
178
|
"""
|
126
|
-
return JsonCompressor.compress_json(self.
|
179
|
+
return JsonCompressor.compress_json(self._pipe_res)
|
127
180
|
|
128
181
|
def apply(self, proc: Callable, *args, **kwargs):
|
129
182
|
"""Apply callable method which.
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import os
|
3
|
+
import re
|
3
4
|
import statistics
|
4
5
|
import time
|
5
6
|
from typing import List
|
@@ -13,11 +14,12 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
|
|
13
14
|
from magic_pdf.data.dataset import Dataset, PageableData
|
14
15
|
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
|
15
16
|
from magic_pdf.libs.clean_memory import clean_memory
|
16
|
-
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
|
17
|
+
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
|
17
18
|
from magic_pdf.libs.convert_utils import dict_to_list
|
18
19
|
from magic_pdf.libs.hash_utils import compute_md5
|
19
20
|
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
|
20
21
|
from magic_pdf.model.magic_model import MagicModel
|
22
|
+
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
|
21
23
|
|
22
24
|
try:
|
23
25
|
import torchtext
|
@@ -28,15 +30,15 @@ except ImportError:
|
|
28
30
|
pass
|
29
31
|
|
30
32
|
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
31
|
-
from magic_pdf.
|
33
|
+
from magic_pdf.post_proc.para_split_v3 import para_split
|
32
34
|
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
|
33
35
|
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
|
34
36
|
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
|
35
37
|
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
|
36
|
-
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
|
38
|
+
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
|
39
|
+
remove_overlaps_min_spans, check_chars_is_overlap_in_span
|
37
40
|
|
38
41
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
39
|
-
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
40
42
|
|
41
43
|
|
42
44
|
def __replace_STX_ETX(text_str: str):
|
@@ -63,11 +65,22 @@ def __replace_0xfffd(text_str: str):
|
|
63
65
|
return s
|
64
66
|
return text_str
|
65
67
|
|
68
|
+
|
69
|
+
# 连写字符拆分
|
70
|
+
def __replace_ligatures(text: str):
|
71
|
+
ligatures = {
|
72
|
+
'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
|
73
|
+
}
|
74
|
+
return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
|
75
|
+
|
76
|
+
|
66
77
|
def chars_to_content(span):
|
67
78
|
# 检查span中的char是否为空
|
68
79
|
if len(span['chars']) == 0:
|
69
80
|
pass
|
70
81
|
# span['content'] = ''
|
82
|
+
elif check_chars_is_overlap_in_span(span['chars']):
|
83
|
+
pass
|
71
84
|
else:
|
72
85
|
# 先给chars按char['bbox']的中心点的x坐标排序
|
73
86
|
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
|
@@ -78,11 +91,16 @@ def chars_to_content(span):
|
|
78
91
|
|
79
92
|
content = ''
|
80
93
|
for char in span['chars']:
|
81
|
-
# 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
|
82
|
-
if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
|
83
|
-
content += ' '
|
84
|
-
content += char['c']
|
85
94
|
|
95
|
+
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
|
96
|
+
char1 = char
|
97
|
+
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
|
98
|
+
if char2 and char2['bbox'][0] - char1['bbox'][2] > char_avg_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
|
99
|
+
content += f"{char['c']} "
|
100
|
+
else:
|
101
|
+
content += char['c']
|
102
|
+
|
103
|
+
content = __replace_ligatures(content)
|
86
104
|
span['content'] = __replace_0xfffd(content)
|
87
105
|
|
88
106
|
del span['chars']
|
@@ -98,6 +116,10 @@ def fill_char_in_spans(spans, all_chars):
|
|
98
116
|
spans = sorted(spans, key=lambda x: x['bbox'][1])
|
99
117
|
|
100
118
|
for char in all_chars:
|
119
|
+
# 跳过非法bbox的char
|
120
|
+
x1, y1, x2, y2 = char['bbox']
|
121
|
+
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
|
122
|
+
continue
|
101
123
|
for span in spans:
|
102
124
|
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
|
103
125
|
span['chars'].append(char)
|
@@ -152,14 +174,16 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
|
|
152
174
|
|
153
175
|
|
154
176
|
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
|
177
|
+
# cid用0xfffd表示,连字符拆开
|
178
|
+
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
155
179
|
|
156
|
-
|
157
|
-
|
180
|
+
# cid用0xfffd表示,连字符不拆开
|
181
|
+
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
|
158
182
|
all_pymu_chars = []
|
159
183
|
for block in text_blocks_raw:
|
160
184
|
for line in block['lines']:
|
161
185
|
cosine, sine = line['dir']
|
162
|
-
if abs
|
186
|
+
if abs(cosine) < 0.9 or abs(sine) > 0.1:
|
163
187
|
continue
|
164
188
|
for span in line['spans']:
|
165
189
|
all_pymu_chars.extend(span['chars'])
|
@@ -255,19 +279,23 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
255
279
|
return spans
|
256
280
|
|
257
281
|
|
258
|
-
def replace_text_span(pymu_spans, ocr_spans):
|
259
|
-
return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
|
260
|
-
|
261
|
-
|
262
282
|
def model_init(model_name: str):
|
263
283
|
from transformers import LayoutLMv3ForTokenClassification
|
264
|
-
|
284
|
+
device = get_device()
|
265
285
|
if torch.cuda.is_available():
|
266
286
|
device = torch.device('cuda')
|
267
287
|
if torch.cuda.is_bf16_supported():
|
268
288
|
supports_bfloat16 = True
|
269
289
|
else:
|
270
290
|
supports_bfloat16 = False
|
291
|
+
elif str(device).startswith("npu"):
|
292
|
+
import torch_npu
|
293
|
+
if torch_npu.npu.is_available():
|
294
|
+
device = torch.device('npu')
|
295
|
+
supports_bfloat16 = False
|
296
|
+
else:
|
297
|
+
device = torch.device('cpu')
|
298
|
+
supports_bfloat16 = False
|
271
299
|
else:
|
272
300
|
device = torch.device('cpu')
|
273
301
|
supports_bfloat16 = False
|
@@ -345,6 +373,8 @@ def cal_block_index(fix_blocks, sorted_bboxes):
|
|
345
373
|
# 使用xycut排序
|
346
374
|
block_bboxes = []
|
347
375
|
for block in fix_blocks:
|
376
|
+
# 如果block['bbox']任意值小于0,将其置为0
|
377
|
+
block['bbox'] = [max(0, x) for x in block['bbox']]
|
348
378
|
block_bboxes.append(block['bbox'])
|
349
379
|
|
350
380
|
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
|
@@ -738,6 +768,11 @@ def parse_page_core(
|
|
738
768
|
"""重排block"""
|
739
769
|
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
|
740
770
|
|
771
|
+
"""block内重排(img和table的block内多个caption或footnote的排序)"""
|
772
|
+
for block in sorted_blocks:
|
773
|
+
if block['type'] in [BlockType.Image, BlockType.Table]:
|
774
|
+
block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
|
775
|
+
|
741
776
|
"""获取QA需要外置的list"""
|
742
777
|
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
|
743
778
|
|
@@ -819,13 +854,29 @@ def pdf_parse_union(
|
|
819
854
|
"""分段"""
|
820
855
|
para_split(pdf_info_dict)
|
821
856
|
|
857
|
+
"""llm优化"""
|
858
|
+
llm_aided_config = get_llm_aided_config()
|
859
|
+
if llm_aided_config is not None:
|
860
|
+
"""公式优化"""
|
861
|
+
formula_aided_config = llm_aided_config.get('formula_aided', None)
|
862
|
+
if formula_aided_config is not None:
|
863
|
+
llm_aided_formula(pdf_info_dict, formula_aided_config)
|
864
|
+
"""文本优化"""
|
865
|
+
text_aided_config = llm_aided_config.get('text_aided', None)
|
866
|
+
if text_aided_config is not None:
|
867
|
+
llm_aided_text(pdf_info_dict, text_aided_config)
|
868
|
+
"""标题优化"""
|
869
|
+
title_aided_config = llm_aided_config.get('title_aided', None)
|
870
|
+
if title_aided_config is not None:
|
871
|
+
llm_aided_title(pdf_info_dict, title_aided_config)
|
872
|
+
|
822
873
|
"""dict转list"""
|
823
874
|
pdf_info_list = dict_to_list(pdf_info_dict)
|
824
875
|
new_pdf_info_dict = {
|
825
876
|
'pdf_info': pdf_info_list,
|
826
877
|
}
|
827
878
|
|
828
|
-
clean_memory()
|
879
|
+
clean_memory(get_device())
|
829
880
|
|
830
881
|
return new_pdf_info_dict
|
831
882
|
|
@@ -0,0 +1 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
@@ -0,0 +1,133 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
2
|
+
import json
|
3
|
+
from loguru import logger
|
4
|
+
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
|
5
|
+
from openai import OpenAI
|
6
|
+
|
7
|
+
|
8
|
+
#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
|
9
|
+
formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
|
10
|
+
|
11
|
+
1. 修正渲染或编译错误:
|
12
|
+
- Some syntax errors such as mismatched/missing/extra tokens. Your task is to fix these syntax errors and make sure corrected results conform to latex math syntax principles.
|
13
|
+
- 包含KaTeX不支持的关键词等原因导致的无法编译或渲染的错误
|
14
|
+
|
15
|
+
2. 保留原始信息:
|
16
|
+
- 保留原始公式中的所有重要信息
|
17
|
+
- 不要添加任何原始公式中没有的新信息
|
18
|
+
|
19
|
+
IMPORTANT:请仅返回修正后的公式,不要包含任何介绍、解释或元数据。
|
20
|
+
|
21
|
+
LaTeX recognition result:
|
22
|
+
$FORMULA
|
23
|
+
|
24
|
+
Your corrected result:
|
25
|
+
"""
|
26
|
+
|
27
|
+
text_optimize_prompt = f"""请根据以下指南修正OCR引起的错误,确保文本连贯并符合原始内容:
|
28
|
+
|
29
|
+
1. 修正OCR引起的拼写错误和错误:
|
30
|
+
- 修正常见的OCR错误(例如,'rn' 被误读为 'm')
|
31
|
+
- 使用上下文和常识进行修正
|
32
|
+
- 只修正明显的错误,不要不必要的修改内容
|
33
|
+
- 不要添加额外的句号或其他不必要的标点符号
|
34
|
+
|
35
|
+
2. 保持原始结构:
|
36
|
+
- 保留所有标题和子标题
|
37
|
+
|
38
|
+
3. 保留原始内容:
|
39
|
+
- 保留原始文本中的所有重要信息
|
40
|
+
- 不要添加任何原始文本中没有的新信息
|
41
|
+
- 保留段落之间的换行符
|
42
|
+
|
43
|
+
4. 保持连贯性:
|
44
|
+
- 确保内容与前文顺畅连接
|
45
|
+
- 适当处理在句子中间开始或结束的文本
|
46
|
+
|
47
|
+
5. 修正行内公式:
|
48
|
+
- 去除行内公式前后多余的空格
|
49
|
+
- 修正公式中的OCR错误
|
50
|
+
- 确保公式能够通过KaTeX渲染
|
51
|
+
|
52
|
+
6. 修正全角字符
|
53
|
+
- 修正全角标点符号为半角标点符号
|
54
|
+
- 修正全角字母为半角字母
|
55
|
+
- 修正全角数字为半角数字
|
56
|
+
|
57
|
+
IMPORTANT:请仅返回修正后的文本,保留所有原始格式,包括换行符。不要包含任何介绍、解释或元数据。
|
58
|
+
|
59
|
+
Previous context:
|
60
|
+
|
61
|
+
Current chunk to process:
|
62
|
+
|
63
|
+
Corrected text:
|
64
|
+
"""
|
65
|
+
|
66
|
+
def llm_aided_formula(pdf_info_dict, formula_aided_config):
|
67
|
+
pass
|
68
|
+
|
69
|
+
def llm_aided_text(pdf_info_dict, text_aided_config):
|
70
|
+
pass
|
71
|
+
|
72
|
+
def llm_aided_title(pdf_info_dict, title_aided_config):
|
73
|
+
client = OpenAI(
|
74
|
+
api_key=title_aided_config["api_key"],
|
75
|
+
base_url=title_aided_config["base_url"],
|
76
|
+
)
|
77
|
+
title_dict = {}
|
78
|
+
origin_title_list = []
|
79
|
+
i = 0
|
80
|
+
for page_num, page in pdf_info_dict.items():
|
81
|
+
blocks = page["para_blocks"]
|
82
|
+
for block in blocks:
|
83
|
+
if block["type"] == "title":
|
84
|
+
origin_title_list.append(block)
|
85
|
+
title_text = merge_para_with_text(block)
|
86
|
+
title_dict[f"{i}"] = title_text
|
87
|
+
i += 1
|
88
|
+
# logger.info(f"Title list: {title_dict}")
|
89
|
+
|
90
|
+
title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
|
91
|
+
|
92
|
+
1. 保留原始内容:
|
93
|
+
- 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
|
94
|
+
- 请务必保证输出的字典中元素的数量和输入的数量一致
|
95
|
+
|
96
|
+
2. 保持字典内key-value的对应关系不变
|
97
|
+
|
98
|
+
3. 优化层次结构:
|
99
|
+
- 为每个标题元素添加适当的层次结构
|
100
|
+
- 标题层级应具有连续性,不能跳过某一层级
|
101
|
+
- 标题层级最多为4级,不要添加过多的层级
|
102
|
+
- 优化后的标题为一个整数,代表该标题的层级
|
103
|
+
|
104
|
+
IMPORTANT:
|
105
|
+
请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。
|
106
|
+
|
107
|
+
Input title list:
|
108
|
+
{title_dict}
|
109
|
+
|
110
|
+
Corrected title list:
|
111
|
+
"""
|
112
|
+
|
113
|
+
completion = client.chat.completions.create(
|
114
|
+
model=title_aided_config["model"],
|
115
|
+
messages=[
|
116
|
+
{'role': 'user', 'content': title_optimize_prompt}],
|
117
|
+
temperature=0.7,
|
118
|
+
)
|
119
|
+
|
120
|
+
json_completion = json.loads(completion.choices[0].message.content)
|
121
|
+
|
122
|
+
# logger.info(f"Title completion: {json_completion}")
|
123
|
+
|
124
|
+
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
|
125
|
+
if len(json_completion) == len(title_dict):
|
126
|
+
try:
|
127
|
+
for i, origin_title_block in enumerate(origin_title_list):
|
128
|
+
origin_title_block["level"] = int(json_completion[str(i)])
|
129
|
+
except Exception as e:
|
130
|
+
logger.exception(e)
|
131
|
+
else:
|
132
|
+
logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
|
133
|
+
|