magic-pdf 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +130 -76
  2. magic_pdf/integrations/__init__.py +0 -0
  3. magic_pdf/integrations/rag/__init__.py +0 -0
  4. magic_pdf/integrations/rag/api.py +82 -0
  5. magic_pdf/integrations/rag/type.py +82 -0
  6. magic_pdf/integrations/rag/utils.py +285 -0
  7. magic_pdf/layout/layout_sort.py +472 -283
  8. magic_pdf/libs/boxbase.py +188 -149
  9. magic_pdf/libs/draw_bbox.py +113 -87
  10. magic_pdf/libs/ocr_content_type.py +21 -18
  11. magic_pdf/libs/version.py +1 -1
  12. magic_pdf/model/doc_analyze_by_custom_model.py +14 -2
  13. magic_pdf/model/magic_model.py +283 -166
  14. magic_pdf/model/model_list.py +8 -0
  15. magic_pdf/model/pdf_extract_kit.py +105 -15
  16. magic_pdf/model/pek_sub_modules/self_modify.py +84 -0
  17. magic_pdf/para/para_split_v2.py +26 -27
  18. magic_pdf/pdf_parse_union_core.py +34 -6
  19. magic_pdf/pipe/AbsPipe.py +4 -1
  20. magic_pdf/pipe/OCRPipe.py +7 -4
  21. magic_pdf/pipe/TXTPipe.py +7 -4
  22. magic_pdf/pipe/UNIPipe.py +11 -6
  23. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +12 -3
  24. magic_pdf/pre_proc/ocr_dict_merge.py +60 -59
  25. magic_pdf/tools/cli.py +56 -29
  26. magic_pdf/tools/cli_dev.py +61 -64
  27. magic_pdf/tools/common.py +57 -37
  28. magic_pdf/user_api.py +17 -9
  29. {magic_pdf-0.7.1.dist-info → magic_pdf-0.8.1.dist-info}/METADATA +72 -27
  30. {magic_pdf-0.7.1.dist-info → magic_pdf-0.8.1.dist-info}/RECORD +34 -29
  31. {magic_pdf-0.7.1.dist-info → magic_pdf-0.8.1.dist-info}/LICENSE.md +0 -0
  32. {magic_pdf-0.7.1.dist-info → magic_pdf-0.8.1.dist-info}/WHEEL +0 -0
  33. {magic_pdf-0.7.1.dist-info → magic_pdf-0.8.1.dist-info}/entry_points.txt +0 -0
  34. {magic_pdf-0.7.1.dist-info → magic_pdf-0.8.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,11 @@
1
1
  class MODEL:
2
2
  Paddle = "pp_structure_v2"
3
3
  PEK = "pdf_extract_kit"
4
+
5
+
6
+ class AtomicModel:
7
+ Layout = "layout"
8
+ MFD = "mfd"
9
+ MFR = "mfr"
10
+ OCR = "ocr"
11
+ Table = "table"
@@ -3,6 +3,7 @@ import os
3
3
  import time
4
4
 
5
5
  from magic_pdf.libs.Constants import *
6
+ from magic_pdf.model.model_list import AtomicModel
6
7
 
7
8
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
8
9
  try:
@@ -64,7 +65,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
64
65
  model = task.build_model(cfg)
65
66
  model = model.to(_device_)
66
67
  vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
67
- return model, vis_processor
68
+ mfr_transform = transforms.Compose([vis_processor, ])
69
+ return [model, mfr_transform]
68
70
 
69
71
 
70
72
  def layout_model_init(weight, config_file, device):
@@ -72,6 +74,11 @@ def layout_model_init(weight, config_file, device):
72
74
  return model
73
75
 
74
76
 
77
+ def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
78
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
79
+ return model
80
+
81
+
75
82
  class MathDataset(Dataset):
76
83
  def __init__(self, image_paths, transform=None):
77
84
  self.image_paths = image_paths
@@ -91,6 +98,58 @@ class MathDataset(Dataset):
91
98
  return image
92
99
 
93
100
 
101
+ class AtomModelSingleton:
102
+ _instance = None
103
+ _models = {}
104
+
105
+ def __new__(cls, *args, **kwargs):
106
+ if cls._instance is None:
107
+ cls._instance = super().__new__(cls)
108
+ return cls._instance
109
+
110
+ def get_atom_model(self, atom_model_name: str, **kwargs):
111
+ if atom_model_name not in self._models:
112
+ self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
113
+ return self._models[atom_model_name]
114
+
115
+
116
+ def atom_model_init(model_name: str, **kwargs):
117
+
118
+ if model_name == AtomicModel.Layout:
119
+ atom_model = layout_model_init(
120
+ kwargs.get("layout_weights"),
121
+ kwargs.get("layout_config_file"),
122
+ kwargs.get("device")
123
+ )
124
+ elif model_name == AtomicModel.MFD:
125
+ atom_model = mfd_model_init(
126
+ kwargs.get("mfd_weights")
127
+ )
128
+ elif model_name == AtomicModel.MFR:
129
+ atom_model = mfr_model_init(
130
+ kwargs.get("mfr_weight_dir"),
131
+ kwargs.get("mfr_cfg_path"),
132
+ kwargs.get("device")
133
+ )
134
+ elif model_name == AtomicModel.OCR:
135
+ atom_model = ocr_model_init(
136
+ kwargs.get("ocr_show_log"),
137
+ kwargs.get("det_db_box_thresh")
138
+ )
139
+ elif model_name == AtomicModel.Table:
140
+ atom_model = table_model_init(
141
+ kwargs.get("table_model_type"),
142
+ kwargs.get("table_model_path"),
143
+ kwargs.get("table_max_time"),
144
+ kwargs.get("device")
145
+ )
146
+ else:
147
+ logger.error("model name not allow")
148
+ exit(1)
149
+
150
+ return atom_model
151
+
152
+
94
153
  class CustomPEKModel:
95
154
 
96
155
  def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -130,32 +189,62 @@ class CustomPEKModel:
130
189
  models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
131
190
  logger.info("using models_dir: {}".format(models_dir))
132
191
 
192
+ atom_model_manager = AtomModelSingleton()
193
+
133
194
  # 初始化公式识别
134
195
  if self.apply_formula:
135
196
  # 初始化公式检测模型
136
- self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
137
-
197
+ # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
198
+ self.mfd_model = atom_model_manager.get_atom_model(
199
+ atom_model_name=AtomicModel.MFD,
200
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
201
+ )
138
202
  # 初始化公式解析模型
139
203
  mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
140
204
  mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
141
- self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
142
- self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
205
+ # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
206
+ # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
207
+ self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
208
+ atom_model_name=AtomicModel.MFR,
209
+ mfr_weight_dir=mfr_weight_dir,
210
+ mfr_cfg_path=mfr_cfg_path,
211
+ device=self.device
212
+ )
143
213
 
144
214
  # 初始化layout模型
145
- self.layout_model = Layoutlmv3_Predictor(
146
- str(os.path.join(models_dir, self.configs['weights']['layout'])),
147
- str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
215
+ # self.layout_model = Layoutlmv3_Predictor(
216
+ # str(os.path.join(models_dir, self.configs['weights']['layout'])),
217
+ # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
218
+ # device=self.device
219
+ # )
220
+ self.layout_model = atom_model_manager.get_atom_model(
221
+ atom_model_name=AtomicModel.Layout,
222
+ layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
223
+ layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
148
224
  device=self.device
149
225
  )
150
226
  # 初始化ocr
151
227
  if self.apply_ocr:
152
- self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
153
228
 
229
+ # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
230
+ self.ocr_model = atom_model_manager.get_atom_model(
231
+ atom_model_name=AtomicModel.OCR,
232
+ ocr_show_log=show_log,
233
+ det_db_box_thresh=0.3
234
+ )
154
235
  # init table model
155
236
  if self.apply_table:
156
237
  table_model_dir = self.configs["weights"][self.table_model_type]
157
- self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
158
- max_time=self.table_max_time, _device_=self.device)
238
+ # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
239
+ # max_time=self.table_max_time, _device_=self.device)
240
+ self.table_model = atom_model_manager.get_atom_model(
241
+ atom_model_name=AtomicModel.Table,
242
+ table_model_type=self.table_model_type,
243
+ table_model_path=str(os.path.join(models_dir, table_model_dir)),
244
+ table_max_time=self.table_max_time,
245
+ device=self.device
246
+ )
247
+
159
248
  logger.info('DocAnalysis init done!')
160
249
 
161
250
  def __call__(self, image):
@@ -291,11 +380,12 @@ class CustomPEKModel:
291
380
  logger.info("------------------table recognition processing begins-----------------")
292
381
  latex_code = None
293
382
  html_code = None
294
- with torch.no_grad():
295
- if self.table_model_type == STRUCT_EQTABLE:
383
+ if self.table_model_type == STRUCT_EQTABLE:
384
+ with torch.no_grad():
296
385
  latex_code = self.table_model.image2latex(new_image)[0]
297
- else:
298
- html_code = self.table_model.img2html(new_image)
386
+ else:
387
+ html_code = self.table_model.img2html(new_image)
388
+
299
389
  run_time = time.time() - single_table_start_time
300
390
  logger.info(f"------------table recognition processing ends within {run_time}s-----")
301
391
  if run_time > self.table_max_time:
@@ -12,6 +12,7 @@ from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binari
12
12
  from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
13
13
 
14
14
  from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
15
+ from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
15
16
 
16
17
  logger = get_logger()
17
18
 
@@ -162,6 +163,86 @@ def update_det_boxes(dt_boxes, mfd_res):
162
163
  return new_dt_boxes
163
164
 
164
165
 
166
+ def merge_overlapping_spans(spans):
167
+ """
168
+ Merges overlapping spans on the same line.
169
+
170
+ :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
171
+ :return: A list of merged spans
172
+ """
173
+ # Return an empty list if the input spans list is empty
174
+ if not spans:
175
+ return []
176
+
177
+ # Sort spans by their starting x-coordinate
178
+ spans.sort(key=lambda x: x[0])
179
+
180
+ # Initialize the list of merged spans
181
+ merged = []
182
+ for span in spans:
183
+ # Unpack span coordinates
184
+ x1, y1, x2, y2 = span
185
+ # If the merged list is empty or there's no horizontal overlap, add the span directly
186
+ if not merged or merged[-1][2] < x1:
187
+ merged.append(span)
188
+ else:
189
+ # If there is horizontal overlap, merge the current span with the previous one
190
+ last_span = merged.pop()
191
+ # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
192
+ x1 = min(last_span[0], x1)
193
+ y1 = min(last_span[1], y1)
194
+ x2 = max(last_span[2], x2)
195
+ y2 = max(last_span[3], y2)
196
+ # Add the merged span back to the list
197
+ merged.append((x1, y1, x2, y2))
198
+
199
+ # Return the list of merged spans
200
+ return merged
201
+
202
+
203
+ def merge_det_boxes(dt_boxes):
204
+ """
205
+ Merge detection boxes.
206
+
207
+ This function takes a list of detected bounding boxes, each represented by four corner points.
208
+ The goal is to merge these bounding boxes into larger text regions.
209
+
210
+ Parameters:
211
+ dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
212
+
213
+ Returns:
214
+ list: A list containing the merged text regions, where each region is represented by four corner points.
215
+ """
216
+ # Convert the detection boxes into a dictionary format with bounding boxes and type
217
+ dt_boxes_dict_list = []
218
+ for text_box in dt_boxes:
219
+ text_bbox = points_to_bbox(text_box)
220
+ text_box_dict = {
221
+ 'bbox': text_bbox,
222
+ 'type': 'text',
223
+ }
224
+ dt_boxes_dict_list.append(text_box_dict)
225
+
226
+ # Merge adjacent text regions into lines
227
+ lines = merge_spans_to_line(dt_boxes_dict_list)
228
+
229
+ # Initialize a new list for storing the merged text regions
230
+ new_dt_boxes = []
231
+ for line in lines:
232
+ line_bbox_list = []
233
+ for span in line:
234
+ line_bbox_list.append(span['bbox'])
235
+
236
+ # Merge overlapping text regions within the same line
237
+ merged_spans = merge_overlapping_spans(line_bbox_list)
238
+
239
+ # Convert the merged text regions back to point format and add them to the new detection box list
240
+ for span in merged_spans:
241
+ new_dt_boxes.append(bbox_to_points(span))
242
+
243
+ return new_dt_boxes
244
+
245
+
165
246
  class ModifiedPaddleOCR(PaddleOCR):
166
247
  def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
167
248
  """
@@ -265,6 +346,9 @@ class ModifiedPaddleOCR(PaddleOCR):
265
346
  img_crop_list = []
266
347
 
267
348
  dt_boxes = sorted_boxes(dt_boxes)
349
+
350
+ dt_boxes = merge_det_boxes(dt_boxes)
351
+
268
352
  if mfd_res:
269
353
  bef = time.time()
270
354
  dt_boxes = update_det_boxes(dt_boxes, mfd_res)
@@ -1,3 +1,5 @@
1
+ import copy
2
+
1
3
  from sklearn.cluster import DBSCAN
2
4
  import numpy as np
3
5
  from loguru import logger
@@ -167,7 +169,7 @@ def cluster_line_x(lines: list) -> dict:
167
169
  x0_lst = np.array([[round(line['bbox'][0]), 0] for line in lines])
168
170
  x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
169
171
  x0_uniq_label = np.unique(x0_clusters.labels_)
170
- #x1_lst = np.array([[line['bbox'][2], 0] for line in lines])
172
+ # x1_lst = np.array([[line['bbox'][2], 0] for line in lines])
171
173
  x0_2_new_val = {} # 存储旧值对应的新值映射
172
174
  min_x0 = round(lines[0]["bbox"][0])
173
175
  for label in x0_uniq_label:
@@ -200,7 +202,9 @@ def __valign_lines(blocks, layout_bboxes):
200
202
  min_distance = 3
201
203
  min_sample = 2
202
204
  new_layout_bboxes = []
203
-
205
+ # add bbox_fs for para split calculation
206
+ for block in blocks:
207
+ block["bbox_fs"] = copy.deepcopy(block["bbox"])
204
208
  for layout_box in layout_bboxes:
205
209
  blocks_in_layoutbox = [b for b in blocks if
206
210
  b["type"] == BlockType.Text and is_in_layout(b['bbox'], layout_box['layout_bbox'])]
@@ -245,16 +249,15 @@ def __valign_lines(blocks, layout_bboxes):
245
249
  # 由于修改了block里的line长度,现在需要重新计算block的bbox
246
250
  for block in blocks_in_layoutbox:
247
251
  if len(block["lines"]) > 0:
248
- block['bbox'] = [min([line['bbox'][0] for line in block['lines']]),
249
- min([line['bbox'][1] for line in block['lines']]),
250
- max([line['bbox'][2] for line in block['lines']]),
251
- max([line['bbox'][3] for line in block['lines']])]
252
-
252
+ block['bbox_fs'] = [min([line['bbox'][0] for line in block['lines']]),
253
+ min([line['bbox'][1] for line in block['lines']]),
254
+ max([line['bbox'][2] for line in block['lines']]),
255
+ max([line['bbox'][3] for line in block['lines']])]
253
256
  """新计算layout的bbox,因为block的bbox变了。"""
254
- layout_x0 = min([block['bbox'][0] for block in blocks_in_layoutbox])
255
- layout_y0 = min([block['bbox'][1] for block in blocks_in_layoutbox])
256
- layout_x1 = max([block['bbox'][2] for block in blocks_in_layoutbox])
257
- layout_y1 = max([block['bbox'][3] for block in blocks_in_layoutbox])
257
+ layout_x0 = min([block['bbox_fs'][0] for block in blocks_in_layoutbox])
258
+ layout_y0 = min([block['bbox_fs'][1] for block in blocks_in_layoutbox])
259
+ layout_x1 = max([block['bbox_fs'][2] for block in blocks_in_layoutbox])
260
+ layout_y1 = max([block['bbox_fs'][3] for block in blocks_in_layoutbox])
258
261
  new_layout_bboxes.append([layout_x0, layout_y0, layout_x1, layout_y1])
259
262
 
260
263
  return new_layout_bboxes
@@ -312,7 +315,7 @@ def __group_line_by_layout(blocks, layout_bboxes):
312
315
  # 因为只是一个block一行目前, 一个block就是一个段落
313
316
  blocks_group = []
314
317
  for lyout in layout_bboxes:
315
- blocks_in_layout = [block for block in blocks if is_in_layout(block['bbox'], lyout['layout_bbox'])]
318
+ blocks_in_layout = [block for block in blocks if is_in_layout(block.get('bbox_fs', None), lyout['layout_bbox'])]
316
319
  blocks_group.append(blocks_in_layout)
317
320
  return blocks_group
318
321
 
@@ -365,7 +368,8 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
365
368
  for i in range(0, len(list_start)):
366
369
  index = list_start[i] - 1
367
370
  if index >= 0:
368
- if "content" in lines[index]["spans"][-1]:
371
+ if "content" in lines[index]["spans"][-1] and lines[index]["spans"][-1].get('type', '') not in [
372
+ ContentType.InlineEquation, ContentType.InterlineEquation]:
369
373
  lines[index]["spans"][-1]["content"] += '\n\n'
370
374
  layout_list_info = [False, False] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
371
375
  for content_type, start, end in text_segments:
@@ -477,7 +481,7 @@ def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
477
481
  break
478
482
  # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
479
483
  if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
480
- #pre_page_paras[-1].append(may_list_lines)
484
+ # pre_page_paras[-1].append(may_list_lines)
481
485
  # 下一页合并到上一页最后一段,打一个cross_page的标签
482
486
  for line in may_list_lines:
483
487
  for span in line["spans"]:
@@ -537,7 +541,6 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
537
541
  next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
538
542
  next_first_line_type = next_first_line['spans'][0]['type']
539
543
  if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
540
- #connected_layout_paras.append(layout_paras[i])
541
544
  connected_layout_blocks.append(blocks_group[i])
542
545
  continue
543
546
  pre_layout = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)
@@ -552,10 +555,8 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
552
555
  -1] not in LINE_STOP_FLAG and \
553
556
  next_first_line['bbox'][0] == next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
554
557
  """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
555
- #connected_layout_paras[-1][-1].extend(layout_paras[i][0])
556
558
  connected_layout_blocks[-1][-1]["lines"].extend(blocks_group[i][0]["lines"])
557
- #layout_paras[i].pop(0) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
558
- blocks_group[i][0]["lines"] = [] #删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
559
+ blocks_group[i][0]["lines"] = [] # 删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
559
560
  blocks_group[i][0][LINES_DELETED] = True
560
561
  # if len(layout_paras[i]) == 0:
561
562
  # layout_paras.pop(i)
@@ -564,7 +565,6 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
564
565
  connected_layout_blocks.append(blocks_group[i])
565
566
  else:
566
567
  """连接段落条件不成立,将前一个layout的段落加入到结果中。"""
567
- #connected_layout_paras.append(layout_paras[i])
568
568
  connected_layout_blocks.append(blocks_group[i])
569
569
  return connected_layout_blocks
570
570
 
@@ -622,7 +622,7 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
622
622
  span[CROSS_PAGE] = True
623
623
  pre_last_para.extend(next_first_para)
624
624
 
625
- #next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
625
+ # next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
626
626
  next_page_paras[0][0]["lines"] = []
627
627
  next_page_paras[0][0][LINES_DELETED] = True
628
628
  return True
@@ -666,16 +666,15 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang):
666
666
  layout_box = new_layout_bbox[layout_i]
667
667
  single_line_paras_tag = []
668
668
  for i in range(len(layout_para)):
669
- #single_line_paras_tag.append(len(layout_para[i]) == 1 and layout_para[i][0]['spans'][0]['type'] == TEXT)
669
+ # single_line_paras_tag.append(len(layout_para[i]) == 1 and layout_para[i][0]['spans'][0]['type'] == TEXT)
670
670
  single_line_paras_tag.append(layout_para[i]['type'] == BlockType.Text and len(layout_para[i]["lines"]) == 1)
671
671
  """找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
672
672
  consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag)
673
673
  if len(consecutive_single_line_indices) > 0:
674
- #index_offset = 0
675
674
  """检查这些行是否是高度相同的,居中的"""
676
675
  for start, end in consecutive_single_line_indices:
677
- #start += index_offset
678
- #end += index_offset
676
+ # start += index_offset
677
+ # end += index_offset
679
678
  line_hi = np.array([block["lines"][0]['bbox'][3] - block["lines"][0]['bbox'][1] for block in
680
679
  layout_para[start:end + 1]])
681
680
  first_line_text = ''.join([__get_span_text(span) for span in layout_para[start]["lines"][0]['spans']])
@@ -700,9 +699,9 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang):
700
699
  for i_para in range(start + 1, end + 1):
701
700
  layout_para[i_para]["lines"] = []
702
701
  layout_para[i_para][LINES_DELETED] = True
703
- #layout_para[start:end + 1] = [merge_para]
702
+ # layout_para[start:end + 1] = [merge_para]
704
703
 
705
- #index_offset -= end - start
704
+ # index_offset -= end - start
706
705
 
707
706
  return
708
707
 
@@ -742,7 +741,7 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
742
741
  new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
743
742
  all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
744
743
  for page_num, page in pdf_info_dict.items():
745
- blocks = page['preproc_blocks']
744
+ blocks = copy.deepcopy(page['preproc_blocks'])
746
745
  layout_bboxes = page['layout_bboxes']
747
746
  new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
748
747
  new_layout_of_pages.append(new_layout_bbox)
@@ -41,6 +41,23 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
41
41
  return is_useful_block_horz_overlap, all_bboxes
42
42
 
43
43
 
44
+ def __replace_STX_ETX(text_str:str):
45
+ """ Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
46
+ Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
47
+
48
+ Args:
49
+ text_str (str): raw text
50
+
51
+ Returns:
52
+ _type_: replaced text
53
+ """
54
+ if text_str:
55
+ s = text_str.replace('\u0002', "'")
56
+ s = s.replace("\u0003", "'")
57
+ return s
58
+ return text_str
59
+
60
+
44
61
  def txt_spans_extract(pdf_page, inline_equations, interline_equations):
45
62
  text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
46
63
  char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
@@ -63,7 +80,7 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
63
80
  spans.append(
64
81
  {
65
82
  "bbox": list(span["bbox"]),
66
- "content": span["text"],
83
+ "content": __replace_STX_ETX(span["text"]),
67
84
  "type": ContentType.Text,
68
85
  "score": 1.0,
69
86
  }
@@ -175,7 +192,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
175
192
  sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
176
193
 
177
194
  '''将span填入排好序的blocks中'''
178
- block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.6)
195
+ block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3)
179
196
 
180
197
  '''对block进行fix操作'''
181
198
  fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
@@ -208,13 +225,17 @@ def pdf_parse_union(pdf_bytes,
208
225
  magic_model = MagicModel(model_list, pdf_docs)
209
226
 
210
227
  '''根据输入的起始范围解析pdf'''
211
- end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
228
+ # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
229
+ end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
230
+
231
+ if end_page_id > len(pdf_docs) - 1:
232
+ logger.warning("end_page_id is out of range, use pdf_docs length")
233
+ end_page_id = len(pdf_docs) - 1
212
234
 
213
235
  '''初始化启动时间'''
214
236
  start_time = time.time()
215
237
 
216
- for page_id in range(start_page_id, end_page_id + 1):
217
-
238
+ for page_id, page in enumerate(pdf_docs):
218
239
  '''debug时输出每页解析的耗时'''
219
240
  if debug_mode:
220
241
  time_now = time.time()
@@ -224,7 +245,14 @@ def pdf_parse_union(pdf_bytes,
224
245
  start_time = time_now
225
246
 
226
247
  '''解析pdf中的每一页'''
227
- page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
248
+ if start_page_id <= page_id <= end_page_id:
249
+ page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
250
+ else:
251
+ page_w = page.rect.width
252
+ page_h = page.rect.height
253
+ page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
254
+ [], [], [], [],
255
+ True, "skip page")
228
256
  pdf_info_dict[f"page_{page_id}"] = page_info
229
257
 
230
258
  """分段"""
magic_pdf/pipe/AbsPipe.py CHANGED
@@ -16,12 +16,15 @@ class AbsPipe(ABC):
16
16
  PIP_OCR = "ocr"
17
17
  PIP_TXT = "txt"
18
18
 
19
- def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
19
+ def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
20
+ start_page_id=0, end_page_id=None):
20
21
  self.pdf_bytes = pdf_bytes
21
22
  self.model_list = model_list
22
23
  self.image_writer = image_writer
23
24
  self.pdf_mid_data = None # 未压缩
24
25
  self.is_debug = is_debug
26
+ self.start_page_id = start_page_id
27
+ self.end_page_id = end_page_id
25
28
 
26
29
  def get_compress_pdf_mid_data(self):
27
30
  return JsonCompressor.compress_json(self.pdf_mid_data)
magic_pdf/pipe/OCRPipe.py CHANGED
@@ -9,17 +9,20 @@ from magic_pdf.user_api import parse_ocr_pdf
9
9
 
10
10
  class OCRPipe(AbsPipe):
11
11
 
12
- def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
13
- super().__init__(pdf_bytes, model_list, image_writer, is_debug)
12
+ def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
13
+ start_page_id=0, end_page_id=None):
14
+ super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
14
15
 
15
16
  def pipe_classify(self):
16
17
  pass
17
18
 
18
19
  def pipe_analyze(self):
19
- self.model_list = doc_analyze(self.pdf_bytes, ocr=True)
20
+ self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
21
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
20
22
 
21
23
  def pipe_parse(self):
22
- self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
24
+ self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
25
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
23
26
 
24
27
  def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
25
28
  result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
magic_pdf/pipe/TXTPipe.py CHANGED
@@ -10,17 +10,20 @@ from magic_pdf.user_api import parse_txt_pdf
10
10
 
11
11
  class TXTPipe(AbsPipe):
12
12
 
13
- def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
14
- super().__init__(pdf_bytes, model_list, image_writer, is_debug)
13
+ def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
14
+ start_page_id=0, end_page_id=None):
15
+ super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
15
16
 
16
17
  def pipe_classify(self):
17
18
  pass
18
19
 
19
20
  def pipe_analyze(self):
20
- self.model_list = doc_analyze(self.pdf_bytes, ocr=False)
21
+ self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
22
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
21
23
 
22
24
  def pipe_parse(self):
23
- self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
25
+ self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
26
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
24
27
 
25
28
  def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
26
29
  result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
magic_pdf/pipe/UNIPipe.py CHANGED
@@ -13,9 +13,10 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
13
13
 
14
14
  class UNIPipe(AbsPipe):
15
15
 
16
- def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False):
16
+ def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
17
+ start_page_id=0, end_page_id=None):
17
18
  self.pdf_type = jso_useful_key["_pdf_type"]
18
- super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug)
19
+ super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
19
20
  if len(self.model_list) == 0:
20
21
  self.input_model_is_empty = True
21
22
  else:
@@ -26,17 +27,21 @@ class UNIPipe(AbsPipe):
26
27
 
27
28
  def pipe_analyze(self):
28
29
  if self.pdf_type == self.PIP_TXT:
29
- self.model_list = doc_analyze(self.pdf_bytes, ocr=False)
30
+ self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
31
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
30
32
  elif self.pdf_type == self.PIP_OCR:
31
- self.model_list = doc_analyze(self.pdf_bytes, ocr=True)
33
+ self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
34
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
32
35
 
33
36
  def pipe_parse(self):
34
37
  if self.pdf_type == self.PIP_TXT:
35
38
  self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
36
- is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty)
39
+ is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
40
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
37
41
  elif self.pdf_type == self.PIP_OCR:
38
42
  self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
39
- is_debug=self.is_debug)
43
+ is_debug=self.is_debug,
44
+ start_page_id=self.start_page_id, end_page_id=self.end_page_id)
40
45
 
41
46
  def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
42
47
  result = super().pipe_mk_uni_format(img_parent_path, drop_mode)