magic-pdf 0.6.1__py3-none-any.whl → 0.7.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +20 -7
  2. magic_pdf/libs/config_reader.py +28 -10
  3. magic_pdf/libs/language.py +12 -0
  4. magic_pdf/libs/version.py +1 -1
  5. magic_pdf/model/__init__.py +1 -1
  6. magic_pdf/model/doc_analyze_by_custom_model.py +35 -3
  7. magic_pdf/model/magic_model.py +49 -41
  8. magic_pdf/model/pdf_extract_kit.py +155 -60
  9. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +7 -6
  10. magic_pdf/model/pek_sub_modules/self_modify.py +87 -43
  11. magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +22 -0
  12. magic_pdf/model/pp_structure_v2.py +1 -1
  13. magic_pdf/pdf_parse_union_core.py +4 -2
  14. magic_pdf/pre_proc/citationmarker_remove.py +5 -1
  15. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +40 -2
  16. magic_pdf/pre_proc/ocr_span_list_modify.py +12 -7
  17. magic_pdf/resources/fasttext-langdetect/lid.176.ftz +0 -0
  18. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +2 -2
  19. magic_pdf/resources/model_config/model_configs.yaml +4 -0
  20. magic_pdf/rw/AbsReaderWriter.py +1 -18
  21. magic_pdf/rw/DiskReaderWriter.py +32 -24
  22. magic_pdf/rw/S3ReaderWriter.py +83 -48
  23. magic_pdf/tools/cli.py +79 -0
  24. magic_pdf/tools/cli_dev.py +156 -0
  25. magic_pdf/tools/common.py +119 -0
  26. {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/METADATA +120 -72
  27. {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/RECORD +34 -35
  28. {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/WHEEL +1 -1
  29. magic_pdf-0.7.0a1.dist-info/entry_points.txt +3 -0
  30. magic_pdf/cli/magicpdf.py +0 -337
  31. magic_pdf/pdf_parse_for_train.py +0 -685
  32. magic_pdf/train_utils/convert_to_train_format.py +0 -65
  33. magic_pdf/train_utils/extract_caption.py +0 -59
  34. magic_pdf/train_utils/remove_footer_header.py +0 -159
  35. magic_pdf/train_utils/vis_utils.py +0 -327
  36. magic_pdf-0.6.1.dist-info/entry_points.txt +0 -2
  37. /magic_pdf/libs/{math.py → local_math.py} +0 -0
  38. /magic_pdf/{cli → model/pek_sub_modules/structeqtable}/__init__.py +0 -0
  39. /magic_pdf/{train_utils → tools}/__init__.py +0 -0
  40. {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/LICENSE.md +0 -0
  41. {magic_pdf-0.6.1.dist-info → magic_pdf-0.7.0a1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,19 @@
1
1
  from loguru import logger
2
2
  import os
3
+ import time
4
+
5
+
6
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
3
7
  try:
4
8
  import cv2
5
9
  import yaml
6
- import time
7
10
  import argparse
8
11
  import numpy as np
9
12
  import torch
13
+ import torchtext
10
14
 
11
- from paddleocr import draw_ocr
15
+ if torchtext.__version__ >= "0.18.0":
16
+ torchtext.disable_torchtext_deprecation_warning()
12
17
  from PIL import Image
13
18
  from torchvision import transforms
14
19
  from torch.utils.data import Dataset, DataLoader
@@ -17,13 +22,23 @@ try:
17
22
  import unimernet.tasks as tasks
18
23
  from unimernet.processors import load_processor
19
24
 
20
- from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
21
- from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
22
- from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
23
- except ImportError:
24
- logger.error('Required dependency not installed, please install by \n"pip install magic-pdf[full-cpu] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
25
+ except ImportError as e:
26
+ logger.exception(e)
27
+ logger.error(
28
+ 'Required dependency not installed, please install by \n'
29
+ '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
25
30
  exit(1)
26
31
 
32
+ from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
33
+ from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
34
+ from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
35
+ from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
36
+
37
+
38
+ def table_model_init(model_path, max_time=400, _device_='cpu'):
39
+ table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
40
+ return table_model
41
+
27
42
 
28
43
  def mfd_model_init(weight):
29
44
  mfd_model = YOLO(weight)
@@ -83,15 +98,17 @@ class CustomPEKModel:
83
98
  model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
84
99
  # 构建 model_configs.yaml 文件的完整路径
85
100
  config_path = os.path.join(model_config_dir, 'model_configs.yaml')
86
- with open(config_path, "r") as f:
101
+ with open(config_path, "r", encoding='utf-8') as f:
87
102
  self.configs = yaml.load(f, Loader=yaml.FullLoader)
88
103
  # 初始化解析配置
89
104
  self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
90
105
  self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
106
+ self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
107
+ self.apply_table = self.table_config.get("is_table_recog_enable", False)
91
108
  self.apply_ocr = ocr
92
109
  logger.info(
93
- "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
94
- self.apply_layout, self.apply_formula, self.apply_ocr
110
+ "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
111
+ self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
95
112
  )
96
113
  )
97
114
  assert self.apply_layout, "DocAnalysis must contain layout model."
@@ -99,6 +116,7 @@ class CustomPEKModel:
99
116
  self.device = kwargs.get("device", self.configs["config"]["device"])
100
117
  logger.info("using device: {}".format(self.device))
101
118
  models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
119
+ logger.info("using models_dir: {}".format(models_dir))
102
120
 
103
121
  # 初始化公式识别
104
122
  if self.apply_formula:
@@ -121,6 +139,11 @@ class CustomPEKModel:
121
139
  if self.apply_ocr:
122
140
  self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
123
141
 
142
+ # init structeqtable
143
+ if self.apply_table:
144
+ max_time = self.table_config.get("max_time", 400)
145
+ self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
146
+ max_time=max_time, _device_=self.device)
124
147
  logger.info('DocAnalysis init done!')
125
148
 
126
149
  def __call__(self, image):
@@ -134,67 +157,139 @@ class CustomPEKModel:
134
157
  layout_cost = round(time.time() - layout_start, 2)
135
158
  logger.info(f"layout detection cost: {layout_cost}")
136
159
 
137
- # 公式检测
138
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
139
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
140
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
141
- new_item = {
142
- 'category_id': 13 + int(cla.item()),
143
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
144
- 'score': round(float(conf.item()), 2),
145
- 'latex': '',
146
- }
147
- layout_res.append(new_item)
148
- latex_filling_list.append(new_item)
149
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
150
- mf_image_list.append(bbox_img)
151
-
152
- # 公式识别
153
- mfr_start = time.time()
154
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
155
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
156
- mfr_res = []
157
- for mf_img in dataloader:
158
- mf_img = mf_img.to(self.device)
159
- output = self.mfr_model.generate({'image': mf_img})
160
- mfr_res.extend(output['pred_str'])
161
- for res, latex in zip(latex_filling_list, mfr_res):
162
- res['latex'] = latex_rm_whitespace(latex)
163
- mfr_cost = round(time.time() - mfr_start, 2)
164
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
160
+ if self.apply_formula:
161
+ # 公式检测
162
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
163
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
164
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
165
+ new_item = {
166
+ 'category_id': 13 + int(cla.item()),
167
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
168
+ 'score': round(float(conf.item()), 2),
169
+ 'latex': '',
170
+ }
171
+ layout_res.append(new_item)
172
+ latex_filling_list.append(new_item)
173
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
174
+ mf_image_list.append(bbox_img)
175
+
176
+ # 公式识别
177
+ mfr_start = time.time()
178
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
179
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
180
+ mfr_res = []
181
+ for mf_img in dataloader:
182
+ mf_img = mf_img.to(self.device)
183
+ output = self.mfr_model.generate({'image': mf_img})
184
+ mfr_res.extend(output['pred_str'])
185
+ for res, latex in zip(latex_filling_list, mfr_res):
186
+ res['latex'] = latex_rm_whitespace(latex)
187
+ mfr_cost = round(time.time() - mfr_start, 2)
188
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
165
189
 
166
190
  # ocr识别
167
191
  if self.apply_ocr:
168
192
  ocr_start = time.time()
169
193
  pil_img = Image.fromarray(image)
194
+
195
+ # 筛选出需要OCR的区域和公式区域
196
+ ocr_res_list = []
170
197
  single_page_mfdetrec_res = []
171
198
  for res in layout_res:
172
199
  if int(res['category_id']) in [13, 14]:
173
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
174
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
175
200
  single_page_mfdetrec_res.append({
176
- "bbox": [xmin, ymin, xmax, ymax],
201
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
202
+ int(res['poly'][4]), int(res['poly'][5])],
177
203
  })
178
- for res in layout_res:
179
- if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
180
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
181
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
182
- crop_box = (xmin, ymin, xmax, ymax)
183
- cropped_img = Image.new('RGB', pil_img.size, 'white')
184
- cropped_img.paste(pil_img.crop(crop_box), crop_box)
185
- cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
186
- ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
187
- if ocr_res:
188
- for box_ocr_res in ocr_res:
189
- p1, p2, p3, p4 = box_ocr_res[0]
190
- text, score = box_ocr_res[1]
191
- layout_res.append({
192
- 'category_id': 15,
193
- 'poly': p1 + p2 + p3 + p4,
194
- 'score': round(score, 2),
195
- 'text': text,
196
- })
204
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
205
+ ocr_res_list.append(res)
206
+
207
+ # 对每一个需OCR处理的区域进行处理
208
+ for res in ocr_res_list:
209
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
210
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
211
+
212
+ paste_x = 50
213
+ paste_y = 50
214
+ # 创建一个宽高各多50的白色背景
215
+ new_width = xmax - xmin + paste_x * 2
216
+ new_height = ymax - ymin + paste_y * 2
217
+ new_image = Image.new('RGB', (new_width, new_height), 'white')
218
+
219
+ # 裁剪图像
220
+ crop_box = (xmin, ymin, xmax, ymax)
221
+ cropped_img = pil_img.crop(crop_box)
222
+ new_image.paste(cropped_img, (paste_x, paste_y))
223
+
224
+ # 调整公式区域坐标
225
+ adjusted_mfdetrec_res = []
226
+ for mf_res in single_page_mfdetrec_res:
227
+ mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
228
+ # 将公式区域坐标调整为相对于裁剪区域的坐标
229
+ x0 = mf_xmin - xmin + paste_x
230
+ y0 = mf_ymin - ymin + paste_y
231
+ x1 = mf_xmax - xmin + paste_x
232
+ y1 = mf_ymax - ymin + paste_y
233
+ # 过滤在图外的公式块
234
+ if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
235
+ continue
236
+ else:
237
+ adjusted_mfdetrec_res.append({
238
+ "bbox": [x0, y0, x1, y1],
239
+ })
240
+
241
+ # OCR识别
242
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
243
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
244
+
245
+ # 整合结果
246
+ if ocr_res:
247
+ for box_ocr_res in ocr_res:
248
+ p1, p2, p3, p4 = box_ocr_res[0]
249
+ text, score = box_ocr_res[1]
250
+
251
+ # 将坐标转换回原图坐标系
252
+ p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
253
+ p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
254
+ p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
255
+ p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
256
+
257
+ layout_res.append({
258
+ 'category_id': 15,
259
+ 'poly': p1 + p2 + p3 + p4,
260
+ 'score': round(score, 2),
261
+ 'text': text,
262
+ })
263
+
197
264
  ocr_cost = round(time.time() - ocr_start, 2)
198
265
  logger.info(f"ocr cost: {ocr_cost}")
199
266
 
267
+ # 表格识别 table recognition
268
+ if self.apply_table:
269
+ pil_img = Image.fromarray(image)
270
+ for layout in layout_res:
271
+ if layout.get("category_id", -1) == 5:
272
+ poly = layout["poly"]
273
+ xmin, ymin = int(poly[0]), int(poly[1])
274
+ xmax, ymax = int(poly[4]), int(poly[5])
275
+
276
+ paste_x = 50
277
+ paste_y = 50
278
+ # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
279
+ new_width = xmax - xmin + paste_x * 2
280
+ new_height = ymax - ymin + paste_y * 2
281
+ new_image = Image.new('RGB', (new_width, new_height), 'white')
282
+
283
+ # 裁剪图像 crop image
284
+ crop_box = (xmin, ymin, xmax, ymax)
285
+ cropped_img = pil_img.crop(crop_box)
286
+ new_image.paste(cropped_img, (paste_x, paste_y))
287
+ start_time = time.time()
288
+ logger.info("------------------table recognition processing begins-----------------")
289
+ latex_code = self.table_model.image2latex(new_image)[0]
290
+ end_time = time.time()
291
+ run_time = end_time - start_time
292
+ logger.info(f"------------table recognition processing ends within {run_time}s-----")
293
+ layout["latex"] = latex_code
294
+
200
295
  return layout_res
@@ -79,12 +79,13 @@ def setup(args, device):
79
79
  cfg.freeze()
80
80
  default_setup(cfg, args)
81
81
 
82
- register_coco_instances(
83
- "scihub_train",
84
- {},
85
- cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
86
- cfg.SCIHUB_DATA_DIR_TRAIN
87
- )
82
+ #@todo 可以删掉这块?
83
+ # register_coco_instances(
84
+ # "scihub_train",
85
+ # {},
86
+ # cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
87
+ # cfg.SCIHUB_DATA_DIR_TRAIN
88
+ # )
88
89
 
89
90
  return cfg
90
91
 
@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
10
10
  from paddleocr.ppocr.utils.logging import get_logger
11
11
  from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
12
12
  from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
13
+
14
+ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
15
+
13
16
  logger = get_logger()
14
17
 
18
+
15
19
  def img_decode(content: bytes):
16
20
  np_arr = np.frombuffer(content, dtype=np.uint8)
17
21
  return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
18
22
 
23
+
19
24
  def check_img(img):
20
25
  if isinstance(img, bytes):
21
26
  img = img_decode(img)
@@ -51,6 +56,7 @@ def check_img(img):
51
56
 
52
57
  return img
53
58
 
59
+
54
60
  def sorted_boxes(dt_boxes):
55
61
  """
56
62
  Sort text boxes in order from top to bottom, left to right
@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
75
81
  return _boxes
76
82
 
77
83
 
78
- def formula_in_text(mf_bbox, text_bbox):
79
- x1, y1, x2, y2 = mf_bbox
80
- x3, y3 = text_bbox[0]
81
- x4, y4 = text_bbox[2]
82
- left_box, right_box = None, None
83
- same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2
84
- if not same_line:
85
- return False, left_box, right_box
86
- else:
87
- drop_origin = False
88
- left_x = x1 - 1
89
- right_x = x2 + 1
90
- if x3 < x1 and x2 < x4:
91
- drop_origin = True
92
- left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
93
- right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
94
- if x3 < x1 and x1 <= x4 <= x2:
95
- drop_origin = True
96
- left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
97
- if x1 <= x3 <= x2 and x2 < x4:
98
- drop_origin = True
99
- right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
100
- if x1 <= x3 < x4 <= x2:
101
- drop_origin = True
102
- return drop_origin, left_box, right_box
103
-
104
-
105
- def update_det_boxes(dt_boxes, mfdetrec_res):
106
- new_dt_boxes = dt_boxes
107
- for mf_box in mfdetrec_res:
108
- flag, left_box, right_box = False, None, None
109
- for idx, text_box in enumerate(new_dt_boxes):
110
- ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box)
111
- if ret:
112
- new_dt_boxes.pop(idx)
113
- if left_box is not None:
114
- new_dt_boxes.append(left_box)
115
- if right_box is not None:
116
- new_dt_boxes.append(right_box)
117
- break
118
-
84
+ def bbox_to_points(bbox):
85
+ """ 将bbox格式转换为四个顶点的数组 """
86
+ x0, y0, x1, y1 = bbox
87
+ return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
88
+
89
+
90
+ def points_to_bbox(points):
91
+ """ 将四个顶点的数组转换为bbox格式 """
92
+ x0, y0 = points[0]
93
+ x1, _ = points[1]
94
+ _, y1 = points[2]
95
+ return [x0, y0, x1, y1]
96
+
97
+
98
+ def merge_intervals(intervals):
99
+ # Sort the intervals based on the start value
100
+ intervals.sort(key=lambda x: x[0])
101
+
102
+ merged = []
103
+ for interval in intervals:
104
+ # If the list of merged intervals is empty or if the current
105
+ # interval does not overlap with the previous, simply append it.
106
+ if not merged or merged[-1][1] < interval[0]:
107
+ merged.append(interval)
108
+ else:
109
+ # Otherwise, there is overlap, so we merge the current and previous intervals.
110
+ merged[-1][1] = max(merged[-1][1], interval[1])
111
+
112
+ return merged
113
+
114
+
115
+ def remove_intervals(original, masks):
116
+ # Merge all mask intervals
117
+ merged_masks = merge_intervals(masks)
118
+
119
+ result = []
120
+ original_start, original_end = original
121
+
122
+ for mask in merged_masks:
123
+ mask_start, mask_end = mask
124
+
125
+ # If the mask starts after the original range, ignore it
126
+ if mask_start > original_end:
127
+ continue
128
+
129
+ # If the mask ends before the original range starts, ignore it
130
+ if mask_end < original_start:
131
+ continue
132
+
133
+ # Remove the masked part from the original range
134
+ if original_start < mask_start:
135
+ result.append([original_start, mask_start - 1])
136
+
137
+ original_start = max(mask_end + 1, original_start)
138
+
139
+ # Add the remaining part of the original range, if any
140
+ if original_start <= original_end:
141
+ result.append([original_start, original_end])
142
+
143
+ return result
144
+
145
+
146
+ def update_det_boxes(dt_boxes, mfd_res):
147
+ new_dt_boxes = []
148
+ for text_box in dt_boxes:
149
+ text_bbox = points_to_bbox(text_box)
150
+ masks_list = []
151
+ for mf_box in mfd_res:
152
+ mf_bbox = mf_box['bbox']
153
+ if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
154
+ masks_list.append([mf_bbox[0], mf_bbox[2]])
155
+ text_x_range = [text_bbox[0], text_bbox[2]]
156
+ text_remove_mask_range = remove_intervals(text_x_range, masks_list)
157
+ temp_dt_box = []
158
+ for text_remove_mask in text_remove_mask_range:
159
+ temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
160
+ if len(temp_dt_box) > 0:
161
+ new_dt_boxes.extend(temp_dt_box)
119
162
  return new_dt_boxes
120
163
 
164
+
121
165
  class ModifiedPaddleOCR(PaddleOCR):
122
166
  def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
123
167
  """
@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
197
241
  if not rec:
198
242
  return cls_res
199
243
  return ocr_res
200
-
244
+
201
245
  def __call__(self, img, cls=True, mfd_res=None):
202
246
  time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
203
247
 
@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
226
270
  dt_boxes = update_det_boxes(dt_boxes, mfd_res)
227
271
  aft = time.time()
228
272
  logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
229
- len(dt_boxes), aft-bef))
273
+ len(dt_boxes), aft - bef))
230
274
 
231
275
  for bno in range(len(dt_boxes)):
232
276
  tmp_box = copy.deepcopy(dt_boxes[bno])
@@ -0,0 +1,22 @@
1
+ from struct_eqtable.model import StructTable
2
+ from pypandoc import convert_text
3
+ class StructTableModel:
4
+ def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
5
+ # init
6
+ self.model_path = model_path
7
+ self.max_new_tokens = max_new_tokens # maximum output tokens length
8
+ self.max_time = max_time # timeout for processing in seconds
9
+ if device == 'cuda':
10
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
11
+ else:
12
+ self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
13
+
14
+ def image2latex(self, image) -> str:
15
+ #
16
+ table_latex = self.model.forward(image)
17
+ return table_latex
18
+
19
+ def image2html(self, image) -> str:
20
+ table_latex = self.image2latex(image)
21
+ table_html = convert_text(table_latex, 'html', format='latex')
22
+ return table_html
@@ -5,7 +5,7 @@ from loguru import logger
5
5
  try:
6
6
  from paddleocr import PPStructure
7
7
  except ImportError:
8
- logger.error('paddleocr not installed, please install by "pip install magic-pdf[cpu]" or "pip install magic-pdf[gpu]"')
8
+ logger.error('paddleocr not installed, please install by "pip install magic-pdf[lite]"')
9
9
  exit(1)
10
10
 
11
11
 
@@ -7,7 +7,7 @@ from magic_pdf.layout.layout_sort import get_bboxes_layout, LAYOUT_UNPROC, get_c
7
7
  from magic_pdf.libs.convert_utils import dict_to_list
8
8
  from magic_pdf.libs.drop_reason import DropReason
9
9
  from magic_pdf.libs.hash_utils import compute_md5
10
- from magic_pdf.libs.math import float_equal
10
+ from magic_pdf.libs.local_math import float_equal
11
11
  from magic_pdf.libs.ocr_content_type import ContentType
12
12
  from magic_pdf.model.magic_model import MagicModel
13
13
  from magic_pdf.para.para_split_v2 import para_split
@@ -111,7 +111,8 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
111
111
  spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
112
112
 
113
113
  '''将所有区块的bbox整理到一起'''
114
- # @todo interline_equation_blocks参数不够准,后面切换到interline_equations上
114
+ # interline_equation_blocks参数不够准,后面切换到interline_equations上
115
+ interline_equation_blocks = []
115
116
  if len(interline_equation_blocks) > 0:
116
117
  all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
117
118
  img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
@@ -120,6 +121,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
120
121
  all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
121
122
  img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
122
123
  interline_equations, page_w, page_h)
124
+
123
125
  if len(drop_reasons) > 0:
124
126
  need_drop = True
125
127
  drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
@@ -135,7 +135,11 @@ def remove_citation_marker(with_char_text_blcoks):
135
135
 
136
136
  if max_font_sz-span_font_sz<1: # 先以字体过滤正文,如果是正文就不再继续判断了
137
137
  continue
138
-
138
+
139
+ # 对被除数为0的情况进行过滤
140
+ if span_hi==0 or min_font_sz==0:
141
+ continue
142
+
139
143
  if (base_span_mid_y-span_mid_y)/span_hi>0.2 or (base_span_mid_y-span_mid_y>0 and abs(span_font_sz-min_font_sz)/min_font_sz<0.1):
140
144
  """
141
145
  1. 它的前一个char如果是句号或者逗号的话,那么肯定是角标而不是公式
@@ -36,9 +36,12 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
36
36
  all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
37
37
  '''任何框体与舍弃框重叠,优先信任舍弃框'''
38
38
  all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
39
- # @todo interline_equation 与title或text框冲突的情况,分两种情况处理
39
+
40
+ # interline_equation 与title或text框冲突的情况,分两种情况处理
40
41
  '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
42
+ all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
41
43
  '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
44
+ # 通过后续大框套小框逻辑删除
42
45
 
43
46
  '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
44
47
  for discarded in discarded_blocks:
@@ -57,6 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
57
60
  return all_bboxes, all_discarded_blocks, drop_reasons
58
61
 
59
62
 
63
+ def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
64
+ # 先提取所有text和interline block
65
+ text_blocks = []
66
+ for block in all_bboxes:
67
+ if block[7] == BlockType.Text:
68
+ text_blocks.append(block)
69
+ interline_equation_blocks = []
70
+ for block in all_bboxes:
71
+ if block[7] == BlockType.InterlineEquation:
72
+ interline_equation_blocks.append(block)
73
+
74
+ need_remove = []
75
+
76
+ for interline_equation_block in interline_equation_blocks:
77
+ for text_block in text_blocks:
78
+ interline_equation_block_bbox = interline_equation_block[:4]
79
+ text_block_bbox = text_block[:4]
80
+ if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
81
+ if text_block not in need_remove:
82
+ need_remove.append(text_block)
83
+
84
+ if len(need_remove) > 0:
85
+ for block in need_remove:
86
+ all_bboxes.remove(block)
87
+
88
+ return all_bboxes
89
+
90
+
60
91
  def fix_text_overlap_title_blocks(all_bboxes):
61
92
  # 先提取所有text和title block
62
93
  text_blocks = []
@@ -68,12 +99,19 @@ def fix_text_overlap_title_blocks(all_bboxes):
68
99
  if block[7] == BlockType.Title:
69
100
  title_blocks.append(block)
70
101
 
102
+ need_remove = []
103
+
71
104
  for text_block in text_blocks:
72
105
  for title_block in title_blocks:
73
106
  text_block_bbox = text_block[:4]
74
107
  title_block_bbox = title_block[:4]
75
108
  if calculate_iou(text_block_bbox, title_block_bbox) > 0.8:
76
- all_bboxes.remove(title_block)
109
+ if title_block not in need_remove:
110
+ need_remove.append(title_block)
111
+
112
+ if len(need_remove) > 0:
113
+ for block in need_remove:
114
+ all_bboxes.remove(block)
77
115
 
78
116
  return all_bboxes
79
117
 
@@ -5,19 +5,24 @@ from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, g
5
5
  from magic_pdf.libs.drop_tag import DropTag
6
6
  from magic_pdf.libs.ocr_content_type import ContentType, BlockType
7
7
 
8
+
8
9
  def remove_overlaps_low_confidence_spans(spans):
9
10
  dropped_spans = []
10
11
  # 删除重叠spans中置信度低的的那些
11
12
  for span1 in spans:
12
13
  for span2 in spans:
13
14
  if span1 != span2:
14
- if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
15
- if span1['score'] < span2['score']:
16
- span_need_remove = span1
17
- else:
18
- span_need_remove = span2
19
- if span_need_remove is not None and span_need_remove not in dropped_spans:
20
- dropped_spans.append(span_need_remove)
15
+ # span1 span2 任何一个都不应该在 dropped_spans 中
16
+ if span1 in dropped_spans or span2 in dropped_spans:
17
+ continue
18
+ else:
19
+ if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
20
+ if span1['score'] < span2['score']:
21
+ span_need_remove = span1
22
+ else:
23
+ span_need_remove = span2
24
+ if span_need_remove is not None and span_need_remove not in dropped_spans:
25
+ dropped_spans.append(span_need_remove)
21
26
 
22
27
  if len(dropped_spans) > 0:
23
28
  for span_need_remove in dropped_spans: