magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/data/batch_build_dataset.py +156 -0
- magic_pdf/data/dataset.py +44 -24
- magic_pdf/data/utils.py +108 -9
- magic_pdf/dict2md/ocr_mkcontent.py +4 -3
- magic_pdf/libs/pdf_image_tools.py +11 -6
- magic_pdf/libs/performance_stats.py +12 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +175 -201
- magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
- magic_pdf/model/pdf_extract_kit.py +5 -38
- magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
- magic_pdf/model/sub_modules/model_init.py +50 -37
- magic_pdf/model/sub_modules/model_utils.py +17 -11
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
- magic_pdf/pdf_parse_union_core_v2.py +112 -74
- magic_pdf/post_proc/para_split_v3.py +16 -13
- magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
- magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
- magic_pdf/resources/model_config/model_configs.yaml +1 -1
- magic_pdf/tools/cli.py +30 -12
- magic_pdf/tools/common.py +90 -12
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
- magic_pdf-1.3.0.dist-info/RECORD +202 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
- magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
- magic_pdf-1.2.1.dist-info/RECORD +0 -147
- /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
- /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
- /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
|
|
1
|
+
import os
|
2
|
+
import math
|
3
|
+
from pathlib import Path
|
4
|
+
import numpy as np
|
5
|
+
import cv2
|
6
|
+
import argparse
|
7
|
+
|
8
|
+
|
9
|
+
root_dir = Path(__file__).resolve().parent.parent.parent
|
10
|
+
DEFAULT_CFG_PATH = root_dir / "pytorchocr" / "utils" / "resources" / "arch_config.yaml"
|
11
|
+
|
12
|
+
|
13
|
+
def init_args():
|
14
|
+
def str2bool(v):
|
15
|
+
return v.lower() in ("true", "t", "1")
|
16
|
+
|
17
|
+
parser = argparse.ArgumentParser()
|
18
|
+
# params for prediction engine
|
19
|
+
parser.add_argument("--use_gpu", type=str2bool, default=False)
|
20
|
+
parser.add_argument("--det", type=str2bool, default=True)
|
21
|
+
parser.add_argument("--rec", type=str2bool, default=True)
|
22
|
+
parser.add_argument("--device", type=str, default='cpu')
|
23
|
+
# parser.add_argument("--ir_optim", type=str2bool, default=True)
|
24
|
+
# parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
25
|
+
# parser.add_argument("--use_fp16", type=str2bool, default=False)
|
26
|
+
parser.add_argument("--gpu_mem", type=int, default=500)
|
27
|
+
parser.add_argument("--warmup", type=str2bool, default=False)
|
28
|
+
|
29
|
+
# params for text detector
|
30
|
+
parser.add_argument("--image_dir", type=str)
|
31
|
+
parser.add_argument("--det_algorithm", type=str, default='DB')
|
32
|
+
parser.add_argument("--det_model_path", type=str)
|
33
|
+
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
34
|
+
parser.add_argument("--det_limit_type", type=str, default='max')
|
35
|
+
|
36
|
+
# DB parmas
|
37
|
+
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
38
|
+
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
|
39
|
+
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
|
40
|
+
parser.add_argument("--max_batch_size", type=int, default=10)
|
41
|
+
parser.add_argument("--use_dilation", type=str2bool, default=False)
|
42
|
+
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
43
|
+
|
44
|
+
# EAST parmas
|
45
|
+
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
46
|
+
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
47
|
+
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
48
|
+
|
49
|
+
# SAST parmas
|
50
|
+
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
51
|
+
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
52
|
+
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
|
53
|
+
|
54
|
+
# PSE parmas
|
55
|
+
parser.add_argument("--det_pse_thresh", type=float, default=0)
|
56
|
+
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
|
57
|
+
parser.add_argument("--det_pse_min_area", type=float, default=16)
|
58
|
+
parser.add_argument("--det_pse_box_type", type=str, default='box')
|
59
|
+
parser.add_argument("--det_pse_scale", type=int, default=1)
|
60
|
+
|
61
|
+
# FCE parmas
|
62
|
+
parser.add_argument("--scales", type=list, default=[8, 16, 32])
|
63
|
+
parser.add_argument("--alpha", type=float, default=1.0)
|
64
|
+
parser.add_argument("--beta", type=float, default=1.0)
|
65
|
+
parser.add_argument("--fourier_degree", type=int, default=5)
|
66
|
+
parser.add_argument("--det_fce_box_type", type=str, default='poly')
|
67
|
+
|
68
|
+
# params for text recognizer
|
69
|
+
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
70
|
+
parser.add_argument("--rec_model_path", type=str)
|
71
|
+
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
|
72
|
+
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
|
73
|
+
parser.add_argument("--rec_char_type", type=str, default='ch')
|
74
|
+
parser.add_argument("--rec_batch_num", type=int, default=6)
|
75
|
+
parser.add_argument("--max_text_length", type=int, default=25)
|
76
|
+
|
77
|
+
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
78
|
+
parser.add_argument("--drop_score", type=float, default=0.5)
|
79
|
+
parser.add_argument("--limited_max_width", type=int, default=1280)
|
80
|
+
parser.add_argument("--limited_min_width", type=int, default=16)
|
81
|
+
|
82
|
+
parser.add_argument(
|
83
|
+
"--vis_font_path", type=str,
|
84
|
+
default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'doc/fonts/simfang.ttf'))
|
85
|
+
parser.add_argument(
|
86
|
+
"--rec_char_dict_path",
|
87
|
+
type=str,
|
88
|
+
default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
89
|
+
'pytorchocr/utils/ppocr_keys_v1.txt'))
|
90
|
+
|
91
|
+
# params for text classifier
|
92
|
+
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
93
|
+
parser.add_argument("--cls_model_path", type=str)
|
94
|
+
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
95
|
+
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
96
|
+
parser.add_argument("--cls_batch_num", type=int, default=6)
|
97
|
+
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
98
|
+
|
99
|
+
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
100
|
+
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
101
|
+
|
102
|
+
# params for e2e
|
103
|
+
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
|
104
|
+
parser.add_argument("--e2e_model_path", type=str)
|
105
|
+
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
|
106
|
+
parser.add_argument("--e2e_limit_type", type=str, default='max')
|
107
|
+
|
108
|
+
# PGNet parmas
|
109
|
+
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
|
110
|
+
parser.add_argument(
|
111
|
+
"--e2e_char_dict_path", type=str,
|
112
|
+
default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
113
|
+
'pytorchocr/utils/ic15_dict.txt'))
|
114
|
+
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
|
115
|
+
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
|
116
|
+
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
|
117
|
+
|
118
|
+
# SR parmas
|
119
|
+
parser.add_argument("--sr_model_path", type=str)
|
120
|
+
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
|
121
|
+
parser.add_argument("--sr_batch_num", type=int, default=1)
|
122
|
+
|
123
|
+
# params .yaml
|
124
|
+
parser.add_argument("--det_yaml_path", type=str, default=None)
|
125
|
+
parser.add_argument("--rec_yaml_path", type=str, default=None)
|
126
|
+
parser.add_argument("--cls_yaml_path", type=str, default=None)
|
127
|
+
parser.add_argument("--e2e_yaml_path", type=str, default=None)
|
128
|
+
parser.add_argument("--sr_yaml_path", type=str, default=None)
|
129
|
+
|
130
|
+
# multi-process
|
131
|
+
parser.add_argument("--use_mp", type=str2bool, default=False)
|
132
|
+
parser.add_argument("--total_process_num", type=int, default=1)
|
133
|
+
parser.add_argument("--process_id", type=int, default=0)
|
134
|
+
|
135
|
+
parser.add_argument("--benchmark", type=str2bool, default=False)
|
136
|
+
parser.add_argument("--save_log_path", type=str, default="./log_output/")
|
137
|
+
|
138
|
+
parser.add_argument("--show_log", type=str2bool, default=True)
|
139
|
+
|
140
|
+
return parser
|
141
|
+
|
142
|
+
def parse_args():
|
143
|
+
parser = init_args()
|
144
|
+
return parser.parse_args()
|
145
|
+
|
146
|
+
def get_default_config(args):
|
147
|
+
return vars(args)
|
148
|
+
|
149
|
+
|
150
|
+
def read_network_config_from_yaml(yaml_path, char_num=None):
|
151
|
+
if not os.path.exists(yaml_path):
|
152
|
+
raise FileNotFoundError('{} is not existed.'.format(yaml_path))
|
153
|
+
import yaml
|
154
|
+
with open(yaml_path, encoding='utf-8') as f:
|
155
|
+
res = yaml.safe_load(f)
|
156
|
+
if res.get('Architecture') is None:
|
157
|
+
raise ValueError('{} has no Architecture'.format(yaml_path))
|
158
|
+
if res['Architecture']['Head']['name'] == 'MultiHead' and char_num is not None:
|
159
|
+
res['Architecture']['Head']['out_channels_list'] = {
|
160
|
+
'CTCLabelDecode': char_num,
|
161
|
+
'SARLabelDecode': char_num + 2,
|
162
|
+
'NRTRLabelDecode': char_num + 3
|
163
|
+
}
|
164
|
+
return res['Architecture']
|
165
|
+
|
166
|
+
def AnalysisConfig(weights_path, yaml_path=None, char_num=None):
|
167
|
+
if not os.path.exists(os.path.abspath(weights_path)):
|
168
|
+
raise FileNotFoundError('{} is not found.'.format(weights_path))
|
169
|
+
|
170
|
+
if yaml_path is not None:
|
171
|
+
return read_network_config_from_yaml(yaml_path, char_num=char_num)
|
172
|
+
|
173
|
+
|
174
|
+
def resize_img(img, input_size=600):
|
175
|
+
"""
|
176
|
+
resize img and limit the longest side of the image to input_size
|
177
|
+
"""
|
178
|
+
img = np.array(img)
|
179
|
+
im_shape = img.shape
|
180
|
+
im_size_max = np.max(im_shape[0:2])
|
181
|
+
im_scale = float(input_size) / float(im_size_max)
|
182
|
+
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
|
183
|
+
return img
|
184
|
+
|
185
|
+
|
186
|
+
def str_count(s):
|
187
|
+
"""
|
188
|
+
Count the number of Chinese characters,
|
189
|
+
a single English character and a single number
|
190
|
+
equal to half the length of Chinese characters.
|
191
|
+
args:
|
192
|
+
s(string): the input of string
|
193
|
+
return(int):
|
194
|
+
the number of Chinese characters
|
195
|
+
"""
|
196
|
+
import string
|
197
|
+
count_zh = count_pu = 0
|
198
|
+
s_len = len(s)
|
199
|
+
en_dg_count = 0
|
200
|
+
for c in s:
|
201
|
+
if c in string.ascii_letters or c.isdigit() or c.isspace():
|
202
|
+
en_dg_count += 1
|
203
|
+
elif c.isalpha():
|
204
|
+
count_zh += 1
|
205
|
+
else:
|
206
|
+
count_pu += 1
|
207
|
+
return s_len - math.ceil(en_dg_count / 2)
|
208
|
+
|
209
|
+
|
210
|
+
def base64_to_cv2(b64str):
|
211
|
+
import base64
|
212
|
+
data = base64.b64decode(b64str.encode('utf8'))
|
213
|
+
data = np.fromstring(data, np.uint8)
|
214
|
+
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
215
|
+
return data
|
216
|
+
|
217
|
+
|
218
|
+
def get_arch_config(model_path):
|
219
|
+
from omegaconf import OmegaConf
|
220
|
+
all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
|
221
|
+
path = Path(model_path)
|
222
|
+
file_name = path.stem
|
223
|
+
if file_name not in all_arch_config:
|
224
|
+
raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
|
225
|
+
|
226
|
+
arch_config = all_arch_config[file_name]
|
227
|
+
return arch_config
|
@@ -9,7 +9,7 @@ from magic_pdf.libs.config_reader import get_device
|
|
9
9
|
|
10
10
|
|
11
11
|
class RapidTableModel(object):
|
12
|
-
def __init__(self, ocr_engine, table_sub_model_name):
|
12
|
+
def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
|
13
13
|
sub_model_list = [model.value for model in ModelType]
|
14
14
|
if table_sub_model_name is None:
|
15
15
|
input_args = RapidTableInput()
|
@@ -23,25 +23,17 @@ class RapidTableModel(object):
|
|
23
23
|
|
24
24
|
self.table_model = RapidTable(input_args)
|
25
25
|
|
26
|
-
#
|
27
|
-
#
|
28
|
-
#
|
29
|
-
#
|
30
|
-
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
31
|
-
# else:
|
32
|
-
# from rapidocr_onnxruntime import RapidOCR
|
33
|
-
# self.ocr_engine = RapidOCR()
|
26
|
+
# self.ocr_model_name = "RapidOCR"
|
27
|
+
# if torch.cuda.is_available():
|
28
|
+
# from rapidocr_paddle import RapidOCR
|
29
|
+
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
34
30
|
# else:
|
35
|
-
#
|
36
|
-
# self.ocr_engine =
|
31
|
+
# from rapidocr_onnxruntime import RapidOCR
|
32
|
+
# self.ocr_engine = RapidOCR()
|
33
|
+
|
34
|
+
self.ocr_model_name = "PaddleOCR"
|
35
|
+
self.ocr_engine = ocr_engine
|
37
36
|
|
38
|
-
self.ocr_model_name = "RapidOCR"
|
39
|
-
if torch.cuda.is_available():
|
40
|
-
from rapidocr_paddle import RapidOCR
|
41
|
-
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
42
|
-
else:
|
43
|
-
from rapidocr_onnxruntime import RapidOCR
|
44
|
-
self.ocr_engine = RapidOCR()
|
45
37
|
|
46
38
|
def predict(self, image):
|
47
39
|
|
@@ -4,6 +4,7 @@ import os
|
|
4
4
|
import re
|
5
5
|
import statistics
|
6
6
|
import time
|
7
|
+
import warnings
|
7
8
|
from typing import List
|
8
9
|
|
9
10
|
import cv2
|
@@ -11,6 +12,7 @@ import fitz
|
|
11
12
|
import torch
|
12
13
|
import numpy as np
|
13
14
|
from loguru import logger
|
15
|
+
from tqdm import tqdm
|
14
16
|
|
15
17
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
16
18
|
from magic_pdf.config.ocr_content_type import BlockType, ContentType
|
@@ -21,20 +23,9 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
|
|
21
23
|
from magic_pdf.libs.convert_utils import dict_to_list
|
22
24
|
from magic_pdf.libs.hash_utils import compute_md5
|
23
25
|
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
|
24
|
-
from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
|
25
26
|
from magic_pdf.model.magic_model import MagicModel
|
26
27
|
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
|
27
28
|
|
28
|
-
from concurrent.futures import ThreadPoolExecutor
|
29
|
-
|
30
|
-
try:
|
31
|
-
import torchtext
|
32
|
-
|
33
|
-
if torchtext.__version__ >= '0.18.0':
|
34
|
-
torchtext.disable_torchtext_deprecation_warning()
|
35
|
-
except ImportError:
|
36
|
-
pass
|
37
|
-
|
38
29
|
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
39
30
|
from magic_pdf.post_proc.para_split_v3 import para_split
|
40
31
|
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
|
@@ -42,7 +33,7 @@ from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
|
|
42
33
|
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
|
43
34
|
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
|
44
35
|
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
|
45
|
-
remove_overlaps_min_spans,
|
36
|
+
remove_overlaps_min_spans, remove_x_overlapping_chars
|
46
37
|
|
47
38
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
48
39
|
|
@@ -64,14 +55,6 @@ def __replace_STX_ETX(text_str: str):
|
|
64
55
|
return text_str
|
65
56
|
|
66
57
|
|
67
|
-
def __replace_0xfffd(text_str: str):
|
68
|
-
"""Replace \ufffd, as these characters become garbled when extracted using pymupdf."""
|
69
|
-
if text_str:
|
70
|
-
s = text_str.replace('\ufffd', " ")
|
71
|
-
return s
|
72
|
-
return text_str
|
73
|
-
|
74
|
-
|
75
58
|
# 连写字符拆分
|
76
59
|
def __replace_ligatures(text: str):
|
77
60
|
ligatures = {
|
@@ -84,16 +67,17 @@ def chars_to_content(span):
|
|
84
67
|
# 检查span中的char是否为空
|
85
68
|
if len(span['chars']) == 0:
|
86
69
|
pass
|
87
|
-
# span['content'] = ''
|
88
|
-
elif check_chars_is_overlap_in_span(span['chars']):
|
89
|
-
pass
|
90
70
|
else:
|
91
71
|
# 先给chars按char['bbox']的中心点的x坐标排序
|
92
72
|
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
|
93
73
|
|
94
|
-
#
|
95
|
-
|
96
|
-
|
74
|
+
# Calculate the width of each character
|
75
|
+
char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
|
76
|
+
# Calculate the median width
|
77
|
+
median_width = statistics.median(char_widths)
|
78
|
+
|
79
|
+
# 通过x轴重叠比率移除一部分char
|
80
|
+
span = remove_x_overlapping_chars(span, median_width)
|
97
81
|
|
98
82
|
content = ''
|
99
83
|
for char in span['chars']:
|
@@ -101,13 +85,12 @@ def chars_to_content(span):
|
|
101
85
|
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
|
102
86
|
char1 = char
|
103
87
|
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
|
104
|
-
if char2 and char2['bbox'][0] - char1['bbox'][2] >
|
88
|
+
if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
|
105
89
|
content += f"{char['c']} "
|
106
90
|
else:
|
107
91
|
content += char['c']
|
108
92
|
|
109
|
-
content = __replace_ligatures(content)
|
110
|
-
span['content'] = __replace_0xfffd(content)
|
93
|
+
span['content'] = __replace_ligatures(content)
|
111
94
|
|
112
95
|
del span['chars']
|
113
96
|
|
@@ -122,10 +105,6 @@ def fill_char_in_spans(spans, all_chars):
|
|
122
105
|
spans = sorted(spans, key=lambda x: x['bbox'][1])
|
123
106
|
|
124
107
|
for char in all_chars:
|
125
|
-
# 跳过非法bbox的char
|
126
|
-
# x1, y1, x2, y2 = char['bbox']
|
127
|
-
# if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
|
128
|
-
# continue
|
129
108
|
|
130
109
|
for span in spans:
|
131
110
|
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
|
@@ -215,7 +194,7 @@ def calculate_contrast(img, img_mode) -> float:
|
|
215
194
|
std_dev = np.std(gray_img)
|
216
195
|
# 对比度定义为标准差除以平均值(加上小常数避免除零错误)
|
217
196
|
contrast = std_dev / (mean_value + 1e-6)
|
218
|
-
# logger.
|
197
|
+
# logger.debug(f"contrast: {contrast}")
|
219
198
|
return round(contrast, 2)
|
220
199
|
|
221
200
|
# @measure_time
|
@@ -308,41 +287,53 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
308
287
|
if len(need_ocr_spans) > 0:
|
309
288
|
|
310
289
|
# 初始化ocr模型
|
311
|
-
atom_model_manager = AtomModelSingleton()
|
312
|
-
ocr_model = atom_model_manager.get_atom_model(
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
)
|
290
|
+
# atom_model_manager = AtomModelSingleton()
|
291
|
+
# ocr_model = atom_model_manager.get_atom_model(
|
292
|
+
# atom_model_name='ocr',
|
293
|
+
# ocr_show_log=False,
|
294
|
+
# det_db_box_thresh=0.3,
|
295
|
+
# lang=lang
|
296
|
+
# )
|
318
297
|
|
319
298
|
for span in need_ocr_spans:
|
320
299
|
# 对span的bbox截图再ocr
|
321
300
|
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
|
322
301
|
|
323
302
|
# 计算span的对比度,低于0.20的span不进行ocr
|
324
|
-
if calculate_contrast(span_img, img_mode='bgr') <= 0.
|
303
|
+
if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
|
325
304
|
spans.remove(span)
|
326
305
|
continue
|
306
|
+
# pass
|
307
|
+
|
308
|
+
span['content'] = ''
|
309
|
+
span['score'] = 1
|
310
|
+
span['np_img'] = span_img
|
327
311
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
312
|
+
|
313
|
+
# ocr_res = ocr_model.ocr(span_img, det=False)
|
314
|
+
# if ocr_res and len(ocr_res) > 0:
|
315
|
+
# if len(ocr_res[0]) > 0:
|
316
|
+
# ocr_text, ocr_score = ocr_res[0][0]
|
317
|
+
# # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
|
318
|
+
# if ocr_score > 0.5 and len(ocr_text) > 0:
|
319
|
+
# span['content'] = ocr_text
|
320
|
+
# span['score'] = float(round(ocr_score, 2))
|
321
|
+
# else:
|
322
|
+
# spans.remove(span)
|
338
323
|
|
339
324
|
return spans
|
340
325
|
|
341
326
|
|
342
327
|
def model_init(model_name: str):
|
343
328
|
from transformers import LayoutLMv3ForTokenClassification
|
344
|
-
|
345
|
-
|
329
|
+
device_name = get_device()
|
330
|
+
bf_16_support = False
|
331
|
+
if device_name.startswith("cuda"):
|
332
|
+
bf_16_support = torch.cuda.is_bf16_supported()
|
333
|
+
elif device_name.startswith("mps"):
|
334
|
+
bf_16_support = True
|
335
|
+
|
336
|
+
device = torch.device(device_name)
|
346
337
|
if model_name == 'layoutreader':
|
347
338
|
# 检测modelscope的缓存目录是否存在
|
348
339
|
layoutreader_model_dir = get_local_layoutreader_model_dir()
|
@@ -357,7 +348,10 @@ def model_init(model_name: str):
|
|
357
348
|
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
358
349
|
'hantian/layoutreader'
|
359
350
|
)
|
360
|
-
|
351
|
+
if bf_16_support:
|
352
|
+
model.to(device).eval().bfloat16()
|
353
|
+
else:
|
354
|
+
model.to(device).eval()
|
361
355
|
else:
|
362
356
|
logger.error('model name not allow')
|
363
357
|
exit(1)
|
@@ -383,9 +377,12 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
|
|
383
377
|
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (
|
384
378
|
boxes2inputs, parse_logits, prepare_inputs)
|
385
379
|
|
386
|
-
|
387
|
-
|
388
|
-
|
380
|
+
with warnings.catch_warnings():
|
381
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
|
382
|
+
|
383
|
+
inputs = boxes2inputs(boxes)
|
384
|
+
inputs = prepare_inputs(inputs, model)
|
385
|
+
logits = model(**inputs).logits.cpu().squeeze(0)
|
389
386
|
return parse_logits(logits, len(boxes))
|
390
387
|
|
391
388
|
|
@@ -463,20 +460,20 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
463
460
|
if (
|
464
461
|
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
|
465
462
|
): # 可能是双列结构,可以切细点
|
466
|
-
lines = int(block_height / line_height)
|
463
|
+
lines = int(block_height / line_height)
|
467
464
|
else:
|
468
465
|
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
|
469
466
|
if block_weight > page_w * 0.4:
|
470
467
|
lines = 3
|
471
|
-
line_height = (y1 - y0) / lines
|
472
468
|
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
|
473
|
-
lines = int(block_height / line_height)
|
469
|
+
lines = int(block_height / line_height)
|
474
470
|
else: # 判断长宽比
|
475
471
|
if block_height / block_weight > 1.2: # 细长的不分
|
476
472
|
return [[x0, y0, x1, y1]]
|
477
473
|
else: # 不细长的还是分成两行
|
478
474
|
lines = 2
|
479
|
-
|
475
|
+
|
476
|
+
line_height = (y1 - y0) / lines
|
480
477
|
|
481
478
|
# 确定从哪个y位置开始绘制线条
|
482
479
|
current_y = y0
|
@@ -492,7 +489,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
492
489
|
else:
|
493
490
|
return [[x0, y0, x1, y1]]
|
494
491
|
|
495
|
-
|
492
|
+
|
496
493
|
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
|
497
494
|
page_line_list = []
|
498
495
|
|
@@ -936,17 +933,18 @@ def pdf_parse_union(
|
|
936
933
|
logger.warning('end_page_id is out of range, use pdf_docs length')
|
937
934
|
end_page_id = len(dataset) - 1
|
938
935
|
|
939
|
-
"""初始化启动时间"""
|
940
|
-
start_time = time.time()
|
936
|
+
# """初始化启动时间"""
|
937
|
+
# start_time = time.time()
|
941
938
|
|
942
|
-
for page_id, page in enumerate(dataset):
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
)
|
949
|
-
|
939
|
+
# for page_id, page in enumerate(dataset):
|
940
|
+
for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
|
941
|
+
# """debug时输出每页解析的耗时."""
|
942
|
+
# if debug_mode:
|
943
|
+
# time_now = time.time()
|
944
|
+
# logger.info(
|
945
|
+
# f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
|
946
|
+
# )
|
947
|
+
# start_time = time_now
|
950
948
|
|
951
949
|
"""解析pdf中的每一页"""
|
952
950
|
if start_page_id <= page_id <= end_page_id:
|
@@ -962,7 +960,47 @@ def pdf_parse_union(
|
|
962
960
|
)
|
963
961
|
pdf_info_dict[f'page_{page_id}'] = page_info
|
964
962
|
|
965
|
-
|
963
|
+
need_ocr_list = []
|
964
|
+
img_crop_list = []
|
965
|
+
text_block_list = []
|
966
|
+
for pange_id, page_info in pdf_info_dict.items():
|
967
|
+
for block in page_info['preproc_blocks']:
|
968
|
+
if block['type'] in ['table', 'image']:
|
969
|
+
for sub_block in block['blocks']:
|
970
|
+
if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
|
971
|
+
text_block_list.append(sub_block)
|
972
|
+
elif block['type'] in ['text', 'title']:
|
973
|
+
text_block_list.append(block)
|
974
|
+
for block in page_info['discarded_blocks']:
|
975
|
+
text_block_list.append(block)
|
976
|
+
for block in text_block_list:
|
977
|
+
for line in block['lines']:
|
978
|
+
for span in line['spans']:
|
979
|
+
if 'np_img' in span:
|
980
|
+
need_ocr_list.append(span)
|
981
|
+
img_crop_list.append(span['np_img'])
|
982
|
+
span.pop('np_img')
|
983
|
+
if len(img_crop_list) > 0:
|
984
|
+
# Get OCR results for this language's images
|
985
|
+
atom_model_manager = AtomModelSingleton()
|
986
|
+
ocr_model = atom_model_manager.get_atom_model(
|
987
|
+
atom_model_name='ocr',
|
988
|
+
ocr_show_log=False,
|
989
|
+
det_db_box_thresh=0.3,
|
990
|
+
lang=lang
|
991
|
+
)
|
992
|
+
# rec_start = time.time()
|
993
|
+
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
|
994
|
+
# Verify we have matching counts
|
995
|
+
assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
|
996
|
+
# Process OCR results for this language
|
997
|
+
for index, span in enumerate(need_ocr_list):
|
998
|
+
ocr_text, ocr_score = ocr_res_list[index]
|
999
|
+
span['content'] = ocr_text
|
1000
|
+
span['score'] = float(round(ocr_score, 2))
|
1001
|
+
# rec_time = time.time() - rec_start
|
1002
|
+
# logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}')
|
1003
|
+
|
966
1004
|
|
967
1005
|
"""分段"""
|
968
1006
|
para_split(pdf_info_dict)
|
@@ -108,29 +108,32 @@ def __is_list_or_index_block(block):
|
|
108
108
|
):
|
109
109
|
multiple_para_flag = True
|
110
110
|
|
111
|
-
|
112
|
-
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
|
113
|
-
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
|
114
|
-
if (
|
115
|
-
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
|
116
|
-
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
|
117
|
-
):
|
118
|
-
external_sides_not_close_num += 1
|
119
|
-
if abs(line_mid_x - block_mid_x) < line_height / 2:
|
120
|
-
center_close_num += 1
|
111
|
+
block_text = ''
|
121
112
|
|
113
|
+
for line in block['lines']:
|
122
114
|
line_text = ''
|
123
115
|
|
124
116
|
for span in line['spans']:
|
125
117
|
span_type = span['type']
|
126
118
|
if span_type == ContentType.Text:
|
127
119
|
line_text += span['content'].strip()
|
128
|
-
|
129
120
|
# 添加所有文本,包括空行,保持与block['lines']长度一致
|
130
121
|
lines_text_list.append(line_text)
|
131
122
|
block_text = ''.join(lines_text_list)
|
132
|
-
|
133
|
-
|
123
|
+
|
124
|
+
block_lang = detect_lang(block_text)
|
125
|
+
# logger.info(f"block_lang: {block_lang}")
|
126
|
+
|
127
|
+
for line in block['lines']:
|
128
|
+
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
|
129
|
+
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
|
130
|
+
if (
|
131
|
+
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
|
132
|
+
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
|
133
|
+
):
|
134
|
+
external_sides_not_close_num += 1
|
135
|
+
if abs(line_mid_x - block_mid_x) < line_height / 2:
|
136
|
+
center_close_num += 1
|
134
137
|
|
135
138
|
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
|
136
139
|
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
|
@@ -62,7 +62,15 @@ def merge_spans_to_line(spans, threshold=0.6):
|
|
62
62
|
|
63
63
|
def span_block_type_compatible(span_type, block_type):
|
64
64
|
if span_type in [ContentType.Text, ContentType.InlineEquation]:
|
65
|
-
return block_type in [
|
65
|
+
return block_type in [
|
66
|
+
BlockType.Text,
|
67
|
+
BlockType.Title,
|
68
|
+
BlockType.ImageCaption,
|
69
|
+
BlockType.ImageFootnote,
|
70
|
+
BlockType.TableCaption,
|
71
|
+
BlockType.TableFootnote,
|
72
|
+
BlockType.Discarded
|
73
|
+
]
|
66
74
|
elif span_type == ContentType.InterlineEquation:
|
67
75
|
return block_type in [BlockType.InterlineEquation, BlockType.Text]
|
68
76
|
elif span_type == ContentType.Image:
|