magic-pdf 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +1 -1
  2. magic_pdf/libs/Constants.py +3 -1
  3. magic_pdf/libs/config_reader.py +1 -1
  4. magic_pdf/libs/draw_bbox.py +10 -4
  5. magic_pdf/libs/version.py +1 -1
  6. magic_pdf/model/pdf_extract_kit.py +42 -310
  7. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
  8. magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
  9. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
  10. magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
  11. magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
  12. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
  13. magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
  14. magic_pdf/model/sub_modules/model_init.py +144 -0
  15. magic_pdf/model/sub_modules/model_utils.py +51 -0
  16. magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
  17. magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
  19. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
  20. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
  21. magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
  22. magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
  23. magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
  24. magic_pdf/model/sub_modules/table/__init__.py +0 -0
  25. magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
  26. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
  27. magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
  28. magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
  29. magic_pdf/model/sub_modules/table/table_utils.py +11 -0
  30. magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
  31. magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
  32. magic_pdf/para/para_split_v3.py +13 -15
  33. magic_pdf/pdf_parse_union_core_v2.py +56 -19
  34. magic_pdf/resources/model_config/model_configs.yaml +2 -1
  35. magic_pdf/tools/common.py +47 -3
  36. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +35 -25
  37. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
  38. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
  39. magic_pdf/model/pek_sub_modules/post_process.py +0 -36
  40. magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
  41. /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
  42. /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
  43. /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
  44. /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
  45. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
  46. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
  47. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
  48. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
  49. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
  50. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
  51. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
  52. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
  53. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
  54. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
  55. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
  56. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
  57. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
  58. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
  59. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
  60. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
  61. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
  62. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
  63. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
  64. /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
  65. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
  66. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,213 @@
1
+ import copy
2
+ import time
3
+
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from paddleocr import PaddleOCR
8
+ from paddleocr.paddleocr import check_img, logger
9
+ from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
10
+ from paddleocr.tools.infer.predict_system import sorted_boxes
11
+ from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
12
+ get_minarea_rect_crop
13
+
14
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
15
+
16
+
17
+ class ModifiedPaddleOCR(PaddleOCR):
18
+
19
+ def ocr(
20
+ self,
21
+ img,
22
+ det=True,
23
+ rec=True,
24
+ cls=True,
25
+ bin=False,
26
+ inv=False,
27
+ alpha_color=(255, 255, 255),
28
+ slice={},
29
+ mfd_res=None,
30
+ ):
31
+ """
32
+ OCR with PaddleOCR
33
+
34
+ Args:
35
+ img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
36
+ det: Use text detection or not. If False, only text recognition will be executed. Default is True.
37
+ rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
38
+ cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
39
+ bin: Binarize image to black and white. Default is False.
40
+ inv: Invert image colors. Default is False.
41
+ alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
42
+ slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
43
+
44
+ Returns:
45
+ If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
46
+ If det is True and rec is False, returns a list of detected bounding boxes for each image.
47
+ If det is False and rec is True, returns a list of recognized text for each image.
48
+ If both det and rec are False, returns a list of angle classification results for each image.
49
+
50
+ Raises:
51
+ AssertionError: If the input image is not of type ndarray, list, str, or bytes.
52
+ SystemExit: If det is True and the input is a list of images.
53
+
54
+ Note:
55
+ - If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
56
+ - For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
57
+ - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
58
+ """
59
+ assert isinstance(img, (np.ndarray, list, str, bytes))
60
+ if isinstance(img, list) and det == True:
61
+ logger.error("When input a list of images, det must be false")
62
+ exit(0)
63
+ if cls == True and self.use_angle_cls == False:
64
+ logger.warning(
65
+ "Since the angle classifier is not initialized, it will not be used during the forward process"
66
+ )
67
+
68
+ img, flag_gif, flag_pdf = check_img(img, alpha_color)
69
+ # for infer pdf file
70
+ if isinstance(img, list) and flag_pdf:
71
+ if self.page_num > len(img) or self.page_num == 0:
72
+ imgs = img
73
+ else:
74
+ imgs = img[: self.page_num]
75
+ else:
76
+ imgs = [img]
77
+
78
+ def preprocess_image(_image):
79
+ _image = alpha_to_color(_image, alpha_color)
80
+ if inv:
81
+ _image = cv2.bitwise_not(_image)
82
+ if bin:
83
+ _image = binarize_img(_image)
84
+ return _image
85
+
86
+ if det and rec:
87
+ ocr_res = []
88
+ for img in imgs:
89
+ img = preprocess_image(img)
90
+ dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
91
+ if not dt_boxes and not rec_res:
92
+ ocr_res.append(None)
93
+ continue
94
+ tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
95
+ ocr_res.append(tmp_res)
96
+ return ocr_res
97
+ elif det and not rec:
98
+ ocr_res = []
99
+ for img in imgs:
100
+ img = preprocess_image(img)
101
+ dt_boxes, elapse = self.text_detector(img)
102
+ if dt_boxes.size == 0:
103
+ ocr_res.append(None)
104
+ continue
105
+ tmp_res = [box.tolist() for box in dt_boxes]
106
+ ocr_res.append(tmp_res)
107
+ return ocr_res
108
+ else:
109
+ ocr_res = []
110
+ cls_res = []
111
+ for img in imgs:
112
+ if not isinstance(img, list):
113
+ img = preprocess_image(img)
114
+ img = [img]
115
+ if self.use_angle_cls and cls:
116
+ img, cls_res_tmp, elapse = self.text_classifier(img)
117
+ if not rec:
118
+ cls_res.append(cls_res_tmp)
119
+ rec_res, elapse = self.text_recognizer(img)
120
+ ocr_res.append(rec_res)
121
+ if not rec:
122
+ return cls_res
123
+ return ocr_res
124
+
125
+ def __call__(self, img, cls=True, slice={}, mfd_res=None):
126
+ time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
127
+
128
+ if img is None:
129
+ logger.debug("no valid image provided")
130
+ return None, None, time_dict
131
+
132
+ start = time.time()
133
+ ori_im = img.copy()
134
+ if slice:
135
+ slice_gen = slice_generator(
136
+ img,
137
+ horizontal_stride=slice["horizontal_stride"],
138
+ vertical_stride=slice["vertical_stride"],
139
+ )
140
+ elapsed = []
141
+ dt_slice_boxes = []
142
+ for slice_crop, v_start, h_start in slice_gen:
143
+ dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
144
+ if dt_boxes.size:
145
+ dt_boxes[:, :, 0] += h_start
146
+ dt_boxes[:, :, 1] += v_start
147
+ dt_slice_boxes.append(dt_boxes)
148
+ elapsed.append(elapse)
149
+ dt_boxes = np.concatenate(dt_slice_boxes)
150
+
151
+ dt_boxes = merge_fragmented(
152
+ boxes=dt_boxes,
153
+ x_threshold=slice["merge_x_thres"],
154
+ y_threshold=slice["merge_y_thres"],
155
+ )
156
+ elapse = sum(elapsed)
157
+ else:
158
+ dt_boxes, elapse = self.text_detector(img)
159
+
160
+ time_dict["det"] = elapse
161
+
162
+ if dt_boxes is None:
163
+ logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
164
+ end = time.time()
165
+ time_dict["all"] = end - start
166
+ return None, None, time_dict
167
+ else:
168
+ logger.debug(
169
+ "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
170
+ )
171
+ img_crop_list = []
172
+
173
+ dt_boxes = sorted_boxes(dt_boxes)
174
+
175
+ if mfd_res:
176
+ bef = time.time()
177
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
178
+ aft = time.time()
179
+ logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
180
+ len(dt_boxes), aft - bef))
181
+
182
+ for bno in range(len(dt_boxes)):
183
+ tmp_box = copy.deepcopy(dt_boxes[bno])
184
+ if self.args.det_box_type == "quad":
185
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
186
+ else:
187
+ img_crop = get_minarea_rect_crop(ori_im, tmp_box)
188
+ img_crop_list.append(img_crop)
189
+ if self.use_angle_cls and cls:
190
+ img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
191
+ time_dict["cls"] = elapse
192
+ logger.debug(
193
+ "cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
194
+ )
195
+ if len(img_crop_list) > 1000:
196
+ logger.debug(
197
+ f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
198
+ )
199
+
200
+ rec_res, elapse = self.text_recognizer(img_crop_list)
201
+ time_dict["rec"] = elapse
202
+ logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
203
+ if self.args.save_crop_res:
204
+ self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
205
+ filter_boxes, filter_rec_res = [], []
206
+ for box, rec_result in zip(dt_boxes, rec_res):
207
+ text, score = rec_result[0], rec_result[1]
208
+ if score >= self.drop_score:
209
+ filter_boxes.append(box)
210
+ filter_rec_res.append(rec_result)
211
+ end = time.time()
212
+ time_dict["all"] = end - start
213
+ return filter_boxes, filter_rec_res, time_dict
File without changes
@@ -0,0 +1,242 @@
1
+ from typing import List
2
+ import cv2
3
+ import numpy as np
4
+
5
+
6
+ def projection_by_bboxes(boxes: np.array, axis: int) -> np.ndarray:
7
+ """
8
+ 通过一组 bbox 获得投影直方图,最后以 per-pixel 形式输出
9
+
10
+ Args:
11
+ boxes: [N, 4]
12
+ axis: 0-x坐标向水平方向投影, 1-y坐标向垂直方向投影
13
+
14
+ Returns:
15
+ 1D 投影直方图,长度为投影方向坐标的最大值(我们不需要图片的实际边长,因为只是要找文本框的间隔)
16
+
17
+ """
18
+ assert axis in [0, 1]
19
+ length = np.max(boxes[:, axis::2])
20
+ res = np.zeros(length, dtype=int)
21
+ # TODO: how to remove for loop?
22
+ for start, end in boxes[:, axis::2]:
23
+ res[start:end] += 1
24
+ return res
25
+
26
+
27
+ # from: https://dothinking.github.io/2021-06-19-%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%E7%AE%97%E6%B3%95/#:~:text=%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%EF%BC%88Recursive%20XY,%EF%BC%8C%E5%8F%AF%E4%BB%A5%E5%88%92%E5%88%86%E6%AE%B5%E8%90%BD%E3%80%81%E8%A1%8C%E3%80%82
28
+ def split_projection_profile(arr_values: np.array, min_value: float, min_gap: float):
29
+ """Split projection profile:
30
+
31
+ ```
32
+ ┌──┐
33
+ arr_values │ │ ┌─┐───
34
+ ┌──┐ │ │ │ │ |
35
+ │ │ │ │ ┌───┐ │ │min_value
36
+ │ │<- min_gap ->│ │ │ │ │ │ |
37
+ ────┴──┴─────────────┴──┴─┴───┴─┴─┴─┴───
38
+ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
39
+ ```
40
+
41
+ Args:
42
+ arr_values (np.array): 1-d array representing the projection profile.
43
+ min_value (float): Ignore the profile if `arr_value` is less than `min_value`.
44
+ min_gap (float): Ignore the gap if less than this value.
45
+
46
+ Returns:
47
+ tuple: Start indexes and end indexes of split groups.
48
+ """
49
+ # all indexes with projection height exceeding the threshold
50
+ arr_index = np.where(arr_values > min_value)[0]
51
+ if not len(arr_index):
52
+ return
53
+
54
+ # find zero intervals between adjacent projections
55
+ # | | ||
56
+ # ||||<- zero-interval -> |||||
57
+ arr_diff = arr_index[1:] - arr_index[0:-1]
58
+ arr_diff_index = np.where(arr_diff > min_gap)[0]
59
+ arr_zero_intvl_start = arr_index[arr_diff_index]
60
+ arr_zero_intvl_end = arr_index[arr_diff_index + 1]
61
+
62
+ # convert to index of projection range:
63
+ # the start index of zero interval is the end index of projection
64
+ arr_start = np.insert(arr_zero_intvl_end, 0, arr_index[0])
65
+ arr_end = np.append(arr_zero_intvl_start, arr_index[-1])
66
+ arr_end += 1 # end index will be excluded as index slice
67
+
68
+ return arr_start, arr_end
69
+
70
+
71
+ def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int]):
72
+ """
73
+
74
+ Args:
75
+ boxes: (N, 4)
76
+ indices: 递归过程中始终表示 box 在原始数据中的索引
77
+ res: 保存输出结果
78
+
79
+ """
80
+ # 向 y 轴投影
81
+ assert len(boxes) == len(indices)
82
+
83
+ _indices = boxes[:, 1].argsort()
84
+ y_sorted_boxes = boxes[_indices]
85
+ y_sorted_indices = indices[_indices]
86
+
87
+ # debug_vis(y_sorted_boxes, y_sorted_indices)
88
+
89
+ y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
90
+ pos_y = split_projection_profile(y_projection, 0, 1)
91
+ if not pos_y:
92
+ return
93
+
94
+ arr_y0, arr_y1 = pos_y
95
+ for r0, r1 in zip(arr_y0, arr_y1):
96
+ # [r0, r1] 表示按照水平切分,有 bbox 的区域,对这些区域会再进行垂直切分
97
+ _indices = (r0 <= y_sorted_boxes[:, 1]) & (y_sorted_boxes[:, 1] < r1)
98
+
99
+ y_sorted_boxes_chunk = y_sorted_boxes[_indices]
100
+ y_sorted_indices_chunk = y_sorted_indices[_indices]
101
+
102
+ _indices = y_sorted_boxes_chunk[:, 0].argsort()
103
+ x_sorted_boxes_chunk = y_sorted_boxes_chunk[_indices]
104
+ x_sorted_indices_chunk = y_sorted_indices_chunk[_indices]
105
+
106
+ # 往 x 方向投影
107
+ x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
108
+ pos_x = split_projection_profile(x_projection, 0, 1)
109
+ if not pos_x:
110
+ continue
111
+
112
+ arr_x0, arr_x1 = pos_x
113
+ if len(arr_x0) == 1:
114
+ # x 方向无法切分
115
+ res.extend(x_sorted_indices_chunk)
116
+ continue
117
+
118
+ # x 方向上能分开,继续递归调用
119
+ for c0, c1 in zip(arr_x0, arr_x1):
120
+ _indices = (c0 <= x_sorted_boxes_chunk[:, 0]) & (
121
+ x_sorted_boxes_chunk[:, 0] < c1
122
+ )
123
+ recursive_xy_cut(
124
+ x_sorted_boxes_chunk[_indices], x_sorted_indices_chunk[_indices], res
125
+ )
126
+
127
+
128
+ def points_to_bbox(points):
129
+ assert len(points) == 8
130
+
131
+ # [x1,y1,x2,y2,x3,y3,x4,y4]
132
+ left = min(points[::2])
133
+ right = max(points[::2])
134
+ top = min(points[1::2])
135
+ bottom = max(points[1::2])
136
+
137
+ left = max(left, 0)
138
+ top = max(top, 0)
139
+ right = max(right, 0)
140
+ bottom = max(bottom, 0)
141
+ return [left, top, right, bottom]
142
+
143
+
144
+ def bbox2points(bbox):
145
+ left, top, right, bottom = bbox
146
+ return [left, top, right, top, right, bottom, left, bottom]
147
+
148
+
149
+ def vis_polygon(img, points, thickness=2, color=None):
150
+ br2bl_color = color
151
+ tl2tr_color = color
152
+ tr2br_color = color
153
+ bl2tl_color = color
154
+ cv2.line(
155
+ img,
156
+ (points[0][0], points[0][1]),
157
+ (points[1][0], points[1][1]),
158
+ color=tl2tr_color,
159
+ thickness=thickness,
160
+ )
161
+
162
+ cv2.line(
163
+ img,
164
+ (points[1][0], points[1][1]),
165
+ (points[2][0], points[2][1]),
166
+ color=tr2br_color,
167
+ thickness=thickness,
168
+ )
169
+
170
+ cv2.line(
171
+ img,
172
+ (points[2][0], points[2][1]),
173
+ (points[3][0], points[3][1]),
174
+ color=br2bl_color,
175
+ thickness=thickness,
176
+ )
177
+
178
+ cv2.line(
179
+ img,
180
+ (points[3][0], points[3][1]),
181
+ (points[0][0], points[0][1]),
182
+ color=bl2tl_color,
183
+ thickness=thickness,
184
+ )
185
+ return img
186
+
187
+
188
+ def vis_points(
189
+ img: np.ndarray, points, texts: List[str] = None, color=(0, 200, 0)
190
+ ) -> np.ndarray:
191
+ """
192
+
193
+ Args:
194
+ img:
195
+ points: [N, 8] 8: x1,y1,x2,y2,x3,y3,x3,y4
196
+ texts:
197
+ color:
198
+
199
+ Returns:
200
+
201
+ """
202
+ points = np.array(points)
203
+ if texts is not None:
204
+ assert len(texts) == points.shape[0]
205
+
206
+ for i, _points in enumerate(points):
207
+ vis_polygon(img, _points.reshape(-1, 2), thickness=2, color=color)
208
+ bbox = points_to_bbox(_points)
209
+ left, top, right, bottom = bbox
210
+ cx = (left + right) // 2
211
+ cy = (top + bottom) // 2
212
+
213
+ txt = texts[i]
214
+ font = cv2.FONT_HERSHEY_SIMPLEX
215
+ cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
216
+
217
+ img = cv2.rectangle(
218
+ img,
219
+ (cx - 5 * len(txt), cy - cat_size[1] - 5),
220
+ (cx - 5 * len(txt) + cat_size[0], cy - 5),
221
+ color,
222
+ -1,
223
+ )
224
+
225
+ img = cv2.putText(
226
+ img,
227
+ txt,
228
+ (cx - 5 * len(txt), cy - 5),
229
+ font,
230
+ 0.5,
231
+ (255, 255, 255),
232
+ thickness=1,
233
+ lineType=cv2.LINE_AA,
234
+ )
235
+
236
+ return img
237
+
238
+
239
+ def vis_polygons_with_index(image, points):
240
+ texts = [str(i) for i in range(len(points))]
241
+ res_img = vis_points(image.copy(), points, texts)
242
+ return res_img
File without changes
@@ -0,0 +1,14 @@
1
+ import numpy as np
2
+ from rapid_table import RapidTable
3
+ from rapidocr_paddle import RapidOCR
4
+
5
+
6
+ class RapidTableModel(object):
7
+ def __init__(self):
8
+ self.table_model = RapidTable()
9
+ self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
10
+
11
+ def predict(self, image):
12
+ ocr_result, _ = self.ocr_engine(np.asarray(image))
13
+ html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
14
+ return html_code, table_cell_bboxes, elapse
@@ -1,8 +1,8 @@
1
- import re
2
-
3
1
  import torch
4
2
  from struct_eqtable import build_model
5
3
 
4
+ from magic_pdf.model.sub_modules.table.table_utils import minify_html
5
+
6
6
 
7
7
  class StructTableModel:
8
8
  def __init__(self, model_path, max_new_tokens=1024, max_time=60):
@@ -31,15 +31,7 @@ class StructTableModel:
31
31
  )
32
32
 
33
33
  if output_format == "html":
34
- results = [self.minify_html(html) for html in results]
34
+ results = [minify_html(html) for html in results]
35
35
 
36
36
  return results
37
37
 
38
- def minify_html(self, html):
39
- # 移除多余的空白字符
40
- html = re.sub(r'\s+', ' ', html)
41
- # 移除行尾的空白字符
42
- html = re.sub(r'\s*>\s*', '>', html)
43
- # 移除标签前的空白字符
44
- html = re.sub(r'\s*<\s*', '<', html)
45
- return html.strip()
@@ -0,0 +1,11 @@
1
+ import re
2
+
3
+
4
+ def minify_html(html):
5
+ # 移除多余的空白字符
6
+ html = re.sub(r'\s+', ' ', html)
7
+ # 移除行尾的空白字符
8
+ html = re.sub(r'\s*>\s*', '>', html)
9
+ # 移除标签前的空白字符
10
+ html = re.sub(r'\s*<\s*', '<', html)
11
+ return html.strip()
@@ -7,7 +7,7 @@ from PIL import Image
7
7
  import numpy as np
8
8
 
9
9
 
10
- class ppTableModel(object):
10
+ class TableMasterPaddleModel(object):
11
11
  """
12
12
  This class is responsible for converting image of table into HTML format using a pre-trained model.
13
13
 
@@ -77,14 +77,12 @@ def __is_list_or_index_block(block):
77
77
 
78
78
  # 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
79
79
  if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
80
- # block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
81
80
  abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
82
81
  block['bbox_fs'][2] - last_line['bbox'][2] > line_height
83
82
  ):
84
83
  multiple_para_flag = True
85
84
 
86
85
  for line in block['lines']:
87
-
88
86
  line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
89
87
  block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
90
88
  if (
@@ -102,13 +100,13 @@ def __is_list_or_index_block(block):
102
100
  if span_type == ContentType.Text:
103
101
  line_text += span['content'].strip()
104
102
 
103
+ # 添加所有文本,包括空行,保持与block['lines']长度一致
105
104
  lines_text_list.append(line_text)
106
105
 
107
106
  # 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
108
107
  if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
109
108
  left_close_num += 1
110
109
  elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
111
- # logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
112
110
  left_not_close_num += 1
113
111
 
114
112
  # 计算右侧是否顶格
@@ -117,7 +115,6 @@ def __is_list_or_index_block(block):
117
115
  else:
118
116
  # 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
119
117
  closed_area = 0.26 * block_weight
120
- # closed_area = 5 * line_height
121
118
  if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
122
119
  right_not_close_num += 1
123
120
 
@@ -128,6 +125,7 @@ def __is_list_or_index_block(block):
128
125
  num_start_count = 0
129
126
  num_end_count = 0
130
127
  flag_end_count = 0
128
+
131
129
  if len(lines_text_list) > 0:
132
130
  for line_text in lines_text_list:
133
131
  if len(line_text) > 0:
@@ -138,11 +136,10 @@ def __is_list_or_index_block(block):
138
136
  if line_text[-1].isdigit():
139
137
  num_end_count += 1
140
138
 
141
- if flag_end_count / len(lines_text_list) >= 0.8:
142
- line_end_flag = True
143
-
144
139
  if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
145
140
  line_num_flag = True
141
+ if flag_end_count / len(lines_text_list) >= 0.8:
142
+ line_end_flag = True
146
143
 
147
144
  # 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
148
145
  if ((left_close_num / len(block['lines']) >= 0.8 or right_close_num / len(block['lines']) >= 0.8)
@@ -176,7 +173,7 @@ def __is_list_or_index_block(block):
176
173
  # 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
177
174
  elif line_end_flag:
178
175
  for i, line in enumerate(block['lines']):
179
- if lines_text_list[i][-1] in LIST_END_FLAG:
176
+ if len(lines_text_list[i]) > 0 and lines_text_list[i][-1] in LIST_END_FLAG:
180
177
  line[ListLineTag.IS_LIST_END_LINE] = True
181
178
  if i + 1 < len(block['lines']):
182
179
  block['lines'][i + 1][ListLineTag.IS_LIST_START_LINE] = True
@@ -187,17 +184,18 @@ def __is_list_or_index_block(block):
187
184
  if line_start_flag:
188
185
  line[ListLineTag.IS_LIST_START_LINE] = True
189
186
  line_start_flag = False
190
- # elif abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
187
+
191
188
  if abs(block['bbox_fs'][2] - line['bbox'][2]) > 0.1 * block_weight:
192
189
  line[ListLineTag.IS_LIST_END_LINE] = True
193
190
  line_start_flag = True
194
- # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
195
- elif num_start_count >= 2 and num_start_count == flag_end_count: # 简单一点先不考虑左侧不贴边的情况
191
+ # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
192
+ elif num_start_count >= 2 and num_start_count == flag_end_count:
196
193
  for i, line in enumerate(block['lines']):
197
- if lines_text_list[i][0].isdigit():
198
- line[ListLineTag.IS_LIST_START_LINE] = True
199
- if lines_text_list[i][-1] in LIST_END_FLAG:
200
- line[ListLineTag.IS_LIST_END_LINE] = True
194
+ if len(lines_text_list[i]) > 0:
195
+ if lines_text_list[i][0].isdigit():
196
+ line[ListLineTag.IS_LIST_START_LINE] = True
197
+ if lines_text_list[i][-1] in LIST_END_FLAG:
198
+ line[ListLineTag.IS_LIST_END_LINE] = True
201
199
  else:
202
200
  # 正常有缩进的list处理
203
201
  for line in block['lines']: