magic-pdf 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. magic_pdf/config/constants.py +53 -0
  2. magic_pdf/config/drop_reason.py +35 -0
  3. magic_pdf/config/drop_tag.py +19 -0
  4. magic_pdf/config/make_content_config.py +11 -0
  5. magic_pdf/{libs/ModelBlockTypeEnum.py → config/model_block_type.py} +2 -1
  6. magic_pdf/data/read_api.py +1 -1
  7. magic_pdf/dict2md/mkcontent.py +226 -185
  8. magic_pdf/dict2md/ocr_mkcontent.py +11 -11
  9. magic_pdf/filter/pdf_meta_scan.py +101 -79
  10. magic_pdf/integrations/rag/utils.py +4 -5
  11. magic_pdf/libs/config_reader.py +5 -5
  12. magic_pdf/libs/draw_bbox.py +3 -2
  13. magic_pdf/libs/pdf_image_tools.py +36 -12
  14. magic_pdf/libs/version.py +1 -1
  15. magic_pdf/model/doc_analyze_by_custom_model.py +2 -0
  16. magic_pdf/model/magic_model.py +13 -13
  17. magic_pdf/model/pdf_extract_kit.py +122 -76
  18. magic_pdf/model/sub_modules/model_init.py +40 -35
  19. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -7
  20. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +12 -4
  21. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +2 -0
  22. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +30 -28
  23. magic_pdf/para/para_split.py +411 -248
  24. magic_pdf/para/para_split_v2.py +352 -182
  25. magic_pdf/para/para_split_v3.py +110 -53
  26. magic_pdf/pdf_parse_by_ocr.py +2 -0
  27. magic_pdf/pdf_parse_by_txt.py +2 -0
  28. magic_pdf/pdf_parse_union_core.py +174 -100
  29. magic_pdf/pdf_parse_union_core_v2.py +202 -36
  30. magic_pdf/pipe/AbsPipe.py +28 -44
  31. magic_pdf/pipe/OCRPipe.py +5 -5
  32. magic_pdf/pipe/TXTPipe.py +5 -6
  33. magic_pdf/pipe/UNIPipe.py +24 -25
  34. magic_pdf/post_proc/pdf_post_filter.py +7 -14
  35. magic_pdf/pre_proc/cut_image.py +9 -11
  36. magic_pdf/pre_proc/equations_replace.py +203 -212
  37. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +235 -49
  38. magic_pdf/pre_proc/ocr_dict_merge.py +5 -5
  39. magic_pdf/pre_proc/ocr_span_list_modify.py +122 -63
  40. magic_pdf/pre_proc/pdf_pre_filter.py +37 -33
  41. magic_pdf/pre_proc/remove_bbox_overlap.py +20 -18
  42. magic_pdf/pre_proc/remove_colored_strip_bbox.py +36 -14
  43. magic_pdf/pre_proc/remove_footer_header.py +2 -5
  44. magic_pdf/pre_proc/remove_rotate_bbox.py +111 -63
  45. magic_pdf/pre_proc/resolve_bbox_conflict.py +10 -17
  46. magic_pdf/spark/spark_api.py +15 -17
  47. magic_pdf/tools/cli.py +3 -4
  48. magic_pdf/tools/cli_dev.py +6 -9
  49. magic_pdf/tools/common.py +26 -36
  50. magic_pdf/user_api.py +29 -38
  51. {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/METADATA +11 -12
  52. {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/RECORD +57 -58
  53. magic_pdf/libs/Constants.py +0 -55
  54. magic_pdf/libs/MakeContentConfig.py +0 -11
  55. magic_pdf/libs/drop_reason.py +0 -27
  56. magic_pdf/libs/drop_tag.py +0 -19
  57. magic_pdf/para/para_pipeline.py +0 -297
  58. /magic_pdf/{libs → config}/ocr_content_type.py +0 -0
  59. {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/LICENSE.md +0 -0
  60. {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/WHEEL +0 -0
  61. {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/entry_points.txt +0 -0
  62. {magic_pdf-0.9.3.dist-info → magic_pdf-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
- import numpy as np
2
- import torch
3
- from loguru import logger
1
+ # flake8: noqa
4
2
  import os
5
3
  import time
4
+
6
5
  import cv2
6
+ import numpy as np
7
+ import torch
7
8
  import yaml
9
+ from loguru import logger
8
10
  from PIL import Image
9
11
 
10
12
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
13
15
  try:
14
16
  import torchtext
15
17
 
16
- if torchtext.__version__ >= "0.18.0":
18
+ if torchtext.__version__ >= '0.18.0':
17
19
  torchtext.disable_torchtext_deprecation_warning()
18
20
  except ImportError:
19
21
  pass
20
22
 
21
- from magic_pdf.libs.Constants import *
23
+ from magic_pdf.config.constants import *
22
24
  from magic_pdf.model.model_list import AtomicModel
23
25
  from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
24
- from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
25
- from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
26
+ from magic_pdf.model.sub_modules.model_utils import (
27
+ clean_vram, crop_img, get_res_list_from_layout_res)
28
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
29
+ get_adjusted_mfdetrec_res, get_ocr_result_list)
26
30
 
27
31
 
28
32
  class CustomPEKModel:
@@ -41,42 +45,54 @@ class CustomPEKModel:
41
45
  model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
42
46
  # 构建 model_configs.yaml 文件的完整路径
43
47
  config_path = os.path.join(model_config_dir, 'model_configs.yaml')
44
- with open(config_path, "r", encoding='utf-8') as f:
48
+ with open(config_path, 'r', encoding='utf-8') as f:
45
49
  self.configs = yaml.load(f, Loader=yaml.FullLoader)
46
50
  # 初始化解析配置
47
51
 
48
52
  # layout config
49
- self.layout_config = kwargs.get("layout_config")
50
- self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
53
+ self.layout_config = kwargs.get('layout_config')
54
+ self.layout_model_name = self.layout_config.get(
55
+ 'model', MODEL_NAME.DocLayout_YOLO
56
+ )
51
57
 
52
58
  # formula config
53
- self.formula_config = kwargs.get("formula_config")
54
- self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
55
- self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
56
- self.apply_formula = self.formula_config.get("enable", True)
59
+ self.formula_config = kwargs.get('formula_config')
60
+ self.mfd_model_name = self.formula_config.get(
61
+ 'mfd_model', MODEL_NAME.YOLO_V8_MFD
62
+ )
63
+ self.mfr_model_name = self.formula_config.get(
64
+ 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
65
+ )
66
+ self.apply_formula = self.formula_config.get('enable', True)
57
67
 
58
68
  # table config
59
- self.table_config = kwargs.get("table_config")
60
- self.apply_table = self.table_config.get("enable", False)
61
- self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
62
- self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
69
+ self.table_config = kwargs.get('table_config')
70
+ self.apply_table = self.table_config.get('enable', False)
71
+ self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
72
+ self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
63
73
 
64
74
  # ocr config
65
75
  self.apply_ocr = ocr
66
- self.lang = kwargs.get("lang", None)
76
+ self.lang = kwargs.get('lang', None)
67
77
 
68
78
  logger.info(
69
- "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
70
- "apply_table: {}, table_model: {}, lang: {}".format(
71
- self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
72
- self.lang
79
+ 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
80
+ 'apply_table: {}, table_model: {}, lang: {}'.format(
81
+ self.layout_model_name,
82
+ self.apply_formula,
83
+ self.apply_ocr,
84
+ self.apply_table,
85
+ self.table_model_name,
86
+ self.lang,
73
87
  )
74
88
  )
75
89
  # 初始化解析方案
76
- self.device = kwargs.get("device", "cpu")
77
- logger.info("using device: {}".format(self.device))
78
- models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
79
- logger.info("using models_dir: {}".format(models_dir))
90
+ self.device = kwargs.get('device', 'cpu')
91
+ logger.info('using device: {}'.format(self.device))
92
+ models_dir = kwargs.get(
93
+ 'models_dir', os.path.join(root_dir, 'resources', 'models')
94
+ )
95
+ logger.info('using models_dir: {}'.format(models_dir))
80
96
 
81
97
  atom_model_manager = AtomModelSingleton()
82
98
 
@@ -85,18 +101,24 @@ class CustomPEKModel:
85
101
  # 初始化公式检测模型
86
102
  self.mfd_model = atom_model_manager.get_atom_model(
87
103
  atom_model_name=AtomicModel.MFD,
88
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
89
- device=self.device
104
+ mfd_weights=str(
105
+ os.path.join(
106
+ models_dir, self.configs['weights'][self.mfd_model_name]
107
+ )
108
+ ),
109
+ device=self.device,
90
110
  )
91
111
 
92
112
  # 初始化公式解析模型
93
- mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
94
- mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
113
+ mfr_weight_dir = str(
114
+ os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
115
+ )
116
+ mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
95
117
  self.mfr_model = atom_model_manager.get_atom_model(
96
118
  atom_model_name=AtomicModel.MFR,
97
119
  mfr_weight_dir=mfr_weight_dir,
98
120
  mfr_cfg_path=mfr_cfg_path,
99
- device=self.device
121
+ device=self.device,
100
122
  )
101
123
 
102
124
  # 初始化layout模型
@@ -104,42 +126,51 @@ class CustomPEKModel:
104
126
  self.layout_model = atom_model_manager.get_atom_model(
105
127
  atom_model_name=AtomicModel.Layout,
106
128
  layout_model_name=MODEL_NAME.LAYOUTLMv3,
107
- layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
108
- layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
109
- device=self.device
129
+ layout_weights=str(
130
+ os.path.join(
131
+ models_dir, self.configs['weights'][self.layout_model_name]
132
+ )
133
+ ),
134
+ layout_config_file=str(
135
+ os.path.join(
136
+ model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
137
+ )
138
+ ),
139
+ device=self.device,
110
140
  )
111
141
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
112
142
  self.layout_model = atom_model_manager.get_atom_model(
113
143
  atom_model_name=AtomicModel.Layout,
114
144
  layout_model_name=MODEL_NAME.DocLayout_YOLO,
115
- doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
116
- device=self.device
145
+ doclayout_yolo_weights=str(
146
+ os.path.join(
147
+ models_dir, self.configs['weights'][self.layout_model_name]
148
+ )
149
+ ),
150
+ device=self.device,
117
151
  )
118
152
  # 初始化ocr
119
- if self.apply_ocr:
120
- self.ocr_model = atom_model_manager.get_atom_model(
121
- atom_model_name=AtomicModel.OCR,
122
- ocr_show_log=show_log,
123
- det_db_box_thresh=0.3,
124
- lang=self.lang
125
- )
153
+ self.ocr_model = atom_model_manager.get_atom_model(
154
+ atom_model_name=AtomicModel.OCR,
155
+ ocr_show_log=show_log,
156
+ det_db_box_thresh=0.3,
157
+ lang=self.lang
158
+ )
126
159
  # init table model
127
160
  if self.apply_table:
128
- table_model_dir = self.configs["weights"][self.table_model_name]
161
+ table_model_dir = self.configs['weights'][self.table_model_name]
129
162
  self.table_model = atom_model_manager.get_atom_model(
130
163
  atom_model_name=AtomicModel.Table,
131
164
  table_model_name=self.table_model_name,
132
165
  table_model_path=str(os.path.join(models_dir, table_model_dir)),
133
166
  table_max_time=self.table_max_time,
134
- device=self.device
167
+ device=self.device,
135
168
  )
136
169
 
137
170
  logger.info('DocAnalysis init done!')
138
171
 
139
172
  def __call__(self, image):
140
173
 
141
- page_start = time.time()
142
-
143
174
  # layout检测
144
175
  layout_start = time.time()
145
176
  layout_res = []
@@ -150,7 +181,7 @@ class CustomPEKModel:
150
181
  # doclayout_yolo
151
182
  layout_res = self.layout_model.predict(image)
152
183
  layout_cost = round(time.time() - layout_start, 2)
153
- logger.info(f"layout detection time: {layout_cost}")
184
+ logger.info(f'layout detection time: {layout_cost}')
154
185
 
155
186
  pil_img = Image.fromarray(image)
156
187
 
@@ -158,40 +189,47 @@ class CustomPEKModel:
158
189
  # 公式检测
159
190
  mfd_start = time.time()
160
191
  mfd_res = self.mfd_model.predict(image)
161
- logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
192
+ logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
162
193
 
163
194
  # 公式识别
164
195
  mfr_start = time.time()
165
196
  formula_list = self.mfr_model.predict(mfd_res, image)
166
197
  layout_res.extend(formula_list)
167
198
  mfr_cost = round(time.time() - mfr_start, 2)
168
- logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
199
+ logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
169
200
 
170
201
  # 清理显存
171
202
  clean_vram(self.device, vram_threshold=8)
172
203
 
173
204
  # 从layout_res中获取ocr区域、表格区域、公式区域
174
- ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
205
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = (
206
+ get_res_list_from_layout_res(layout_res)
207
+ )
175
208
 
176
209
  # ocr识别
177
- if self.apply_ocr:
178
- ocr_start = time.time()
179
- # Process each area that requires OCR processing
180
- for res in ocr_res_list:
181
- new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
182
- adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
183
-
184
- # OCR recognition
185
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
210
+ ocr_start = time.time()
211
+ # Process each area that requires OCR processing
212
+ for res in ocr_res_list:
213
+ new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
214
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
215
+
216
+ # OCR recognition
217
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
218
+ if self.apply_ocr:
186
219
  ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
220
+ else:
221
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
187
222
 
188
- # Integration results
189
- if ocr_res:
190
- ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
191
- layout_res.extend(ocr_result_list)
223
+ # Integration results
224
+ if ocr_res:
225
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
226
+ layout_res.extend(ocr_result_list)
192
227
 
193
- ocr_cost = round(time.time() - ocr_start, 2)
228
+ ocr_cost = round(time.time() - ocr_start, 2)
229
+ if self.apply_ocr:
194
230
  logger.info(f"ocr time: {ocr_cost}")
231
+ else:
232
+ logger.info(f"det time: {ocr_cost}")
195
233
 
196
234
  # 表格识别 table recognition
197
235
  if self.apply_table:
@@ -202,27 +240,35 @@ class CustomPEKModel:
202
240
  html_code = None
203
241
  if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
204
242
  with torch.no_grad():
205
- table_result = self.table_model.predict(new_image, "html")
243
+ table_result = self.table_model.predict(new_image, 'html')
206
244
  if len(table_result) > 0:
207
245
  html_code = table_result[0]
208
246
  elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
209
247
  html_code = self.table_model.img2html(new_image)
210
248
  elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
211
- html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
249
+ html_code, table_cell_bboxes, elapse = self.table_model.predict(
250
+ new_image
251
+ )
212
252
  run_time = time.time() - single_table_start_time
213
253
  if run_time > self.table_max_time:
214
- logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
254
+ logger.warning(
255
+ f'table recognition processing exceeds max time {self.table_max_time}s'
256
+ )
215
257
  # 判断是否返回正常
216
258
  if html_code:
217
- expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
259
+ expected_ending = html_code.strip().endswith(
260
+ '</html>'
261
+ ) or html_code.strip().endswith('</table>')
218
262
  if expected_ending:
219
- res["html"] = html_code
263
+ res['html'] = html_code
220
264
  else:
221
- logger.warning(f"table recognition processing fails, not found expected HTML table end")
265
+ logger.warning(
266
+ 'table recognition processing fails, not found expected HTML table end'
267
+ )
222
268
  else:
223
- logger.warning(f"table recognition processing fails, not get html return")
224
- logger.info(f"table time: {round(time.time() - table_start, 2)}")
225
-
226
- logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
269
+ logger.warning(
270
+ 'table recognition processing fails, not get html return'
271
+ )
272
+ logger.info(f'table time: {round(time.time() - table_start, 2)}')
227
273
 
228
274
  return layout_res
@@ -1,17 +1,22 @@
1
1
  from loguru import logger
2
2
 
3
- from magic_pdf.libs.Constants import MODEL_NAME
3
+ from magic_pdf.config.constants import MODEL_NAME
4
4
  from magic_pdf.model.model_list import AtomicModel
5
- from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
6
- from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
5
+ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
6
+ DocLayoutYOLOModel
7
+ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
8
+ Layoutlmv3_Predictor
7
9
  from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
8
-
9
10
  from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
10
- from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
11
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
12
+ ModifiedPaddleOCR
13
+ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
14
+ RapidTableModel
11
15
  # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
12
- from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
13
- from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
14
- from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
16
+ from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
17
+ StructTableModel
18
+ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
19
+ TableMasterPaddleModel
15
20
 
16
21
 
17
22
  def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
19
24
  table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
20
25
  elif table_model_type == MODEL_NAME.TABLE_MASTER:
21
26
  config = {
22
- "model_dir": model_path,
23
- "device": _device_
27
+ 'model_dir': model_path,
28
+ 'device': _device_
24
29
  }
25
30
  table_model = TableMasterPaddleModel(config)
26
31
  elif table_model_type == MODEL_NAME.RAPID_TABLE:
27
32
  table_model = RapidTableModel()
28
33
  else:
29
- logger.error("table model type not allow")
34
+ logger.error('table model type not allow')
30
35
  exit(1)
31
36
 
32
37
  return table_model
@@ -58,7 +63,7 @@ def ocr_model_init(show_log: bool = False,
58
63
  use_dilation=True,
59
64
  det_db_unclip_ratio=1.8,
60
65
  ):
61
- if lang is not None:
66
+ if lang is not None and lang != '':
62
67
  model = ModifiedPaddleOCR(
63
68
  show_log=show_log,
64
69
  det_db_box_thresh=det_db_box_thresh,
@@ -87,8 +92,8 @@ class AtomModelSingleton:
87
92
  return cls._instance
88
93
 
89
94
  def get_atom_model(self, atom_model_name: str, **kwargs):
90
- lang = kwargs.get("lang", None)
91
- layout_model_name = kwargs.get("layout_model_name", None)
95
+ lang = kwargs.get('lang', None)
96
+ layout_model_name = kwargs.get('layout_model_name', None)
92
97
  key = (atom_model_name, layout_model_name, lang)
93
98
  if key not in self._models:
94
99
  self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
@@ -98,47 +103,47 @@ class AtomModelSingleton:
98
103
  def atom_model_init(model_name: str, **kwargs):
99
104
  atom_model = None
100
105
  if model_name == AtomicModel.Layout:
101
- if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
106
+ if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
102
107
  atom_model = layout_model_init(
103
- kwargs.get("layout_weights"),
104
- kwargs.get("layout_config_file"),
105
- kwargs.get("device")
108
+ kwargs.get('layout_weights'),
109
+ kwargs.get('layout_config_file'),
110
+ kwargs.get('device')
106
111
  )
107
- elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
112
+ elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
108
113
  atom_model = doclayout_yolo_model_init(
109
- kwargs.get("doclayout_yolo_weights"),
110
- kwargs.get("device")
114
+ kwargs.get('doclayout_yolo_weights'),
115
+ kwargs.get('device')
111
116
  )
112
117
  elif model_name == AtomicModel.MFD:
113
118
  atom_model = mfd_model_init(
114
- kwargs.get("mfd_weights"),
115
- kwargs.get("device")
119
+ kwargs.get('mfd_weights'),
120
+ kwargs.get('device')
116
121
  )
117
122
  elif model_name == AtomicModel.MFR:
118
123
  atom_model = mfr_model_init(
119
- kwargs.get("mfr_weight_dir"),
120
- kwargs.get("mfr_cfg_path"),
121
- kwargs.get("device")
124
+ kwargs.get('mfr_weight_dir'),
125
+ kwargs.get('mfr_cfg_path'),
126
+ kwargs.get('device')
122
127
  )
123
128
  elif model_name == AtomicModel.OCR:
124
129
  atom_model = ocr_model_init(
125
- kwargs.get("ocr_show_log"),
126
- kwargs.get("det_db_box_thresh"),
127
- kwargs.get("lang")
130
+ kwargs.get('ocr_show_log'),
131
+ kwargs.get('det_db_box_thresh'),
132
+ kwargs.get('lang')
128
133
  )
129
134
  elif model_name == AtomicModel.Table:
130
135
  atom_model = table_model_init(
131
- kwargs.get("table_model_name"),
132
- kwargs.get("table_model_path"),
133
- kwargs.get("table_max_time"),
134
- kwargs.get("device")
136
+ kwargs.get('table_model_name'),
137
+ kwargs.get('table_model_path'),
138
+ kwargs.get('table_max_time'),
139
+ kwargs.get('device')
135
140
  )
136
141
  else:
137
- logger.error("model name not allow")
142
+ logger.error('model name not allow')
138
143
  exit(1)
139
144
 
140
145
  if atom_model is None:
141
- logger.error("model init failed")
146
+ logger.error('model init failed')
142
147
  exit(1)
143
148
  else:
144
149
  return atom_model
@@ -71,7 +71,13 @@ def remove_intervals(original, masks):
71
71
 
72
72
  def update_det_boxes(dt_boxes, mfd_res):
73
73
  new_dt_boxes = []
74
+ angle_boxes_list = []
74
75
  for text_box in dt_boxes:
76
+
77
+ if calculate_is_angle(text_box):
78
+ angle_boxes_list.append(text_box)
79
+ continue
80
+
75
81
  text_bbox = points_to_bbox(text_box)
76
82
  masks_list = []
77
83
  for mf_box in mfd_res:
@@ -85,6 +91,9 @@ def update_det_boxes(dt_boxes, mfd_res):
85
91
  temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
86
92
  if len(temp_dt_box) > 0:
87
93
  new_dt_boxes.extend(temp_dt_box)
94
+
95
+ new_dt_boxes.extend(angle_boxes_list)
96
+
88
97
  return new_dt_boxes
89
98
 
90
99
 
@@ -143,9 +152,11 @@ def merge_det_boxes(dt_boxes):
143
152
  angle_boxes_list = []
144
153
  for text_box in dt_boxes:
145
154
  text_bbox = points_to_bbox(text_box)
146
- if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
155
+
156
+ if calculate_is_angle(text_box):
147
157
  angle_boxes_list.append(text_box)
148
158
  continue
159
+
149
160
  text_box_dict = {
150
161
  'bbox': text_bbox,
151
162
  'type': 'text',
@@ -200,15 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
200
211
  ocr_result_list = []
201
212
  for box_ocr_res in ocr_res:
202
213
 
203
- p1, p2, p3, p4 = box_ocr_res[0]
204
- text, score = box_ocr_res[1]
205
- average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
206
- if average_angle_degrees > 0.5:
214
+ if len(box_ocr_res) == 2:
215
+ p1, p2, p3, p4 = box_ocr_res[0]
216
+ text, score = box_ocr_res[1]
217
+ else:
218
+ p1, p2, p3, p4 = box_ocr_res
219
+ text, score = "", 1
220
+ # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
221
+ # if average_angle_degrees > 0.5:
222
+ poly = [p1, p2, p3, p4]
223
+ if calculate_is_angle(poly):
207
224
  # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
208
225
  # 与x轴的夹角超过0.5度,对边界做一下矫正
209
226
  # 计算几何中心
210
- x_center = sum(point[0] for point in box_ocr_res[0]) / 4
211
- y_center = sum(point[1] for point in box_ocr_res[0]) / 4
227
+ x_center = sum(point[0] for point in poly) / 4
228
+ y_center = sum(point[1] for point in poly) / 4
212
229
  new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
213
230
  new_width = p3[0] - p1[0]
214
231
  p1 = [x_center - new_width / 2, y_center - new_height / 2]
@@ -257,3 +274,12 @@ def calculate_angle_degrees(poly):
257
274
  # logger.info(f"average_angle_degrees: {average_angle_degrees}")
258
275
  return average_angle_degrees
259
276
 
277
+
278
+ def calculate_is_angle(poly):
279
+ p1, p2, p3, p4 = poly
280
+ height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
281
+ if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
282
+ return False
283
+ else:
284
+ # logger.info((p3[1] - p1[1])/height)
285
+ return True
@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
78
78
  for idx, img in enumerate(imgs):
79
79
  img = preprocess_image(img)
80
80
  dt_boxes, elapse = self.text_detector(img)
81
- if not dt_boxes:
81
+ if dt_boxes is None:
82
82
  ocr_res.append(None)
83
83
  continue
84
+ dt_boxes = sorted_boxes(dt_boxes)
85
+ # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
86
+ dt_boxes = merge_det_boxes(dt_boxes)
87
+ if mfd_res:
88
+ bef = time.time()
89
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
90
+ aft = time.time()
91
+ logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
92
+ len(dt_boxes), aft - bef))
84
93
  tmp_res = [box.tolist() for box in dt_boxes]
85
94
  ocr_res.append(tmp_res)
86
95
  return ocr_res
@@ -125,9 +134,8 @@ class ModifiedPaddleOCR(PaddleOCR):
125
134
 
126
135
  dt_boxes = sorted_boxes(dt_boxes)
127
136
 
128
- # @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
129
- # dt_boxes = merge_det_boxes(dt_boxes)
130
-
137
+ # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
138
+ dt_boxes = merge_det_boxes(dt_boxes)
131
139
 
132
140
  if mfd_res:
133
141
  bef = time.time()
@@ -10,5 +10,7 @@ class RapidTableModel(object):
10
10
 
11
11
  def predict(self, image):
12
12
  ocr_result, _ = self.ocr_engine(np.asarray(image))
13
+ if ocr_result is None:
14
+ return None, None, None
13
15
  html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
14
16
  return html_code, table_cell_bboxes, elapse