magic-pdf 0.9.2__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. magic_pdf/config/constants.py +53 -0
  2. magic_pdf/config/drop_reason.py +35 -0
  3. magic_pdf/config/drop_tag.py +19 -0
  4. magic_pdf/config/make_content_config.py +11 -0
  5. magic_pdf/{libs/ModelBlockTypeEnum.py → config/model_block_type.py} +2 -1
  6. magic_pdf/data/read_api.py +1 -1
  7. magic_pdf/dict2md/mkcontent.py +226 -185
  8. magic_pdf/dict2md/ocr_mkcontent.py +12 -12
  9. magic_pdf/filter/pdf_meta_scan.py +101 -79
  10. magic_pdf/integrations/rag/utils.py +4 -5
  11. magic_pdf/libs/config_reader.py +6 -6
  12. magic_pdf/libs/draw_bbox.py +13 -6
  13. magic_pdf/libs/pdf_image_tools.py +36 -12
  14. magic_pdf/libs/version.py +1 -1
  15. magic_pdf/model/doc_analyze_by_custom_model.py +2 -0
  16. magic_pdf/model/magic_model.py +13 -13
  17. magic_pdf/model/pdf_extract_kit.py +142 -351
  18. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
  19. magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
  20. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
  21. magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
  22. magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
  25. magic_pdf/model/sub_modules/model_init.py +149 -0
  26. magic_pdf/model/sub_modules/model_utils.py +51 -0
  27. magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
  28. magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
  29. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +285 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +176 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
  32. magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
  33. magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
  34. magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
  35. magic_pdf/model/sub_modules/table/__init__.py +0 -0
  36. magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
  37. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +16 -0
  38. magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
  39. magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
  40. magic_pdf/model/sub_modules/table/table_utils.py +11 -0
  41. magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
  42. magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +31 -29
  43. magic_pdf/para/para_split.py +411 -248
  44. magic_pdf/para/para_split_v2.py +352 -182
  45. magic_pdf/para/para_split_v3.py +121 -66
  46. magic_pdf/pdf_parse_by_ocr.py +2 -0
  47. magic_pdf/pdf_parse_by_txt.py +2 -0
  48. magic_pdf/pdf_parse_union_core.py +174 -100
  49. magic_pdf/pdf_parse_union_core_v2.py +253 -50
  50. magic_pdf/pipe/AbsPipe.py +28 -44
  51. magic_pdf/pipe/OCRPipe.py +5 -5
  52. magic_pdf/pipe/TXTPipe.py +5 -6
  53. magic_pdf/pipe/UNIPipe.py +24 -25
  54. magic_pdf/post_proc/pdf_post_filter.py +7 -14
  55. magic_pdf/pre_proc/cut_image.py +9 -11
  56. magic_pdf/pre_proc/equations_replace.py +203 -212
  57. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +235 -49
  58. magic_pdf/pre_proc/ocr_dict_merge.py +5 -5
  59. magic_pdf/pre_proc/ocr_span_list_modify.py +122 -63
  60. magic_pdf/pre_proc/pdf_pre_filter.py +37 -33
  61. magic_pdf/pre_proc/remove_bbox_overlap.py +20 -18
  62. magic_pdf/pre_proc/remove_colored_strip_bbox.py +36 -14
  63. magic_pdf/pre_proc/remove_footer_header.py +2 -5
  64. magic_pdf/pre_proc/remove_rotate_bbox.py +111 -63
  65. magic_pdf/pre_proc/resolve_bbox_conflict.py +10 -17
  66. magic_pdf/resources/model_config/model_configs.yaml +2 -1
  67. magic_pdf/spark/spark_api.py +15 -17
  68. magic_pdf/tools/cli.py +3 -4
  69. magic_pdf/tools/cli_dev.py +6 -9
  70. magic_pdf/tools/common.py +70 -36
  71. magic_pdf/user_api.py +29 -38
  72. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/METADATA +18 -13
  73. magic_pdf-0.10.0.dist-info/RECORD +198 -0
  74. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/WHEEL +1 -1
  75. magic_pdf/libs/Constants.py +0 -53
  76. magic_pdf/libs/MakeContentConfig.py +0 -11
  77. magic_pdf/libs/drop_reason.py +0 -27
  78. magic_pdf/libs/drop_tag.py +0 -19
  79. magic_pdf/model/pek_sub_modules/post_process.py +0 -36
  80. magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
  81. magic_pdf/para/para_pipeline.py +0 -297
  82. magic_pdf-0.9.2.dist-info/RECORD +0 -178
  83. /magic_pdf/{libs → config}/ocr_content_type.py +0 -0
  84. /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
  85. /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
  86. /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
  87. /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
  88. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
  89. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
  90. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
  91. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
  92. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
  93. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
  94. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
  95. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
  96. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
  97. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
  98. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
  99. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
  100. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
  101. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
  102. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
  103. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
  104. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
  105. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
  106. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
  107. /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
  108. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/LICENSE.md +0 -0
  109. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/entry_points.txt +0 -0
  110. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,195 +1,32 @@
1
- from loguru import logger
1
+ # flake8: noqa
2
2
  import os
3
3
  import time
4
- from pathlib import Path
5
- import shutil
6
- from magic_pdf.libs.Constants import *
7
- from magic_pdf.libs.clean_memory import clean_memory
8
- from magic_pdf.model.model_list import AtomicModel
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from loguru import logger
10
+ from PIL import Image
9
11
 
10
12
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
11
13
  os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
14
+
12
15
  try:
13
- import cv2
14
- import yaml
15
- import argparse
16
- import numpy as np
17
- import torch
18
16
  import torchtext
19
17
 
20
- if torchtext.__version__ >= "0.18.0":
18
+ if torchtext.__version__ >= '0.18.0':
21
19
  torchtext.disable_torchtext_deprecation_warning()
22
- from PIL import Image
23
- from torchvision import transforms
24
- from torch.utils.data import Dataset, DataLoader
25
- from ultralytics import YOLO
26
- from unimernet.common.config import Config
27
- import unimernet.tasks as tasks
28
- from unimernet.processors import load_processor
29
- from doclayout_yolo import YOLOv10
30
-
31
- except ImportError as e:
32
- logger.exception(e)
33
- logger.error(
34
- 'Required dependency not installed, please install by \n'
35
- '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
36
- exit(1)
37
-
38
- from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
39
- from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
40
- from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
41
- from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
42
- from magic_pdf.model.ppTableModel import ppTableModel
43
-
44
-
45
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
46
- if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
47
- table_model = StructTableModel(model_path, max_time=max_time)
48
- elif table_model_type == MODEL_NAME.TABLE_MASTER:
49
- config = {
50
- "model_dir": model_path,
51
- "device": _device_
52
- }
53
- table_model = ppTableModel(config)
54
- else:
55
- logger.error("table model type not allow")
56
- exit(1)
57
- return table_model
58
-
59
-
60
- def mfd_model_init(weight):
61
- mfd_model = YOLO(weight)
62
- return mfd_model
63
-
64
-
65
- def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
66
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
67
- cfg = Config(args)
68
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
69
- cfg.config.model.model_config.model_name = weight_dir
70
- cfg.config.model.tokenizer_config.path = weight_dir
71
- task = tasks.setup_task(cfg)
72
- model = task.build_model(cfg)
73
- model.to(_device_)
74
- model.eval()
75
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
76
- mfr_transform = transforms.Compose([vis_processor, ])
77
- return [model, mfr_transform]
78
-
79
-
80
- def layout_model_init(weight, config_file, device):
81
- model = Layoutlmv3_Predictor(weight, config_file, device)
82
- return model
83
-
84
-
85
- def doclayout_yolo_model_init(weight):
86
- model = YOLOv10(weight)
87
- return model
88
-
89
-
90
- def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
91
- if lang is not None:
92
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
93
- else:
94
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
95
- return model
96
-
97
-
98
- class MathDataset(Dataset):
99
- def __init__(self, image_paths, transform=None):
100
- self.image_paths = image_paths
101
- self.transform = transform
102
-
103
- def __len__(self):
104
- return len(self.image_paths)
105
-
106
- def __getitem__(self, idx):
107
- # if not pil image, then convert to pil image
108
- if isinstance(self.image_paths[idx], str):
109
- raw_image = Image.open(self.image_paths[idx])
110
- else:
111
- raw_image = self.image_paths[idx]
112
- if self.transform:
113
- image = self.transform(raw_image)
114
- return image
115
-
116
-
117
- class AtomModelSingleton:
118
- _instance = None
119
- _models = {}
120
-
121
- def __new__(cls, *args, **kwargs):
122
- if cls._instance is None:
123
- cls._instance = super().__new__(cls)
124
- return cls._instance
125
-
126
- def get_atom_model(self, atom_model_name: str, **kwargs):
127
- lang = kwargs.get("lang", None)
128
- layout_model_name = kwargs.get("layout_model_name", None)
129
- key = (atom_model_name, layout_model_name, lang)
130
- if key not in self._models:
131
- self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
132
- return self._models[key]
133
-
134
-
135
- def atom_model_init(model_name: str, **kwargs):
136
-
137
- if model_name == AtomicModel.Layout:
138
- if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
139
- atom_model = layout_model_init(
140
- kwargs.get("layout_weights"),
141
- kwargs.get("layout_config_file"),
142
- kwargs.get("device")
143
- )
144
- elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
145
- atom_model = doclayout_yolo_model_init(
146
- kwargs.get("doclayout_yolo_weights"),
147
- )
148
- elif model_name == AtomicModel.MFD:
149
- atom_model = mfd_model_init(
150
- kwargs.get("mfd_weights")
151
- )
152
- elif model_name == AtomicModel.MFR:
153
- atom_model = mfr_model_init(
154
- kwargs.get("mfr_weight_dir"),
155
- kwargs.get("mfr_cfg_path"),
156
- kwargs.get("device")
157
- )
158
- elif model_name == AtomicModel.OCR:
159
- atom_model = ocr_model_init(
160
- kwargs.get("ocr_show_log"),
161
- kwargs.get("det_db_box_thresh"),
162
- kwargs.get("lang")
163
- )
164
- elif model_name == AtomicModel.Table:
165
- atom_model = table_model_init(
166
- kwargs.get("table_model_name"),
167
- kwargs.get("table_model_path"),
168
- kwargs.get("table_max_time"),
169
- kwargs.get("device")
170
- )
171
- else:
172
- logger.error("model name not allow")
173
- exit(1)
174
-
175
- return atom_model
176
-
20
+ except ImportError:
21
+ pass
177
22
 
178
- # Unified crop img logic
179
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
180
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
181
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
182
- # Create a white background with an additional width and height of 50
183
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
184
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
185
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
186
-
187
- # Crop image
188
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
189
- cropped_img = input_pil_img.crop(crop_box)
190
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
191
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
192
- return return_image, return_list
23
+ from magic_pdf.config.constants import *
24
+ from magic_pdf.model.model_list import AtomicModel
25
+ from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
26
+ from magic_pdf.model.sub_modules.model_utils import (
27
+ clean_vram, crop_img, get_res_list_from_layout_res)
28
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
29
+ get_adjusted_mfdetrec_res, get_ocr_result_list)
193
30
 
194
31
 
195
32
  class CustomPEKModel:
@@ -208,61 +45,80 @@ class CustomPEKModel:
208
45
  model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
209
46
  # 构建 model_configs.yaml 文件的完整路径
210
47
  config_path = os.path.join(model_config_dir, 'model_configs.yaml')
211
- with open(config_path, "r", encoding='utf-8') as f:
48
+ with open(config_path, 'r', encoding='utf-8') as f:
212
49
  self.configs = yaml.load(f, Loader=yaml.FullLoader)
213
50
  # 初始化解析配置
214
51
 
215
52
  # layout config
216
- self.layout_config = kwargs.get("layout_config")
217
- self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
53
+ self.layout_config = kwargs.get('layout_config')
54
+ self.layout_model_name = self.layout_config.get(
55
+ 'model', MODEL_NAME.DocLayout_YOLO
56
+ )
218
57
 
219
58
  # formula config
220
- self.formula_config = kwargs.get("formula_config")
221
- self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
222
- self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
223
- self.apply_formula = self.formula_config.get("enable", True)
59
+ self.formula_config = kwargs.get('formula_config')
60
+ self.mfd_model_name = self.formula_config.get(
61
+ 'mfd_model', MODEL_NAME.YOLO_V8_MFD
62
+ )
63
+ self.mfr_model_name = self.formula_config.get(
64
+ 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
65
+ )
66
+ self.apply_formula = self.formula_config.get('enable', True)
224
67
 
225
68
  # table config
226
- self.table_config = kwargs.get("table_config")
227
- self.apply_table = self.table_config.get("enable", False)
228
- self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
229
- self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
69
+ self.table_config = kwargs.get('table_config')
70
+ self.apply_table = self.table_config.get('enable', False)
71
+ self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
72
+ self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
230
73
 
231
74
  # ocr config
232
75
  self.apply_ocr = ocr
233
- self.lang = kwargs.get("lang", None)
76
+ self.lang = kwargs.get('lang', None)
234
77
 
235
78
  logger.info(
236
- "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
237
- "apply_table: {}, table_model: {}, lang: {}".format(
238
- self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
79
+ 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
80
+ 'apply_table: {}, table_model: {}, lang: {}'.format(
81
+ self.layout_model_name,
82
+ self.apply_formula,
83
+ self.apply_ocr,
84
+ self.apply_table,
85
+ self.table_model_name,
86
+ self.lang,
239
87
  )
240
88
  )
241
89
  # 初始化解析方案
242
- self.device = kwargs.get("device", "cpu")
243
- logger.info("using device: {}".format(self.device))
244
- models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
245
- logger.info("using models_dir: {}".format(models_dir))
90
+ self.device = kwargs.get('device', 'cpu')
91
+ logger.info('using device: {}'.format(self.device))
92
+ models_dir = kwargs.get(
93
+ 'models_dir', os.path.join(root_dir, 'resources', 'models')
94
+ )
95
+ logger.info('using models_dir: {}'.format(models_dir))
246
96
 
247
97
  atom_model_manager = AtomModelSingleton()
248
98
 
249
99
  # 初始化公式识别
250
100
  if self.apply_formula:
251
-
252
101
  # 初始化公式检测模型
253
102
  self.mfd_model = atom_model_manager.get_atom_model(
254
103
  atom_model_name=AtomicModel.MFD,
255
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
104
+ mfd_weights=str(
105
+ os.path.join(
106
+ models_dir, self.configs['weights'][self.mfd_model_name]
107
+ )
108
+ ),
109
+ device=self.device,
256
110
  )
257
111
 
258
112
  # 初始化公式解析模型
259
- mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
260
- mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
261
- self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
113
+ mfr_weight_dir = str(
114
+ os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
115
+ )
116
+ mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
117
+ self.mfr_model = atom_model_manager.get_atom_model(
262
118
  atom_model_name=AtomicModel.MFR,
263
119
  mfr_weight_dir=mfr_weight_dir,
264
120
  mfr_cfg_path=mfr_cfg_path,
265
- device=self.device
121
+ device=self.device,
266
122
  )
267
123
 
268
124
  # 初始化layout模型
@@ -270,172 +126,110 @@ class CustomPEKModel:
270
126
  self.layout_model = atom_model_manager.get_atom_model(
271
127
  atom_model_name=AtomicModel.Layout,
272
128
  layout_model_name=MODEL_NAME.LAYOUTLMv3,
273
- layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
274
- layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
275
- device=self.device
129
+ layout_weights=str(
130
+ os.path.join(
131
+ models_dir, self.configs['weights'][self.layout_model_name]
132
+ )
133
+ ),
134
+ layout_config_file=str(
135
+ os.path.join(
136
+ model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
137
+ )
138
+ ),
139
+ device=self.device,
276
140
  )
277
141
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
278
142
  self.layout_model = atom_model_manager.get_atom_model(
279
143
  atom_model_name=AtomicModel.Layout,
280
144
  layout_model_name=MODEL_NAME.DocLayout_YOLO,
281
- doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
145
+ doclayout_yolo_weights=str(
146
+ os.path.join(
147
+ models_dir, self.configs['weights'][self.layout_model_name]
148
+ )
149
+ ),
150
+ device=self.device,
282
151
  )
283
152
  # 初始化ocr
284
- if self.apply_ocr:
285
- self.ocr_model = atom_model_manager.get_atom_model(
286
- atom_model_name=AtomicModel.OCR,
287
- ocr_show_log=show_log,
288
- det_db_box_thresh=0.3,
289
- lang=self.lang
290
- )
153
+ self.ocr_model = atom_model_manager.get_atom_model(
154
+ atom_model_name=AtomicModel.OCR,
155
+ ocr_show_log=show_log,
156
+ det_db_box_thresh=0.3,
157
+ lang=self.lang
158
+ )
291
159
  # init table model
292
160
  if self.apply_table:
293
- table_model_dir = self.configs["weights"][self.table_model_name]
161
+ table_model_dir = self.configs['weights'][self.table_model_name]
294
162
  self.table_model = atom_model_manager.get_atom_model(
295
163
  atom_model_name=AtomicModel.Table,
296
164
  table_model_name=self.table_model_name,
297
165
  table_model_path=str(os.path.join(models_dir, table_model_dir)),
298
166
  table_max_time=self.table_max_time,
299
- device=self.device
167
+ device=self.device,
300
168
  )
301
169
 
302
170
  logger.info('DocAnalysis init done!')
303
171
 
304
172
  def __call__(self, image):
305
173
 
306
- page_start = time.time()
307
-
308
- latex_filling_list = []
309
- mf_image_list = []
310
-
311
174
  # layout检测
312
175
  layout_start = time.time()
176
+ layout_res = []
313
177
  if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
314
178
  # layoutlmv3
315
179
  layout_res = self.layout_model(image, ignore_catids=[])
316
180
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
317
181
  # doclayout_yolo
318
- layout_res = []
319
- doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
320
- for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
321
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
322
- new_item = {
323
- 'category_id': int(cla.item()),
324
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
325
- 'score': round(float(conf.item()), 3),
326
- }
327
- layout_res.append(new_item)
182
+ layout_res = self.layout_model.predict(image)
328
183
  layout_cost = round(time.time() - layout_start, 2)
329
- logger.info(f"layout detection time: {layout_cost}")
184
+ logger.info(f'layout detection time: {layout_cost}')
330
185
 
331
186
  pil_img = Image.fromarray(image)
332
187
 
333
188
  if self.apply_formula:
334
189
  # 公式检测
335
190
  mfd_start = time.time()
336
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
337
- logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
338
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
339
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
340
- new_item = {
341
- 'category_id': 13 + int(cla.item()),
342
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
343
- 'score': round(float(conf.item()), 2),
344
- 'latex': '',
345
- }
346
- layout_res.append(new_item)
347
- latex_filling_list.append(new_item)
348
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
349
- mf_image_list.append(bbox_img)
191
+ mfd_res = self.mfd_model.predict(image)
192
+ logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
350
193
 
351
194
  # 公式识别
352
195
  mfr_start = time.time()
353
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
354
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
355
- mfr_res = []
356
- for mf_img in dataloader:
357
- mf_img = mf_img.to(self.device)
358
- with torch.no_grad():
359
- output = self.mfr_model.generate({'image': mf_img})
360
- mfr_res.extend(output['pred_str'])
361
- for res, latex in zip(latex_filling_list, mfr_res):
362
- res['latex'] = latex_rm_whitespace(latex)
196
+ formula_list = self.mfr_model.predict(mfd_res, image)
197
+ layout_res.extend(formula_list)
363
198
  mfr_cost = round(time.time() - mfr_start, 2)
364
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
199
+ logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
365
200
 
366
- # Select regions for OCR / formula regions / table regions
367
- ocr_res_list = []
368
- table_res_list = []
369
- single_page_mfdetrec_res = []
370
- for res in layout_res:
371
- if int(res['category_id']) in [13, 14]:
372
- single_page_mfdetrec_res.append({
373
- "bbox": [int(res['poly'][0]), int(res['poly'][1]),
374
- int(res['poly'][4]), int(res['poly'][5])],
375
- })
376
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
377
- ocr_res_list.append(res)
378
- elif int(res['category_id']) in [5]:
379
- table_res_list.append(res)
201
+ # 清理显存
202
+ clean_vram(self.device, vram_threshold=8)
380
203
 
381
- if torch.cuda.is_available() and self.device != 'cpu':
382
- properties = torch.cuda.get_device_properties(self.device)
383
- total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
384
- if total_memory <= 10:
385
- gc_start = time.time()
386
- clean_memory()
387
- gc_time = round(time.time() - gc_start, 2)
388
- logger.info(f"gc time: {gc_time}")
204
+ # 从layout_res中获取ocr区域、表格区域、公式区域
205
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = (
206
+ get_res_list_from_layout_res(layout_res)
207
+ )
389
208
 
390
209
  # ocr识别
391
- if self.apply_ocr:
392
- ocr_start = time.time()
393
- # Process each area that requires OCR processing
394
- for res in ocr_res_list:
395
- new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
396
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
397
- # Adjust the coordinates of the formula area
398
- adjusted_mfdetrec_res = []
399
- for mf_res in single_page_mfdetrec_res:
400
- mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
401
- # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
402
- x0 = mf_xmin - xmin + paste_x
403
- y0 = mf_ymin - ymin + paste_y
404
- x1 = mf_xmax - xmin + paste_x
405
- y1 = mf_ymax - ymin + paste_y
406
- # Filter formula blocks outside the graph
407
- if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
408
- continue
409
- else:
410
- adjusted_mfdetrec_res.append({
411
- "bbox": [x0, y0, x1, y1],
412
- })
413
-
414
- # OCR recognition
415
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
210
+ ocr_start = time.time()
211
+ # Process each area that requires OCR processing
212
+ for res in ocr_res_list:
213
+ new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
214
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
215
+
216
+ # OCR recognition
217
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
218
+ if self.apply_ocr:
416
219
  ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
220
+ else:
221
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
417
222
 
418
- # Integration results
419
- if ocr_res:
420
- for box_ocr_res in ocr_res:
421
- p1, p2, p3, p4 = box_ocr_res[0]
422
- text, score = box_ocr_res[1]
223
+ # Integration results
224
+ if ocr_res:
225
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
226
+ layout_res.extend(ocr_result_list)
423
227
 
424
- # Convert the coordinates back to the original coordinate system
425
- p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
426
- p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
427
- p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
428
- p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
429
-
430
- layout_res.append({
431
- 'category_id': 15,
432
- 'poly': p1 + p2 + p3 + p4,
433
- 'score': round(score, 2),
434
- 'text': text,
435
- })
436
-
437
- ocr_cost = round(time.time() - ocr_start, 2)
228
+ ocr_cost = round(time.time() - ocr_start, 2)
229
+ if self.apply_ocr:
438
230
  logger.info(f"ocr time: {ocr_cost}")
231
+ else:
232
+ logger.info(f"det time: {ocr_cost}")
439
233
 
440
234
  # 表格识别 table recognition
441
235
  if self.apply_table:
@@ -443,41 +237,38 @@ class CustomPEKModel:
443
237
  for res in table_res_list:
444
238
  new_image, _ = crop_img(res, pil_img)
445
239
  single_table_start_time = time.time()
446
- # logger.info("------------------table recognition processing begins-----------------")
447
- latex_code = None
448
240
  html_code = None
449
241
  if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
450
242
  with torch.no_grad():
451
- table_result = self.table_model.predict(new_image, "html")
243
+ table_result = self.table_model.predict(new_image, 'html')
452
244
  if len(table_result) > 0:
453
245
  html_code = table_result[0]
454
- else:
246
+ elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
455
247
  html_code = self.table_model.img2html(new_image)
456
-
248
+ elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
249
+ html_code, table_cell_bboxes, elapse = self.table_model.predict(
250
+ new_image
251
+ )
457
252
  run_time = time.time() - single_table_start_time
458
- # logger.info(f"------------table recognition processing ends within {run_time}s-----")
459
253
  if run_time > self.table_max_time:
460
- logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
254
+ logger.warning(
255
+ f'table recognition processing exceeds max time {self.table_max_time}s'
256
+ )
461
257
  # 判断是否返回正常
462
-
463
- if latex_code:
464
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
465
- if expected_ending:
466
- res["latex"] = latex_code
467
- else:
468
- logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
469
- elif html_code:
470
- expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
258
+ if html_code:
259
+ expected_ending = html_code.strip().endswith(
260
+ '</html>'
261
+ ) or html_code.strip().endswith('</table>')
471
262
  if expected_ending:
472
- res["html"] = html_code
263
+ res['html'] = html_code
473
264
  else:
474
- logger.warning(f"table recognition processing fails, not found expected HTML table end")
265
+ logger.warning(
266
+ 'table recognition processing fails, not found expected HTML table end'
267
+ )
475
268
  else:
476
- logger.warning(f"table recognition processing fails, not get latex or html return")
477
- logger.info(f"table time: {round(time.time() - table_start, 2)}")
478
-
479
- logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
269
+ logger.warning(
270
+ 'table recognition processing fails, not get html return'
271
+ )
272
+ logger.info(f'table time: {round(time.time() - table_start, 2)}')
480
273
 
481
274
  return layout_res
482
-
483
-
@@ -0,0 +1,21 @@
1
+ from doclayout_yolo import YOLOv10
2
+
3
+
4
+ class DocLayoutYOLOModel(object):
5
+ def __init__(self, weight, device):
6
+ self.model = YOLOv10(weight)
7
+ self.device = device
8
+
9
+ def predict(self, image):
10
+ layout_res = []
11
+ doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
12
+ for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
13
+ doclayout_yolo_res.boxes.cls.cpu()):
14
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
15
+ new_item = {
16
+ 'category_id': int(cla.item()),
17
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
18
+ 'score': round(float(conf.item()), 3),
19
+ }
20
+ layout_res.append(new_item)
21
+ return layout_res
File without changes
@@ -0,0 +1,12 @@
1
+ from ultralytics import YOLO
2
+
3
+
4
+ class YOLOv8MFDModel(object):
5
+ def __init__(self, weight, device='cpu'):
6
+ self.mfd_model = YOLO(weight)
7
+ self.device = device
8
+
9
+ def predict(self, image):
10
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
11
+ return mfd_res
12
+
File without changes
File without changes