magic-pdf 0.7.0b1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +134 -76
  2. magic_pdf/integrations/__init__.py +0 -0
  3. magic_pdf/integrations/rag/__init__.py +0 -0
  4. magic_pdf/integrations/rag/api.py +82 -0
  5. magic_pdf/integrations/rag/type.py +82 -0
  6. magic_pdf/integrations/rag/utils.py +285 -0
  7. magic_pdf/layout/layout_sort.py +472 -283
  8. magic_pdf/libs/Constants.py +27 -1
  9. magic_pdf/libs/boxbase.py +169 -149
  10. magic_pdf/libs/draw_bbox.py +113 -87
  11. magic_pdf/libs/ocr_content_type.py +21 -18
  12. magic_pdf/libs/version.py +1 -1
  13. magic_pdf/model/doc_analyze_by_custom_model.py +14 -2
  14. magic_pdf/model/magic_model.py +230 -161
  15. magic_pdf/model/model_list.py +8 -0
  16. magic_pdf/model/pdf_extract_kit.py +135 -22
  17. magic_pdf/model/pek_sub_modules/self_modify.py +84 -0
  18. magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +0 -1
  19. magic_pdf/model/ppTableModel.py +67 -0
  20. magic_pdf/para/para_split_v2.py +76 -74
  21. magic_pdf/pdf_parse_union_core.py +34 -6
  22. magic_pdf/pipe/AbsPipe.py +4 -1
  23. magic_pdf/pipe/OCRPipe.py +7 -4
  24. magic_pdf/pipe/TXTPipe.py +7 -4
  25. magic_pdf/pipe/UNIPipe.py +11 -6
  26. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +12 -3
  27. magic_pdf/pre_proc/ocr_dict_merge.py +60 -59
  28. magic_pdf/resources/model_config/model_configs.yaml +3 -1
  29. magic_pdf/tools/cli.py +56 -29
  30. magic_pdf/tools/cli_dev.py +61 -64
  31. magic_pdf/tools/common.py +57 -37
  32. magic_pdf/user_api.py +17 -9
  33. {magic_pdf-0.7.0b1.dist-info → magic_pdf-0.8.0.dist-info}/METADATA +71 -33
  34. {magic_pdf-0.7.0b1.dist-info → magic_pdf-0.8.0.dist-info}/RECORD +38 -32
  35. {magic_pdf-0.7.0b1.dist-info → magic_pdf-0.8.0.dist-info}/LICENSE.md +0 -0
  36. {magic_pdf-0.7.0b1.dist-info → magic_pdf-0.8.0.dist-info}/WHEEL +0 -0
  37. {magic_pdf-0.7.0b1.dist-info → magic_pdf-0.8.0.dist-info}/entry_points.txt +0 -0
  38. {magic_pdf-0.7.0b1.dist-info → magic_pdf-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,11 @@
1
1
  class MODEL:
2
2
  Paddle = "pp_structure_v2"
3
3
  PEK = "pdf_extract_kit"
4
+
5
+
6
+ class AtomicModel:
7
+ Layout = "layout"
8
+ MFD = "mfd"
9
+ MFR = "mfr"
10
+ OCR = "ocr"
11
+ Table = "table"
@@ -2,7 +2,8 @@ from loguru import logger
2
2
  import os
3
3
  import time
4
4
 
5
- from magic_pdf.libs.Constants import TABLE_MAX_TIME_VALUE
5
+ from magic_pdf.libs.Constants import *
6
+ from magic_pdf.model.model_list import AtomicModel
6
7
 
7
8
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
8
9
  try:
@@ -34,10 +35,18 @@ from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Pre
34
35
  from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
35
36
  from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
36
37
  from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
37
-
38
-
39
- def table_model_init(model_path, max_time, _device_='cpu'):
40
- table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
38
+ from magic_pdf.model.ppTableModel import ppTableModel
39
+
40
+
41
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
42
+ if table_model_type == STRUCT_EQTABLE:
43
+ table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
44
+ else:
45
+ config = {
46
+ "model_dir": model_path,
47
+ "device": _device_
48
+ }
49
+ table_model = ppTableModel(config)
41
50
  return table_model
42
51
 
43
52
 
@@ -56,7 +65,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
56
65
  model = task.build_model(cfg)
57
66
  model = model.to(_device_)
58
67
  vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
59
- return model, vis_processor
68
+ mfr_transform = transforms.Compose([vis_processor, ])
69
+ return [model, mfr_transform]
60
70
 
61
71
 
62
72
  def layout_model_init(weight, config_file, device):
@@ -64,6 +74,11 @@ def layout_model_init(weight, config_file, device):
64
74
  return model
65
75
 
66
76
 
77
+ def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
78
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
79
+ return model
80
+
81
+
67
82
  class MathDataset(Dataset):
68
83
  def __init__(self, image_paths, transform=None):
69
84
  self.image_paths = image_paths
@@ -83,6 +98,58 @@ class MathDataset(Dataset):
83
98
  return image
84
99
 
85
100
 
101
+ class AtomModelSingleton:
102
+ _instance = None
103
+ _models = {}
104
+
105
+ def __new__(cls, *args, **kwargs):
106
+ if cls._instance is None:
107
+ cls._instance = super().__new__(cls)
108
+ return cls._instance
109
+
110
+ def get_atom_model(self, atom_model_name: str, **kwargs):
111
+ if atom_model_name not in self._models:
112
+ self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
113
+ return self._models[atom_model_name]
114
+
115
+
116
+ def atom_model_init(model_name: str, **kwargs):
117
+
118
+ if model_name == AtomicModel.Layout:
119
+ atom_model = layout_model_init(
120
+ kwargs.get("layout_weights"),
121
+ kwargs.get("layout_config_file"),
122
+ kwargs.get("device")
123
+ )
124
+ elif model_name == AtomicModel.MFD:
125
+ atom_model = mfd_model_init(
126
+ kwargs.get("mfd_weights")
127
+ )
128
+ elif model_name == AtomicModel.MFR:
129
+ atom_model = mfr_model_init(
130
+ kwargs.get("mfr_weight_dir"),
131
+ kwargs.get("mfr_cfg_path"),
132
+ kwargs.get("device")
133
+ )
134
+ elif model_name == AtomicModel.OCR:
135
+ atom_model = ocr_model_init(
136
+ kwargs.get("ocr_show_log"),
137
+ kwargs.get("det_db_box_thresh")
138
+ )
139
+ elif model_name == AtomicModel.Table:
140
+ atom_model = table_model_init(
141
+ kwargs.get("table_model_type"),
142
+ kwargs.get("table_model_path"),
143
+ kwargs.get("table_max_time"),
144
+ kwargs.get("device")
145
+ )
146
+ else:
147
+ logger.error("model name not allow")
148
+ exit(1)
149
+
150
+ return atom_model
151
+
152
+
86
153
  class CustomPEKModel:
87
154
 
88
155
  def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -104,9 +171,11 @@ class CustomPEKModel:
104
171
  # 初始化解析配置
105
172
  self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
106
173
  self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
174
+ # table config
107
175
  self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
108
176
  self.apply_table = self.table_config.get("is_table_recog_enable", False)
109
177
  self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
178
+ self.table_model_type = self.table_config.get("model", TABLE_MASTER)
110
179
  self.apply_ocr = ocr
111
180
  logger.info(
112
181
  "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
@@ -120,31 +189,62 @@ class CustomPEKModel:
120
189
  models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
121
190
  logger.info("using models_dir: {}".format(models_dir))
122
191
 
192
+ atom_model_manager = AtomModelSingleton()
193
+
123
194
  # 初始化公式识别
124
195
  if self.apply_formula:
125
196
  # 初始化公式检测模型
126
- self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
127
-
197
+ # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
198
+ self.mfd_model = atom_model_manager.get_atom_model(
199
+ atom_model_name=AtomicModel.MFD,
200
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
201
+ )
128
202
  # 初始化公式解析模型
129
203
  mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
130
204
  mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
131
- self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
132
- self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
205
+ # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
206
+ # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
207
+ self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
208
+ atom_model_name=AtomicModel.MFR,
209
+ mfr_weight_dir=mfr_weight_dir,
210
+ mfr_cfg_path=mfr_cfg_path,
211
+ device=self.device
212
+ )
133
213
 
134
214
  # 初始化layout模型
135
- self.layout_model = Layoutlmv3_Predictor(
136
- str(os.path.join(models_dir, self.configs['weights']['layout'])),
137
- str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
215
+ # self.layout_model = Layoutlmv3_Predictor(
216
+ # str(os.path.join(models_dir, self.configs['weights']['layout'])),
217
+ # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
218
+ # device=self.device
219
+ # )
220
+ self.layout_model = atom_model_manager.get_atom_model(
221
+ atom_model_name=AtomicModel.Layout,
222
+ layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
223
+ layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
138
224
  device=self.device
139
225
  )
140
226
  # 初始化ocr
141
227
  if self.apply_ocr:
142
- self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
143
228
 
144
- # init structeqtable
229
+ # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
230
+ self.ocr_model = atom_model_manager.get_atom_model(
231
+ atom_model_name=AtomicModel.OCR,
232
+ ocr_show_log=show_log,
233
+ det_db_box_thresh=0.3
234
+ )
235
+ # init table model
145
236
  if self.apply_table:
146
- self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
147
- max_time = self.table_max_time, _device_=self.device)
237
+ table_model_dir = self.configs["weights"][self.table_model_type]
238
+ # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
239
+ # max_time=self.table_max_time, _device_=self.device)
240
+ self.table_model = atom_model_manager.get_atom_model(
241
+ atom_model_name=AtomicModel.Table,
242
+ table_model_type=self.table_model_type,
243
+ table_model_path=str(os.path.join(models_dir, table_model_dir)),
244
+ table_max_time=self.table_max_time,
245
+ device=self.device
246
+ )
247
+
148
248
  logger.info('DocAnalysis init done!')
149
249
 
150
250
  def __call__(self, image):
@@ -278,16 +378,29 @@ class CustomPEKModel:
278
378
  new_image, _ = crop_img(res, pil_img)
279
379
  single_table_start_time = time.time()
280
380
  logger.info("------------------table recognition processing begins-----------------")
281
- with torch.no_grad():
282
- latex_code = self.table_model.image2latex(new_image)[0]
381
+ latex_code = None
382
+ html_code = None
383
+ if self.table_model_type == STRUCT_EQTABLE:
384
+ with torch.no_grad():
385
+ latex_code = self.table_model.image2latex(new_image)[0]
386
+ else:
387
+ html_code = self.table_model.img2html(new_image)
388
+
283
389
  run_time = time.time() - single_table_start_time
284
390
  logger.info(f"------------table recognition processing ends within {run_time}s-----")
285
391
  if run_time > self.table_max_time:
286
392
  logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
287
393
  # 判断是否返回正常
288
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
289
- if latex_code and expected_ending:
290
- res["latex"] = latex_code
394
+
395
+ if latex_code:
396
+ expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
397
+ 'end{table}')
398
+ if expected_ending:
399
+ res["latex"] = latex_code
400
+ else:
401
+ logger.warning(f"------------table recognition processing fails----------")
402
+ elif html_code:
403
+ res["html"] = html_code
291
404
  else:
292
405
  logger.warning(f"------------table recognition processing fails----------")
293
406
  table_cost = round(time.time() - table_start, 2)
@@ -12,6 +12,7 @@ from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binari
12
12
  from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
13
13
 
14
14
  from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
15
+ from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
15
16
 
16
17
  logger = get_logger()
17
18
 
@@ -162,6 +163,86 @@ def update_det_boxes(dt_boxes, mfd_res):
162
163
  return new_dt_boxes
163
164
 
164
165
 
166
+ def merge_overlapping_spans(spans):
167
+ """
168
+ Merges overlapping spans on the same line.
169
+
170
+ :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
171
+ :return: A list of merged spans
172
+ """
173
+ # Return an empty list if the input spans list is empty
174
+ if not spans:
175
+ return []
176
+
177
+ # Sort spans by their starting x-coordinate
178
+ spans.sort(key=lambda x: x[0])
179
+
180
+ # Initialize the list of merged spans
181
+ merged = []
182
+ for span in spans:
183
+ # Unpack span coordinates
184
+ x1, y1, x2, y2 = span
185
+ # If the merged list is empty or there's no horizontal overlap, add the span directly
186
+ if not merged or merged[-1][2] < x1:
187
+ merged.append(span)
188
+ else:
189
+ # If there is horizontal overlap, merge the current span with the previous one
190
+ last_span = merged.pop()
191
+ # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
192
+ x1 = min(last_span[0], x1)
193
+ y1 = min(last_span[1], y1)
194
+ x2 = max(last_span[2], x2)
195
+ y2 = max(last_span[3], y2)
196
+ # Add the merged span back to the list
197
+ merged.append((x1, y1, x2, y2))
198
+
199
+ # Return the list of merged spans
200
+ return merged
201
+
202
+
203
+ def merge_det_boxes(dt_boxes):
204
+ """
205
+ Merge detection boxes.
206
+
207
+ This function takes a list of detected bounding boxes, each represented by four corner points.
208
+ The goal is to merge these bounding boxes into larger text regions.
209
+
210
+ Parameters:
211
+ dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
212
+
213
+ Returns:
214
+ list: A list containing the merged text regions, where each region is represented by four corner points.
215
+ """
216
+ # Convert the detection boxes into a dictionary format with bounding boxes and type
217
+ dt_boxes_dict_list = []
218
+ for text_box in dt_boxes:
219
+ text_bbox = points_to_bbox(text_box)
220
+ text_box_dict = {
221
+ 'bbox': text_bbox,
222
+ 'type': 'text',
223
+ }
224
+ dt_boxes_dict_list.append(text_box_dict)
225
+
226
+ # Merge adjacent text regions into lines
227
+ lines = merge_spans_to_line(dt_boxes_dict_list)
228
+
229
+ # Initialize a new list for storing the merged text regions
230
+ new_dt_boxes = []
231
+ for line in lines:
232
+ line_bbox_list = []
233
+ for span in line:
234
+ line_bbox_list.append(span['bbox'])
235
+
236
+ # Merge overlapping text regions within the same line
237
+ merged_spans = merge_overlapping_spans(line_bbox_list)
238
+
239
+ # Convert the merged text regions back to point format and add them to the new detection box list
240
+ for span in merged_spans:
241
+ new_dt_boxes.append(bbox_to_points(span))
242
+
243
+ return new_dt_boxes
244
+
245
+
165
246
  class ModifiedPaddleOCR(PaddleOCR):
166
247
  def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
167
248
  """
@@ -265,6 +346,9 @@ class ModifiedPaddleOCR(PaddleOCR):
265
346
  img_crop_list = []
266
347
 
267
348
  dt_boxes = sorted_boxes(dt_boxes)
349
+
350
+ dt_boxes = merge_det_boxes(dt_boxes)
351
+
268
352
  if mfd_res:
269
353
  bef = time.time()
270
354
  dt_boxes = update_det_boxes(dt_boxes, mfd_res)
@@ -12,7 +12,6 @@ class StructTableModel:
12
12
  self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
13
13
 
14
14
  def image2latex(self, image) -> str:
15
- #
16
15
  table_latex = self.model.forward(image)
17
16
  return table_latex
18
17
 
@@ -0,0 +1,67 @@
1
+ from paddleocr.ppstructure.table.predict_table import TableSystem
2
+ from paddleocr.ppstructure.utility import init_args
3
+ from magic_pdf.libs.Constants import *
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+
9
+ class ppTableModel(object):
10
+ """
11
+ This class is responsible for converting image of table into HTML format using a pre-trained model.
12
+
13
+ Attributes:
14
+ - table_sys: An instance of TableSystem initialized with parsed arguments.
15
+
16
+ Methods:
17
+ - __init__(config): Initializes the model with configuration parameters.
18
+ - img2html(image): Converts a PIL Image or NumPy array to HTML string.
19
+ - parse_args(**kwargs): Parses configuration arguments.
20
+ """
21
+
22
+ def __init__(self, config):
23
+ """
24
+ Parameters:
25
+ - config (dict): Configuration dictionary containing model_dir and device.
26
+ """
27
+ args = self.parse_args(**config)
28
+ self.table_sys = TableSystem(args)
29
+
30
+ def img2html(self, image):
31
+ """
32
+ Parameters:
33
+ - image (PIL.Image or np.ndarray): The image of the table to be converted.
34
+
35
+ Return:
36
+ - HTML (str): A string representing the HTML structure with content of the table.
37
+ """
38
+ if isinstance(image, Image.Image):
39
+ image = np.array(image)
40
+ pred_res, _ = self.table_sys(image)
41
+ pred_html = pred_res["html"]
42
+ res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
43
+ "") + "</table></td>\n"
44
+ return res
45
+
46
+ def parse_args(self, **kwargs):
47
+ parser = init_args()
48
+ model_dir = kwargs.get("model_dir")
49
+ table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)
50
+ table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)
51
+ det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)
52
+ rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
53
+ rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
54
+ device = kwargs.get("device", "cpu")
55
+ use_gpu = True if device == "cuda" else False
56
+ config = {
57
+ "use_gpu": use_gpu,
58
+ "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
59
+ "table_algorithm": TABLE_MASTER,
60
+ "table_model_dir": table_model_dir,
61
+ "table_char_dict_path": table_char_dict_path,
62
+ "det_model_dir": det_model_dir,
63
+ "rec_model_dir": rec_model_dir,
64
+ "rec_char_dict_path": rec_char_dict_path,
65
+ }
66
+ parser.set_defaults(**config)
67
+ return parser.parse_args([])