magic-pdf 0.6.0__py3-none-any.whl → 0.6.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,31 @@
1
+ from loguru import logger
1
2
  import os
2
- import cv2
3
- import yaml
4
3
  import time
5
- import argparse
6
- import numpy as np
7
- import torch
8
- from loguru import logger
9
4
 
10
- from paddleocr import draw_ocr
11
- from PIL import Image
12
- from torchvision import transforms
13
- from torch.utils.data import Dataset, DataLoader
14
- from ultralytics import YOLO
15
- from unimernet.common.config import Config
16
- import unimernet.tasks as tasks
17
- from unimernet.processors import load_processor
5
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
6
+ try:
7
+ import cv2
8
+ import yaml
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+ import torchtext
13
+ if torchtext.__version__ >= "0.18.0":
14
+ torchtext.disable_torchtext_deprecation_warning()
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from ultralytics import YOLO
19
+ from unimernet.common.config import Config
20
+ import unimernet.tasks as tasks
21
+ from unimernet.processors import load_processor
22
+
23
+ except ImportError as e:
24
+ logger.exception(e)
25
+ logger.error(
26
+ 'Required dependency not installed, please install by \n'
27
+ '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
28
+ exit(1)
18
29
 
19
30
  from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
20
31
  from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
@@ -79,7 +90,7 @@ class CustomPEKModel:
79
90
  model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
80
91
  # 构建 model_configs.yaml 文件的完整路径
81
92
  config_path = os.path.join(model_config_dir, 'model_configs.yaml')
82
- with open(config_path, "r") as f:
93
+ with open(config_path, "r", encoding='utf-8') as f:
83
94
  self.configs = yaml.load(f, Loader=yaml.FullLoader)
84
95
  # 初始化解析配置
85
96
  self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
@@ -95,6 +106,7 @@ class CustomPEKModel:
95
106
  self.device = kwargs.get("device", self.configs["config"]["device"])
96
107
  logger.info("using device: {}".format(self.device))
97
108
  models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
109
+ logger.info("using models_dir: {}".format(models_dir))
98
110
 
99
111
  # 初始化公式识别
100
112
  if self.apply_formula:
@@ -130,66 +142,110 @@ class CustomPEKModel:
130
142
  layout_cost = round(time.time() - layout_start, 2)
131
143
  logger.info(f"layout detection cost: {layout_cost}")
132
144
 
133
- # 公式检测
134
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
135
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
136
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
137
- new_item = {
138
- 'category_id': 13 + int(cla.item()),
139
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
140
- 'score': round(float(conf.item()), 2),
141
- 'latex': '',
142
- }
143
- layout_res.append(new_item)
144
- latex_filling_list.append(new_item)
145
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
146
- mf_image_list.append(bbox_img)
147
-
148
- # 公式识别
149
- mfr_start = time.time()
150
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
151
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
152
- mfr_res = []
153
- for mf_img in dataloader:
154
- mf_img = mf_img.to(self.device)
155
- output = self.mfr_model.generate({'image': mf_img})
156
- mfr_res.extend(output['pred_str'])
157
- for res, latex in zip(latex_filling_list, mfr_res):
158
- res['latex'] = latex_rm_whitespace(latex)
159
- mfr_cost = round(time.time() - mfr_start, 2)
160
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
145
+ if self.apply_formula:
146
+ # 公式检测
147
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
148
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
149
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
150
+ new_item = {
151
+ 'category_id': 13 + int(cla.item()),
152
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
153
+ 'score': round(float(conf.item()), 2),
154
+ 'latex': '',
155
+ }
156
+ layout_res.append(new_item)
157
+ latex_filling_list.append(new_item)
158
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
159
+ mf_image_list.append(bbox_img)
160
+
161
+ # 公式识别
162
+ mfr_start = time.time()
163
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
164
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
165
+ mfr_res = []
166
+ for mf_img in dataloader:
167
+ mf_img = mf_img.to(self.device)
168
+ output = self.mfr_model.generate({'image': mf_img})
169
+ mfr_res.extend(output['pred_str'])
170
+ for res, latex in zip(latex_filling_list, mfr_res):
171
+ res['latex'] = latex_rm_whitespace(latex)
172
+ mfr_cost = round(time.time() - mfr_start, 2)
173
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
161
174
 
162
175
  # ocr识别
163
176
  if self.apply_ocr:
164
177
  ocr_start = time.time()
165
178
  pil_img = Image.fromarray(image)
179
+
180
+ # 筛选出需要OCR的区域和公式区域
181
+ ocr_res_list = []
166
182
  single_page_mfdetrec_res = []
167
183
  for res in layout_res:
168
184
  if int(res['category_id']) in [13, 14]:
169
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
170
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
171
185
  single_page_mfdetrec_res.append({
172
- "bbox": [xmin, ymin, xmax, ymax],
186
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
187
+ int(res['poly'][4]), int(res['poly'][5])],
173
188
  })
174
- for res in layout_res:
175
- if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
176
- xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
177
- xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
178
- crop_box = (xmin, ymin, xmax, ymax)
179
- cropped_img = Image.new('RGB', pil_img.size, 'white')
180
- cropped_img.paste(pil_img.crop(crop_box), crop_box)
181
- cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
182
- ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
183
- if ocr_res:
184
- for box_ocr_res in ocr_res:
185
- p1, p2, p3, p4 = box_ocr_res[0]
186
- text, score = box_ocr_res[1]
187
- layout_res.append({
188
- 'category_id': 15,
189
- 'poly': p1 + p2 + p3 + p4,
190
- 'score': round(score, 2),
191
- 'text': text,
192
- })
189
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
190
+ ocr_res_list.append(res)
191
+
192
+ # 对每一个需OCR处理的区域进行处理
193
+ for res in ocr_res_list:
194
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
195
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
196
+
197
+ paste_x = 50
198
+ paste_y = 50
199
+ # 创建一个宽高各多50的白色背景
200
+ new_width = xmax - xmin + paste_x * 2
201
+ new_height = ymax - ymin + paste_y * 2
202
+ new_image = Image.new('RGB', (new_width, new_height), 'white')
203
+
204
+ # 裁剪图像
205
+ crop_box = (xmin, ymin, xmax, ymax)
206
+ cropped_img = pil_img.crop(crop_box)
207
+ new_image.paste(cropped_img, (paste_x, paste_y))
208
+
209
+ # 调整公式区域坐标
210
+ adjusted_mfdetrec_res = []
211
+ for mf_res in single_page_mfdetrec_res:
212
+ mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
213
+ # 将公式区域坐标调整为相对于裁剪区域的坐标
214
+ x0 = mf_xmin - xmin + paste_x
215
+ y0 = mf_ymin - ymin + paste_y
216
+ x1 = mf_xmax - xmin + paste_x
217
+ y1 = mf_ymax - ymin + paste_y
218
+ # 过滤在图外的公式块
219
+ if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
220
+ continue
221
+ else:
222
+ adjusted_mfdetrec_res.append({
223
+ "bbox": [x0, y0, x1, y1],
224
+ })
225
+
226
+ # OCR识别
227
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
228
+ ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
229
+
230
+ # 整合结果
231
+ if ocr_res:
232
+ for box_ocr_res in ocr_res:
233
+ p1, p2, p3, p4 = box_ocr_res[0]
234
+ text, score = box_ocr_res[1]
235
+
236
+ # 将坐标转换回原图坐标系
237
+ p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
238
+ p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
239
+ p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
240
+ p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
241
+
242
+ layout_res.append({
243
+ 'category_id': 15,
244
+ 'poly': p1 + p2 + p3 + p4,
245
+ 'score': round(score, 2),
246
+ 'text': text,
247
+ })
248
+
193
249
  ocr_cost = round(time.time() - ocr_start, 2)
194
250
  logger.info(f"ocr cost: {ocr_cost}")
195
251
 
@@ -79,12 +79,13 @@ def setup(args, device):
79
79
  cfg.freeze()
80
80
  default_setup(cfg, args)
81
81
 
82
- register_coco_instances(
83
- "scihub_train",
84
- {},
85
- cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
86
- cfg.SCIHUB_DATA_DIR_TRAIN
87
- )
82
+ #@todo 可以删掉这块?
83
+ # register_coco_instances(
84
+ # "scihub_train",
85
+ # {},
86
+ # cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
87
+ # cfg.SCIHUB_DATA_DIR_TRAIN
88
+ # )
88
89
 
89
90
  return cfg
90
91
 
@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
10
10
  from paddleocr.ppocr.utils.logging import get_logger
11
11
  from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
12
12
  from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
13
+
14
+ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
15
+
13
16
  logger = get_logger()
14
17
 
18
+
15
19
  def img_decode(content: bytes):
16
20
  np_arr = np.frombuffer(content, dtype=np.uint8)
17
21
  return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
18
22
 
23
+
19
24
  def check_img(img):
20
25
  if isinstance(img, bytes):
21
26
  img = img_decode(img)
@@ -51,6 +56,7 @@ def check_img(img):
51
56
 
52
57
  return img
53
58
 
59
+
54
60
  def sorted_boxes(dt_boxes):
55
61
  """
56
62
  Sort text boxes in order from top to bottom, left to right
@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
75
81
  return _boxes
76
82
 
77
83
 
78
- def formula_in_text(mf_bbox, text_bbox):
79
- x1, y1, x2, y2 = mf_bbox
80
- x3, y3 = text_bbox[0]
81
- x4, y4 = text_bbox[2]
82
- left_box, right_box = None, None
83
- same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2
84
- if not same_line:
85
- return False, left_box, right_box
86
- else:
87
- drop_origin = False
88
- left_x = x1 - 1
89
- right_x = x2 + 1
90
- if x3 < x1 and x2 < x4:
91
- drop_origin = True
92
- left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
93
- right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
94
- if x3 < x1 and x1 <= x4 <= x2:
95
- drop_origin = True
96
- left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
97
- if x1 <= x3 <= x2 and x2 < x4:
98
- drop_origin = True
99
- right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
100
- if x1 <= x3 < x4 <= x2:
101
- drop_origin = True
102
- return drop_origin, left_box, right_box
103
-
104
-
105
- def update_det_boxes(dt_boxes, mfdetrec_res):
106
- new_dt_boxes = dt_boxes
107
- for mf_box in mfdetrec_res:
108
- flag, left_box, right_box = False, None, None
109
- for idx, text_box in enumerate(new_dt_boxes):
110
- ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box)
111
- if ret:
112
- new_dt_boxes.pop(idx)
113
- if left_box is not None:
114
- new_dt_boxes.append(left_box)
115
- if right_box is not None:
116
- new_dt_boxes.append(right_box)
117
- break
118
-
84
+ def bbox_to_points(bbox):
85
+ """ 将bbox格式转换为四个顶点的数组 """
86
+ x0, y0, x1, y1 = bbox
87
+ return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
88
+
89
+
90
+ def points_to_bbox(points):
91
+ """ 将四个顶点的数组转换为bbox格式 """
92
+ x0, y0 = points[0]
93
+ x1, _ = points[1]
94
+ _, y1 = points[2]
95
+ return [x0, y0, x1, y1]
96
+
97
+
98
+ def merge_intervals(intervals):
99
+ # Sort the intervals based on the start value
100
+ intervals.sort(key=lambda x: x[0])
101
+
102
+ merged = []
103
+ for interval in intervals:
104
+ # If the list of merged intervals is empty or if the current
105
+ # interval does not overlap with the previous, simply append it.
106
+ if not merged or merged[-1][1] < interval[0]:
107
+ merged.append(interval)
108
+ else:
109
+ # Otherwise, there is overlap, so we merge the current and previous intervals.
110
+ merged[-1][1] = max(merged[-1][1], interval[1])
111
+
112
+ return merged
113
+
114
+
115
+ def remove_intervals(original, masks):
116
+ # Merge all mask intervals
117
+ merged_masks = merge_intervals(masks)
118
+
119
+ result = []
120
+ original_start, original_end = original
121
+
122
+ for mask in merged_masks:
123
+ mask_start, mask_end = mask
124
+
125
+ # If the mask starts after the original range, ignore it
126
+ if mask_start > original_end:
127
+ continue
128
+
129
+ # If the mask ends before the original range starts, ignore it
130
+ if mask_end < original_start:
131
+ continue
132
+
133
+ # Remove the masked part from the original range
134
+ if original_start < mask_start:
135
+ result.append([original_start, mask_start - 1])
136
+
137
+ original_start = max(mask_end + 1, original_start)
138
+
139
+ # Add the remaining part of the original range, if any
140
+ if original_start <= original_end:
141
+ result.append([original_start, original_end])
142
+
143
+ return result
144
+
145
+
146
+ def update_det_boxes(dt_boxes, mfd_res):
147
+ new_dt_boxes = []
148
+ for text_box in dt_boxes:
149
+ text_bbox = points_to_bbox(text_box)
150
+ masks_list = []
151
+ for mf_box in mfd_res:
152
+ mf_bbox = mf_box['bbox']
153
+ if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
154
+ masks_list.append([mf_bbox[0], mf_bbox[2]])
155
+ text_x_range = [text_bbox[0], text_bbox[2]]
156
+ text_remove_mask_range = remove_intervals(text_x_range, masks_list)
157
+ temp_dt_box = []
158
+ for text_remove_mask in text_remove_mask_range:
159
+ temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
160
+ if len(temp_dt_box) > 0:
161
+ new_dt_boxes.extend(temp_dt_box)
119
162
  return new_dt_boxes
120
163
 
164
+
121
165
  class ModifiedPaddleOCR(PaddleOCR):
122
166
  def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
123
167
  """
@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
197
241
  if not rec:
198
242
  return cls_res
199
243
  return ocr_res
200
-
244
+
201
245
  def __call__(self, img, cls=True, mfd_res=None):
202
246
  time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
203
247
 
@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
226
270
  dt_boxes = update_det_boxes(dt_boxes, mfd_res)
227
271
  aft = time.time()
228
272
  logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
229
- len(dt_boxes), aft-bef))
273
+ len(dt_boxes), aft - bef))
230
274
 
231
275
  for bno in range(len(dt_boxes)):
232
276
  tmp_box = copy.deepcopy(dt_boxes[bno])
@@ -5,7 +5,7 @@ from loguru import logger
5
5
  try:
6
6
  from paddleocr import PPStructure
7
7
  except ImportError:
8
- logger.error('paddleocr not installed, please install by "pip install magic-pdf[cpu]" or "pip install magic-pdf[gpu]"')
8
+ logger.error('paddleocr not installed, please install by "pip install magic-pdf[lite]"')
9
9
  exit(1)
10
10
 
11
11
 
@@ -7,7 +7,7 @@ from magic_pdf.layout.layout_sort import get_bboxes_layout, LAYOUT_UNPROC, get_c
7
7
  from magic_pdf.libs.convert_utils import dict_to_list
8
8
  from magic_pdf.libs.drop_reason import DropReason
9
9
  from magic_pdf.libs.hash_utils import compute_md5
10
- from magic_pdf.libs.math import float_equal
10
+ from magic_pdf.libs.local_math import float_equal
11
11
  from magic_pdf.libs.ocr_content_type import ContentType
12
12
  from magic_pdf.model.magic_model import MagicModel
13
13
  from magic_pdf.para.para_split_v2 import para_split
@@ -111,7 +111,8 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
111
111
  spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
112
112
 
113
113
  '''将所有区块的bbox整理到一起'''
114
- # @todo interline_equation_blocks参数不够准,后面切换到interline_equations上
114
+ # interline_equation_blocks参数不够准,后面切换到interline_equations上
115
+ interline_equation_blocks = []
115
116
  if len(interline_equation_blocks) > 0:
116
117
  all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
117
118
  img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
@@ -120,6 +121,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
120
121
  all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
121
122
  img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
122
123
  interline_equations, page_w, page_h)
124
+
123
125
  if len(drop_reasons) > 0:
124
126
  need_drop = True
125
127
  drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
@@ -135,7 +135,11 @@ def remove_citation_marker(with_char_text_blcoks):
135
135
 
136
136
  if max_font_sz-span_font_sz<1: # 先以字体过滤正文,如果是正文就不再继续判断了
137
137
  continue
138
-
138
+
139
+ # 对被除数为0的情况进行过滤
140
+ if span_hi==0 or min_font_sz==0:
141
+ continue
142
+
139
143
  if (base_span_mid_y-span_mid_y)/span_hi>0.2 or (base_span_mid_y-span_mid_y>0 and abs(span_font_sz-min_font_sz)/min_font_sz<0.1):
140
144
  """
141
145
  1. 它的前一个char如果是句号或者逗号的话,那么肯定是角标而不是公式
@@ -36,9 +36,12 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
36
36
  all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
37
37
  '''任何框体与舍弃框重叠,优先信任舍弃框'''
38
38
  all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
39
- # @todo interline_equation 与title或text框冲突的情况,分两种情况处理
39
+
40
+ # interline_equation 与title或text框冲突的情况,分两种情况处理
40
41
  '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
42
+ all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
41
43
  '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
44
+ # 通过后续大框套小框逻辑删除
42
45
 
43
46
  '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
44
47
  for discarded in discarded_blocks:
@@ -57,6 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
57
60
  return all_bboxes, all_discarded_blocks, drop_reasons
58
61
 
59
62
 
63
+ def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
64
+ # 先提取所有text和interline block
65
+ text_blocks = []
66
+ for block in all_bboxes:
67
+ if block[7] == BlockType.Text:
68
+ text_blocks.append(block)
69
+ interline_equation_blocks = []
70
+ for block in all_bboxes:
71
+ if block[7] == BlockType.InterlineEquation:
72
+ interline_equation_blocks.append(block)
73
+
74
+ need_remove = []
75
+
76
+ for interline_equation_block in interline_equation_blocks:
77
+ for text_block in text_blocks:
78
+ interline_equation_block_bbox = interline_equation_block[:4]
79
+ text_block_bbox = text_block[:4]
80
+ if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
81
+ if text_block not in need_remove:
82
+ need_remove.append(text_block)
83
+
84
+ if len(need_remove) > 0:
85
+ for block in need_remove:
86
+ all_bboxes.remove(block)
87
+
88
+ return all_bboxes
89
+
90
+
60
91
  def fix_text_overlap_title_blocks(all_bboxes):
61
92
  # 先提取所有text和title block
62
93
  text_blocks = []
@@ -68,12 +99,19 @@ def fix_text_overlap_title_blocks(all_bboxes):
68
99
  if block[7] == BlockType.Title:
69
100
  title_blocks.append(block)
70
101
 
102
+ need_remove = []
103
+
71
104
  for text_block in text_blocks:
72
105
  for title_block in title_blocks:
73
106
  text_block_bbox = text_block[:4]
74
107
  title_block_bbox = title_block[:4]
75
108
  if calculate_iou(text_block_bbox, title_block_bbox) > 0.8:
76
- all_bboxes.remove(title_block)
109
+ if title_block not in need_remove:
110
+ need_remove.append(title_block)
111
+
112
+ if len(need_remove) > 0:
113
+ for block in need_remove:
114
+ all_bboxes.remove(block)
77
115
 
78
116
  return all_bboxes
79
117
 
@@ -5,19 +5,24 @@ from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, g
5
5
  from magic_pdf.libs.drop_tag import DropTag
6
6
  from magic_pdf.libs.ocr_content_type import ContentType, BlockType
7
7
 
8
+
8
9
  def remove_overlaps_low_confidence_spans(spans):
9
10
  dropped_spans = []
10
11
  # 删除重叠spans中置信度低的的那些
11
12
  for span1 in spans:
12
13
  for span2 in spans:
13
14
  if span1 != span2:
14
- if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
15
- if span1['score'] < span2['score']:
16
- span_need_remove = span1
17
- else:
18
- span_need_remove = span2
19
- if span_need_remove is not None and span_need_remove not in dropped_spans:
20
- dropped_spans.append(span_need_remove)
15
+ # span1 span2 任何一个都不应该在 dropped_spans 中
16
+ if span1 in dropped_spans or span2 in dropped_spans:
17
+ continue
18
+ else:
19
+ if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
20
+ if span1['score'] < span2['score']:
21
+ span_need_remove = span1
22
+ else:
23
+ span_need_remove = span2
24
+ if span_need_remove is not None and span_need_remove not in dropped_spans:
25
+ dropped_spans.append(span_need_remove)
21
26
 
22
27
  if len(dropped_spans) > 0:
23
28
  for span_need_remove in dropped_spans:
@@ -1,6 +1,6 @@
1
1
  AUG:
2
2
  DETR: true
3
- CACHE_DIR: /mnt/localdata/users/yupanhuang/cache/huggingface
3
+ CACHE_DIR: ~/cache/huggingface
4
4
  CUDNN_BENCHMARK: false
5
5
  DATALOADER:
6
6
  ASPECT_RATIO_GROUPING: true
@@ -294,7 +294,7 @@ MODEL:
294
294
  POS_TYPE: abs
295
295
  WEIGHTS:
296
296
  OUTPUT_DIR:
297
- SCIHUB_DATA_DIR_TRAIN: /mnt/petrelfs/share_data/zhaozhiyuan/publaynet/layout_scihub/train
297
+ SCIHUB_DATA_DIR_TRAIN: ~/publaynet/layout_scihub/train
298
298
  SEED: 42
299
299
  SOLVER:
300
300
  AMP: