magic-pdf 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +56 -25
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +142 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +18 -12
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +15 -19
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/resources/slanet_plus/slanet-plus.onnx +0 -0
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/METADATA +92 -59
  88. magic_pdf-1.3.1.dist-info/RECORD +203 -0
  89. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/WHEEL +1 -1
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  91. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  92. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  93. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  94. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  95. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  96. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  97. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  98. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  99. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/LICENSE.md +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
1
+ import os
2
+ import math
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import cv2
6
+ import argparse
7
+
8
+
9
+ root_dir = Path(__file__).resolve().parent.parent.parent
10
+ DEFAULT_CFG_PATH = root_dir / "pytorchocr" / "utils" / "resources" / "arch_config.yaml"
11
+
12
+
13
+ def init_args():
14
+ def str2bool(v):
15
+ return v.lower() in ("true", "t", "1")
16
+
17
+ parser = argparse.ArgumentParser()
18
+ # params for prediction engine
19
+ parser.add_argument("--use_gpu", type=str2bool, default=False)
20
+ parser.add_argument("--det", type=str2bool, default=True)
21
+ parser.add_argument("--rec", type=str2bool, default=True)
22
+ parser.add_argument("--device", type=str, default='cpu')
23
+ # parser.add_argument("--ir_optim", type=str2bool, default=True)
24
+ # parser.add_argument("--use_tensorrt", type=str2bool, default=False)
25
+ # parser.add_argument("--use_fp16", type=str2bool, default=False)
26
+ parser.add_argument("--gpu_mem", type=int, default=500)
27
+ parser.add_argument("--warmup", type=str2bool, default=False)
28
+
29
+ # params for text detector
30
+ parser.add_argument("--image_dir", type=str)
31
+ parser.add_argument("--det_algorithm", type=str, default='DB')
32
+ parser.add_argument("--det_model_path", type=str)
33
+ parser.add_argument("--det_limit_side_len", type=float, default=960)
34
+ parser.add_argument("--det_limit_type", type=str, default='max')
35
+
36
+ # DB parmas
37
+ parser.add_argument("--det_db_thresh", type=float, default=0.3)
38
+ parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
39
+ parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
40
+ parser.add_argument("--max_batch_size", type=int, default=10)
41
+ parser.add_argument("--use_dilation", type=str2bool, default=False)
42
+ parser.add_argument("--det_db_score_mode", type=str, default="fast")
43
+
44
+ # EAST parmas
45
+ parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
46
+ parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
47
+ parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
48
+
49
+ # SAST parmas
50
+ parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
51
+ parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
52
+ parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
53
+
54
+ # PSE parmas
55
+ parser.add_argument("--det_pse_thresh", type=float, default=0)
56
+ parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
57
+ parser.add_argument("--det_pse_min_area", type=float, default=16)
58
+ parser.add_argument("--det_pse_box_type", type=str, default='box')
59
+ parser.add_argument("--det_pse_scale", type=int, default=1)
60
+
61
+ # FCE parmas
62
+ parser.add_argument("--scales", type=list, default=[8, 16, 32])
63
+ parser.add_argument("--alpha", type=float, default=1.0)
64
+ parser.add_argument("--beta", type=float, default=1.0)
65
+ parser.add_argument("--fourier_degree", type=int, default=5)
66
+ parser.add_argument("--det_fce_box_type", type=str, default='poly')
67
+
68
+ # params for text recognizer
69
+ parser.add_argument("--rec_algorithm", type=str, default='CRNN')
70
+ parser.add_argument("--rec_model_path", type=str)
71
+ parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
72
+ parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
73
+ parser.add_argument("--rec_char_type", type=str, default='ch')
74
+ parser.add_argument("--rec_batch_num", type=int, default=6)
75
+ parser.add_argument("--max_text_length", type=int, default=25)
76
+
77
+ parser.add_argument("--use_space_char", type=str2bool, default=True)
78
+ parser.add_argument("--drop_score", type=float, default=0.5)
79
+ parser.add_argument("--limited_max_width", type=int, default=1280)
80
+ parser.add_argument("--limited_min_width", type=int, default=16)
81
+
82
+ parser.add_argument(
83
+ "--vis_font_path", type=str,
84
+ default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'doc/fonts/simfang.ttf'))
85
+ parser.add_argument(
86
+ "--rec_char_dict_path",
87
+ type=str,
88
+ default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
89
+ 'pytorchocr/utils/ppocr_keys_v1.txt'))
90
+
91
+ # params for text classifier
92
+ parser.add_argument("--use_angle_cls", type=str2bool, default=False)
93
+ parser.add_argument("--cls_model_path", type=str)
94
+ parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
95
+ parser.add_argument("--label_list", type=list, default=['0', '180'])
96
+ parser.add_argument("--cls_batch_num", type=int, default=6)
97
+ parser.add_argument("--cls_thresh", type=float, default=0.9)
98
+
99
+ parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
100
+ parser.add_argument("--use_pdserving", type=str2bool, default=False)
101
+
102
+ # params for e2e
103
+ parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
104
+ parser.add_argument("--e2e_model_path", type=str)
105
+ parser.add_argument("--e2e_limit_side_len", type=float, default=768)
106
+ parser.add_argument("--e2e_limit_type", type=str, default='max')
107
+
108
+ # PGNet parmas
109
+ parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
110
+ parser.add_argument(
111
+ "--e2e_char_dict_path", type=str,
112
+ default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
113
+ 'pytorchocr/utils/ic15_dict.txt'))
114
+ parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
115
+ parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
116
+ parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
117
+
118
+ # SR parmas
119
+ parser.add_argument("--sr_model_path", type=str)
120
+ parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
121
+ parser.add_argument("--sr_batch_num", type=int, default=1)
122
+
123
+ # params .yaml
124
+ parser.add_argument("--det_yaml_path", type=str, default=None)
125
+ parser.add_argument("--rec_yaml_path", type=str, default=None)
126
+ parser.add_argument("--cls_yaml_path", type=str, default=None)
127
+ parser.add_argument("--e2e_yaml_path", type=str, default=None)
128
+ parser.add_argument("--sr_yaml_path", type=str, default=None)
129
+
130
+ # multi-process
131
+ parser.add_argument("--use_mp", type=str2bool, default=False)
132
+ parser.add_argument("--total_process_num", type=int, default=1)
133
+ parser.add_argument("--process_id", type=int, default=0)
134
+
135
+ parser.add_argument("--benchmark", type=str2bool, default=False)
136
+ parser.add_argument("--save_log_path", type=str, default="./log_output/")
137
+
138
+ parser.add_argument("--show_log", type=str2bool, default=True)
139
+
140
+ return parser
141
+
142
+ def parse_args():
143
+ parser = init_args()
144
+ return parser.parse_args()
145
+
146
+ def get_default_config(args):
147
+ return vars(args)
148
+
149
+
150
+ def read_network_config_from_yaml(yaml_path, char_num=None):
151
+ if not os.path.exists(yaml_path):
152
+ raise FileNotFoundError('{} is not existed.'.format(yaml_path))
153
+ import yaml
154
+ with open(yaml_path, encoding='utf-8') as f:
155
+ res = yaml.safe_load(f)
156
+ if res.get('Architecture') is None:
157
+ raise ValueError('{} has no Architecture'.format(yaml_path))
158
+ if res['Architecture']['Head']['name'] == 'MultiHead' and char_num is not None:
159
+ res['Architecture']['Head']['out_channels_list'] = {
160
+ 'CTCLabelDecode': char_num,
161
+ 'SARLabelDecode': char_num + 2,
162
+ 'NRTRLabelDecode': char_num + 3
163
+ }
164
+ return res['Architecture']
165
+
166
+ def AnalysisConfig(weights_path, yaml_path=None, char_num=None):
167
+ if not os.path.exists(os.path.abspath(weights_path)):
168
+ raise FileNotFoundError('{} is not found.'.format(weights_path))
169
+
170
+ if yaml_path is not None:
171
+ return read_network_config_from_yaml(yaml_path, char_num=char_num)
172
+
173
+
174
+ def resize_img(img, input_size=600):
175
+ """
176
+ resize img and limit the longest side of the image to input_size
177
+ """
178
+ img = np.array(img)
179
+ im_shape = img.shape
180
+ im_size_max = np.max(im_shape[0:2])
181
+ im_scale = float(input_size) / float(im_size_max)
182
+ img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
183
+ return img
184
+
185
+
186
+ def str_count(s):
187
+ """
188
+ Count the number of Chinese characters,
189
+ a single English character and a single number
190
+ equal to half the length of Chinese characters.
191
+ args:
192
+ s(string): the input of string
193
+ return(int):
194
+ the number of Chinese characters
195
+ """
196
+ import string
197
+ count_zh = count_pu = 0
198
+ s_len = len(s)
199
+ en_dg_count = 0
200
+ for c in s:
201
+ if c in string.ascii_letters or c.isdigit() or c.isspace():
202
+ en_dg_count += 1
203
+ elif c.isalpha():
204
+ count_zh += 1
205
+ else:
206
+ count_pu += 1
207
+ return s_len - math.ceil(en_dg_count / 2)
208
+
209
+
210
+ def base64_to_cv2(b64str):
211
+ import base64
212
+ data = base64.b64decode(b64str.encode('utf8'))
213
+ data = np.fromstring(data, np.uint8)
214
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
215
+ return data
216
+
217
+
218
+ def get_arch_config(model_path):
219
+ from omegaconf import OmegaConf
220
+ all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
221
+ path = Path(model_path)
222
+ file_name = path.stem
223
+ if file_name not in all_arch_config:
224
+ raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
225
+
226
+ arch_config = all_arch_config[file_name]
227
+ return arch_config
@@ -1,3 +1,5 @@
1
+ import os
2
+ from pathlib import Path
1
3
  import cv2
2
4
  import numpy as np
3
5
  import torch
@@ -9,7 +11,7 @@ from magic_pdf.libs.config_reader import get_device
9
11
 
10
12
 
11
13
  class RapidTableModel(object):
12
- def __init__(self, ocr_engine, table_sub_model_name):
14
+ def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
13
15
  sub_model_list = [model.value for model in ModelType]
14
16
  if table_sub_model_name is None:
15
17
  input_args = RapidTableInput()
@@ -17,31 +19,25 @@ class RapidTableModel(object):
17
19
  if torch.cuda.is_available() and table_sub_model_name == "unitable":
18
20
  input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
19
21
  else:
20
- input_args = RapidTableInput(model_type=table_sub_model_name)
22
+ root_dir = Path(__file__).absolute().parent.parent.parent.parent.parent
23
+ slanet_plus_model_path = os.path.join(root_dir, 'resources', 'slanet_plus', 'slanet-plus.onnx')
24
+ input_args = RapidTableInput(model_type=table_sub_model_name, model_path=slanet_plus_model_path)
21
25
  else:
22
26
  raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
23
27
 
24
28
  self.table_model = RapidTable(input_args)
25
29
 
26
- # if ocr_engine is None:
27
- # self.ocr_model_name = "RapidOCR"
28
- # if torch.cuda.is_available():
29
- # from rapidocr_paddle import RapidOCR
30
- # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
31
- # else:
32
- # from rapidocr_onnxruntime import RapidOCR
33
- # self.ocr_engine = RapidOCR()
30
+ # self.ocr_model_name = "RapidOCR"
31
+ # if torch.cuda.is_available():
32
+ # from rapidocr_paddle import RapidOCR
33
+ # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
34
34
  # else:
35
- # self.ocr_model_name = "PaddleOCR"
36
- # self.ocr_engine = ocr_engine
35
+ # from rapidocr_onnxruntime import RapidOCR
36
+ # self.ocr_engine = RapidOCR()
37
+
38
+ self.ocr_model_name = "PaddleOCR"
39
+ self.ocr_engine = ocr_engine
37
40
 
38
- self.ocr_model_name = "RapidOCR"
39
- if torch.cuda.is_available():
40
- from rapidocr_paddle import RapidOCR
41
- self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
42
- else:
43
- from rapidocr_onnxruntime import RapidOCR
44
- self.ocr_engine = RapidOCR()
45
41
 
46
42
  def predict(self, image):
47
43
 
@@ -4,6 +4,7 @@ import os
4
4
  import re
5
5
  import statistics
6
6
  import time
7
+ import warnings
7
8
  from typing import List
8
9
 
9
10
  import cv2
@@ -11,6 +12,7 @@ import fitz
11
12
  import torch
12
13
  import numpy as np
13
14
  from loguru import logger
15
+ from tqdm import tqdm
14
16
 
15
17
  from magic_pdf.config.enums import SupportedPdfParseMethod
16
18
  from magic_pdf.config.ocr_content_type import BlockType, ContentType
@@ -21,20 +23,9 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
21
23
  from magic_pdf.libs.convert_utils import dict_to_list
22
24
  from magic_pdf.libs.hash_utils import compute_md5
23
25
  from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
24
- from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
25
26
  from magic_pdf.model.magic_model import MagicModel
26
27
  from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
27
28
 
28
- from concurrent.futures import ThreadPoolExecutor
29
-
30
- try:
31
- import torchtext
32
-
33
- if torchtext.__version__ >= '0.18.0':
34
- torchtext.disable_torchtext_deprecation_warning()
35
- except ImportError:
36
- pass
37
-
38
29
  from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
39
30
  from magic_pdf.post_proc.para_split_v3 import para_split
40
31
  from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
@@ -42,7 +33,7 @@ from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
42
33
  from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
43
34
  from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
44
35
  from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
45
- remove_overlaps_min_spans, check_chars_is_overlap_in_span
36
+ remove_overlaps_min_spans, remove_x_overlapping_chars
46
37
 
47
38
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
48
39
 
@@ -64,14 +55,6 @@ def __replace_STX_ETX(text_str: str):
64
55
  return text_str
65
56
 
66
57
 
67
- def __replace_0xfffd(text_str: str):
68
- """Replace \ufffd, as these characters become garbled when extracted using pymupdf."""
69
- if text_str:
70
- s = text_str.replace('\ufffd', " ")
71
- return s
72
- return text_str
73
-
74
-
75
58
  # 连写字符拆分
76
59
  def __replace_ligatures(text: str):
77
60
  ligatures = {
@@ -84,16 +67,17 @@ def chars_to_content(span):
84
67
  # 检查span中的char是否为空
85
68
  if len(span['chars']) == 0:
86
69
  pass
87
- # span['content'] = ''
88
- elif check_chars_is_overlap_in_span(span['chars']):
89
- pass
90
70
  else:
91
71
  # 先给chars按char['bbox']的中心点的x坐标排序
92
72
  span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
93
73
 
94
- # 求char的平均宽度
95
- char_width_sum = sum([char['bbox'][2] - char['bbox'][0] for char in span['chars']])
96
- char_avg_width = char_width_sum / len(span['chars'])
74
+ # Calculate the width of each character
75
+ char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
76
+ # Calculate the median width
77
+ median_width = statistics.median(char_widths)
78
+
79
+ # 通过x轴重叠比率移除一部分char
80
+ span = remove_x_overlapping_chars(span, median_width)
97
81
 
98
82
  content = ''
99
83
  for char in span['chars']:
@@ -101,13 +85,12 @@ def chars_to_content(span):
101
85
  # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
102
86
  char1 = char
103
87
  char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
104
- if char2 and char2['bbox'][0] - char1['bbox'][2] > char_avg_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
88
+ if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
105
89
  content += f"{char['c']} "
106
90
  else:
107
91
  content += char['c']
108
92
 
109
- content = __replace_ligatures(content)
110
- span['content'] = __replace_0xfffd(content)
93
+ span['content'] = __replace_ligatures(content)
111
94
 
112
95
  del span['chars']
113
96
 
@@ -122,10 +105,6 @@ def fill_char_in_spans(spans, all_chars):
122
105
  spans = sorted(spans, key=lambda x: x['bbox'][1])
123
106
 
124
107
  for char in all_chars:
125
- # 跳过非法bbox的char
126
- # x1, y1, x2, y2 = char['bbox']
127
- # if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
128
- # continue
129
108
 
130
109
  for span in spans:
131
110
  if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
@@ -215,7 +194,7 @@ def calculate_contrast(img, img_mode) -> float:
215
194
  std_dev = np.std(gray_img)
216
195
  # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
217
196
  contrast = std_dev / (mean_value + 1e-6)
218
- # logger.info(f"contrast: {contrast}")
197
+ # logger.debug(f"contrast: {contrast}")
219
198
  return round(contrast, 2)
220
199
 
221
200
  # @measure_time
@@ -308,41 +287,53 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
308
287
  if len(need_ocr_spans) > 0:
309
288
 
310
289
  # 初始化ocr模型
311
- atom_model_manager = AtomModelSingleton()
312
- ocr_model = atom_model_manager.get_atom_model(
313
- atom_model_name='ocr',
314
- ocr_show_log=False,
315
- det_db_box_thresh=0.3,
316
- lang=lang
317
- )
290
+ # atom_model_manager = AtomModelSingleton()
291
+ # ocr_model = atom_model_manager.get_atom_model(
292
+ # atom_model_name='ocr',
293
+ # ocr_show_log=False,
294
+ # det_db_box_thresh=0.3,
295
+ # lang=lang
296
+ # )
318
297
 
319
298
  for span in need_ocr_spans:
320
299
  # 对span的bbox截图再ocr
321
300
  span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
322
301
 
323
302
  # 计算span的对比度,低于0.20的span不进行ocr
324
- if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
303
+ if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
325
304
  spans.remove(span)
326
305
  continue
306
+ # pass
307
+
308
+ span['content'] = ''
309
+ span['score'] = 1
310
+ span['np_img'] = span_img
327
311
 
328
- ocr_res = ocr_model.ocr(span_img, det=False)
329
- if ocr_res and len(ocr_res) > 0:
330
- if len(ocr_res[0]) > 0:
331
- ocr_text, ocr_score = ocr_res[0][0]
332
- # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
333
- if ocr_score > 0.5 and len(ocr_text) > 0:
334
- span['content'] = ocr_text
335
- span['score'] = ocr_score
336
- else:
337
- spans.remove(span)
312
+
313
+ # ocr_res = ocr_model.ocr(span_img, det=False)
314
+ # if ocr_res and len(ocr_res) > 0:
315
+ # if len(ocr_res[0]) > 0:
316
+ # ocr_text, ocr_score = ocr_res[0][0]
317
+ # # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
318
+ # if ocr_score > 0.5 and len(ocr_text) > 0:
319
+ # span['content'] = ocr_text
320
+ # span['score'] = float(round(ocr_score, 2))
321
+ # else:
322
+ # spans.remove(span)
338
323
 
339
324
  return spans
340
325
 
341
326
 
342
327
  def model_init(model_name: str):
343
328
  from transformers import LayoutLMv3ForTokenClassification
344
- device = torch.device(get_device())
345
-
329
+ device_name = get_device()
330
+ bf_16_support = False
331
+ if device_name.startswith("cuda"):
332
+ bf_16_support = torch.cuda.is_bf16_supported()
333
+ elif device_name.startswith("mps"):
334
+ bf_16_support = True
335
+
336
+ device = torch.device(device_name)
346
337
  if model_name == 'layoutreader':
347
338
  # 检测modelscope的缓存目录是否存在
348
339
  layoutreader_model_dir = get_local_layoutreader_model_dir()
@@ -357,7 +348,10 @@ def model_init(model_name: str):
357
348
  model = LayoutLMv3ForTokenClassification.from_pretrained(
358
349
  'hantian/layoutreader'
359
350
  )
360
- model.to(device).eval()
351
+ if bf_16_support:
352
+ model.to(device).eval().bfloat16()
353
+ else:
354
+ model.to(device).eval()
361
355
  else:
362
356
  logger.error('model name not allow')
363
357
  exit(1)
@@ -383,9 +377,12 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
383
377
  from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (
384
378
  boxes2inputs, parse_logits, prepare_inputs)
385
379
 
386
- inputs = boxes2inputs(boxes)
387
- inputs = prepare_inputs(inputs, model)
388
- logits = model(**inputs).logits.cpu().squeeze(0)
380
+ with warnings.catch_warnings():
381
+ warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
382
+
383
+ inputs = boxes2inputs(boxes)
384
+ inputs = prepare_inputs(inputs, model)
385
+ logits = model(**inputs).logits.cpu().squeeze(0)
389
386
  return parse_logits(logits, len(boxes))
390
387
 
391
388
 
@@ -463,20 +460,20 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
463
460
  if (
464
461
  block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
465
462
  ): # 可能是双列结构,可以切细点
466
- lines = int(block_height / line_height) + 1
463
+ lines = int(block_height / line_height)
467
464
  else:
468
465
  # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
469
466
  if block_weight > page_w * 0.4:
470
467
  lines = 3
471
- line_height = (y1 - y0) / lines
472
468
  elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
473
- lines = int(block_height / line_height) + 1
469
+ lines = int(block_height / line_height)
474
470
  else: # 判断长宽比
475
471
  if block_height / block_weight > 1.2: # 细长的不分
476
472
  return [[x0, y0, x1, y1]]
477
473
  else: # 不细长的还是分成两行
478
474
  lines = 2
479
- line_height = (y1 - y0) / lines
475
+
476
+ line_height = (y1 - y0) / lines
480
477
 
481
478
  # 确定从哪个y位置开始绘制线条
482
479
  current_y = y0
@@ -492,7 +489,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
492
489
  else:
493
490
  return [[x0, y0, x1, y1]]
494
491
 
495
- # @measure_time
492
+
496
493
  def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
497
494
  page_line_list = []
498
495
 
@@ -936,17 +933,18 @@ def pdf_parse_union(
936
933
  logger.warning('end_page_id is out of range, use pdf_docs length')
937
934
  end_page_id = len(dataset) - 1
938
935
 
939
- """初始化启动时间"""
940
- start_time = time.time()
936
+ # """初始化启动时间"""
937
+ # start_time = time.time()
941
938
 
942
- for page_id, page in enumerate(dataset):
943
- """debug时输出每页解析的耗时."""
944
- if debug_mode:
945
- time_now = time.time()
946
- logger.info(
947
- f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
948
- )
949
- start_time = time_now
939
+ # for page_id, page in enumerate(dataset):
940
+ for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
941
+ # """debug时输出每页解析的耗时."""
942
+ # if debug_mode:
943
+ # time_now = time.time()
944
+ # logger.info(
945
+ # f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
946
+ # )
947
+ # start_time = time_now
950
948
 
951
949
  """解析pdf中的每一页"""
952
950
  if start_page_id <= page_id <= end_page_id:
@@ -962,7 +960,47 @@ def pdf_parse_union(
962
960
  )
963
961
  pdf_info_dict[f'page_{page_id}'] = page_info
964
962
 
965
- # PerformanceStats.print_stats()
963
+ need_ocr_list = []
964
+ img_crop_list = []
965
+ text_block_list = []
966
+ for pange_id, page_info in pdf_info_dict.items():
967
+ for block in page_info['preproc_blocks']:
968
+ if block['type'] in ['table', 'image']:
969
+ for sub_block in block['blocks']:
970
+ if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
971
+ text_block_list.append(sub_block)
972
+ elif block['type'] in ['text', 'title']:
973
+ text_block_list.append(block)
974
+ for block in page_info['discarded_blocks']:
975
+ text_block_list.append(block)
976
+ for block in text_block_list:
977
+ for line in block['lines']:
978
+ for span in line['spans']:
979
+ if 'np_img' in span:
980
+ need_ocr_list.append(span)
981
+ img_crop_list.append(span['np_img'])
982
+ span.pop('np_img')
983
+ if len(img_crop_list) > 0:
984
+ # Get OCR results for this language's images
985
+ atom_model_manager = AtomModelSingleton()
986
+ ocr_model = atom_model_manager.get_atom_model(
987
+ atom_model_name='ocr',
988
+ ocr_show_log=False,
989
+ det_db_box_thresh=0.3,
990
+ lang=lang
991
+ )
992
+ # rec_start = time.time()
993
+ ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
994
+ # Verify we have matching counts
995
+ assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
996
+ # Process OCR results for this language
997
+ for index, span in enumerate(need_ocr_list):
998
+ ocr_text, ocr_score = ocr_res_list[index]
999
+ span['content'] = ocr_text
1000
+ span['score'] = float(f"{ocr_score:.3f}")
1001
+ # rec_time = time.time() - rec_start
1002
+ # logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}')
1003
+
966
1004
 
967
1005
  """分段"""
968
1006
  para_split(pdf_info_dict)
@@ -62,7 +62,15 @@ def merge_spans_to_line(spans, threshold=0.6):
62
62
 
63
63
  def span_block_type_compatible(span_type, block_type):
64
64
  if span_type in [ContentType.Text, ContentType.InlineEquation]:
65
- return block_type in [BlockType.Text, BlockType.Title, BlockType.ImageCaption, BlockType.ImageFootnote, BlockType.TableCaption, BlockType.TableFootnote]
65
+ return block_type in [
66
+ BlockType.Text,
67
+ BlockType.Title,
68
+ BlockType.ImageCaption,
69
+ BlockType.ImageFootnote,
70
+ BlockType.TableCaption,
71
+ BlockType.TableFootnote,
72
+ BlockType.Discarded
73
+ ]
66
74
  elif span_type == ContentType.InterlineEquation:
67
75
  return block_type in [BlockType.InterlineEquation, BlockType.Text]
68
76
  elif span_type == ContentType.Image: