magic-pdf 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. magic_pdf/__init__.py +0 -0
  2. magic_pdf/cli/__init__.py +0 -0
  3. magic_pdf/cli/magicpdf.py +294 -0
  4. magic_pdf/dict2md/__init__.py +0 -0
  5. magic_pdf/dict2md/mkcontent.py +397 -0
  6. magic_pdf/dict2md/ocr_mkcontent.py +356 -0
  7. magic_pdf/filter/__init__.py +0 -0
  8. magic_pdf/filter/pdf_classify_by_type.py +381 -0
  9. magic_pdf/filter/pdf_meta_scan.py +368 -0
  10. magic_pdf/layout/__init__.py +0 -0
  11. magic_pdf/layout/bbox_sort.py +681 -0
  12. magic_pdf/layout/layout_det_utils.py +182 -0
  13. magic_pdf/layout/layout_sort.py +732 -0
  14. magic_pdf/layout/layout_spiler_recog.py +101 -0
  15. magic_pdf/layout/mcol_sort.py +336 -0
  16. magic_pdf/libs/Constants.py +11 -0
  17. magic_pdf/libs/MakeContentConfig.py +10 -0
  18. magic_pdf/libs/ModelBlockTypeEnum.py +9 -0
  19. magic_pdf/libs/__init__.py +0 -0
  20. magic_pdf/libs/boxbase.py +408 -0
  21. magic_pdf/libs/calc_span_stats.py +239 -0
  22. magic_pdf/libs/commons.py +204 -0
  23. magic_pdf/libs/config_reader.py +63 -0
  24. magic_pdf/libs/convert_utils.py +5 -0
  25. magic_pdf/libs/coordinate_transform.py +9 -0
  26. magic_pdf/libs/detect_language_from_model.py +21 -0
  27. magic_pdf/libs/draw_bbox.py +227 -0
  28. magic_pdf/libs/drop_reason.py +27 -0
  29. magic_pdf/libs/drop_tag.py +19 -0
  30. magic_pdf/libs/hash_utils.py +15 -0
  31. magic_pdf/libs/json_compressor.py +27 -0
  32. magic_pdf/libs/language.py +31 -0
  33. magic_pdf/libs/markdown_utils.py +31 -0
  34. magic_pdf/libs/math.py +9 -0
  35. magic_pdf/libs/nlp_utils.py +203 -0
  36. magic_pdf/libs/ocr_content_type.py +21 -0
  37. magic_pdf/libs/path_utils.py +23 -0
  38. magic_pdf/libs/pdf_image_tools.py +33 -0
  39. magic_pdf/libs/safe_filename.py +11 -0
  40. magic_pdf/libs/textbase.py +33 -0
  41. magic_pdf/libs/version.py +1 -0
  42. magic_pdf/libs/vis_utils.py +308 -0
  43. magic_pdf/model/__init__.py +0 -0
  44. magic_pdf/model/doc_analyze_by_360layout.py +8 -0
  45. magic_pdf/model/doc_analyze_by_pp_structurev2.py +125 -0
  46. magic_pdf/model/magic_model.py +632 -0
  47. magic_pdf/para/__init__.py +0 -0
  48. magic_pdf/para/block_continuation_processor.py +562 -0
  49. magic_pdf/para/block_termination_processor.py +480 -0
  50. magic_pdf/para/commons.py +222 -0
  51. magic_pdf/para/denoise.py +246 -0
  52. magic_pdf/para/draw.py +121 -0
  53. magic_pdf/para/exceptions.py +198 -0
  54. magic_pdf/para/layout_match_processor.py +40 -0
  55. magic_pdf/para/para_pipeline.py +297 -0
  56. magic_pdf/para/para_split.py +644 -0
  57. magic_pdf/para/para_split_v2.py +772 -0
  58. magic_pdf/para/raw_processor.py +207 -0
  59. magic_pdf/para/stats.py +268 -0
  60. magic_pdf/para/title_processor.py +1014 -0
  61. magic_pdf/pdf_parse_by_ocr.py +219 -0
  62. magic_pdf/pdf_parse_by_ocr_v2.py +17 -0
  63. magic_pdf/pdf_parse_by_txt.py +410 -0
  64. magic_pdf/pdf_parse_by_txt_v2.py +56 -0
  65. magic_pdf/pdf_parse_for_train.py +685 -0
  66. magic_pdf/pdf_parse_union_core.py +241 -0
  67. magic_pdf/pipe/AbsPipe.py +112 -0
  68. magic_pdf/pipe/OCRPipe.py +28 -0
  69. magic_pdf/pipe/TXTPipe.py +29 -0
  70. magic_pdf/pipe/UNIPipe.py +83 -0
  71. magic_pdf/pipe/__init__.py +0 -0
  72. magic_pdf/post_proc/__init__.py +0 -0
  73. magic_pdf/post_proc/detect_para.py +3472 -0
  74. magic_pdf/post_proc/pdf_post_filter.py +67 -0
  75. magic_pdf/post_proc/remove_footnote.py +153 -0
  76. magic_pdf/pre_proc/__init__.py +0 -0
  77. magic_pdf/pre_proc/citationmarker_remove.py +157 -0
  78. magic_pdf/pre_proc/construct_page_dict.py +72 -0
  79. magic_pdf/pre_proc/cut_image.py +71 -0
  80. magic_pdf/pre_proc/detect_equation.py +134 -0
  81. magic_pdf/pre_proc/detect_footer_by_model.py +64 -0
  82. magic_pdf/pre_proc/detect_footer_header_by_statistics.py +284 -0
  83. magic_pdf/pre_proc/detect_footnote.py +170 -0
  84. magic_pdf/pre_proc/detect_header.py +64 -0
  85. magic_pdf/pre_proc/detect_images.py +647 -0
  86. magic_pdf/pre_proc/detect_page_number.py +64 -0
  87. magic_pdf/pre_proc/detect_tables.py +62 -0
  88. magic_pdf/pre_proc/equations_replace.py +559 -0
  89. magic_pdf/pre_proc/fix_image.py +244 -0
  90. magic_pdf/pre_proc/fix_table.py +270 -0
  91. magic_pdf/pre_proc/main_text_font.py +23 -0
  92. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +115 -0
  93. magic_pdf/pre_proc/ocr_detect_layout.py +133 -0
  94. magic_pdf/pre_proc/ocr_dict_merge.py +336 -0
  95. magic_pdf/pre_proc/ocr_span_list_modify.py +258 -0
  96. magic_pdf/pre_proc/pdf_pre_filter.py +74 -0
  97. magic_pdf/pre_proc/post_layout_split.py +0 -0
  98. magic_pdf/pre_proc/remove_bbox_overlap.py +98 -0
  99. magic_pdf/pre_proc/remove_colored_strip_bbox.py +79 -0
  100. magic_pdf/pre_proc/remove_footer_header.py +117 -0
  101. magic_pdf/pre_proc/remove_rotate_bbox.py +188 -0
  102. magic_pdf/pre_proc/resolve_bbox_conflict.py +191 -0
  103. magic_pdf/pre_proc/solve_line_alien.py +29 -0
  104. magic_pdf/pre_proc/statistics.py +12 -0
  105. magic_pdf/rw/AbsReaderWriter.py +34 -0
  106. magic_pdf/rw/DiskReaderWriter.py +66 -0
  107. magic_pdf/rw/S3ReaderWriter.py +107 -0
  108. magic_pdf/rw/__init__.py +0 -0
  109. magic_pdf/spark/__init__.py +0 -0
  110. magic_pdf/spark/spark_api.py +51 -0
  111. magic_pdf/train_utils/__init__.py +0 -0
  112. magic_pdf/train_utils/convert_to_train_format.py +65 -0
  113. magic_pdf/train_utils/extract_caption.py +59 -0
  114. magic_pdf/train_utils/remove_footer_header.py +159 -0
  115. magic_pdf/train_utils/vis_utils.py +327 -0
  116. magic_pdf/user_api.py +136 -0
  117. magic_pdf-0.5.4.dist-info/LICENSE.md +661 -0
  118. magic_pdf-0.5.4.dist-info/METADATA +24 -0
  119. magic_pdf-0.5.4.dist-info/RECORD +121 -0
  120. magic_pdf-0.5.4.dist-info/WHEEL +5 -0
  121. magic_pdf-0.5.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,632 @@
1
+ import json
2
+ import math
3
+
4
+ from magic_pdf.libs.commons import fitz
5
+ from loguru import logger
6
+
7
+ from magic_pdf.libs.commons import join_path
8
+ from magic_pdf.libs.coordinate_transform import get_scale_ratio
9
+ from magic_pdf.libs.ocr_content_type import ContentType
10
+ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
11
+ from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
12
+ from magic_pdf.libs.math import float_gt
13
+ from magic_pdf.libs.boxbase import (
14
+ _is_in,
15
+ bbox_relative_pos,
16
+ bbox_distance,
17
+ _is_part_overlap,
18
+ calculate_overlap_area_in_bbox1_area_ratio, calculate_iou,
19
+ )
20
+ from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
21
+
22
+ CAPATION_OVERLAP_AREA_RATIO = 0.6
23
+
24
+
25
+ class MagicModel:
26
+ """
27
+ 每个函数没有得到元素的时候返回空list
28
+
29
+ """
30
+
31
+ def __fix_axis(self):
32
+ for model_page_info in self.__model_list:
33
+ need_remove_list = []
34
+ page_no = model_page_info["page_info"]["page_no"]
35
+ horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
36
+ model_page_info, self.__docs[page_no]
37
+ )
38
+ layout_dets = model_page_info["layout_dets"]
39
+ for layout_det in layout_dets:
40
+
41
+ if layout_det.get("bbox") is not None:
42
+ # 兼容直接输出bbox的模型数据,如paddle
43
+ x0, y0, x1, y1 = layout_det["bbox"]
44
+ else:
45
+ # 兼容直接输出poly的模型数据,如xxx
46
+ x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
47
+
48
+ bbox = [
49
+ int(x0 / horizontal_scale_ratio),
50
+ int(y0 / vertical_scale_ratio),
51
+ int(x1 / horizontal_scale_ratio),
52
+ int(y1 / vertical_scale_ratio),
53
+ ]
54
+ layout_det["bbox"] = bbox
55
+ # 删除高度或者宽度小于等于0的spans
56
+ if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
57
+ need_remove_list.append(layout_det)
58
+ for need_remove in need_remove_list:
59
+ layout_dets.remove(need_remove)
60
+
61
+ def __fix_by_remove_low_confidence(self):
62
+ for model_page_info in self.__model_list:
63
+ need_remove_list = []
64
+ layout_dets = model_page_info["layout_dets"]
65
+ for layout_det in layout_dets:
66
+ if layout_det["score"] <= 0.05:
67
+ need_remove_list.append(layout_det)
68
+ else:
69
+ continue
70
+ for need_remove in need_remove_list:
71
+ layout_dets.remove(need_remove)
72
+
73
+ def __fix_by_remove_high_iou_and_low_confidence(self):
74
+ for model_page_info in self.__model_list:
75
+ need_remove_list = []
76
+ layout_dets = model_page_info["layout_dets"]
77
+ for layout_det1 in layout_dets:
78
+ for layout_det2 in layout_dets:
79
+ if layout_det1 == layout_det2:
80
+ continue
81
+ if layout_det1["category_id"] in [0,1,2,3,4,5,6,7,8,9] and layout_det2["category_id"] in [0,1,2,3,4,5,6,7,8,9]:
82
+ if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
83
+ if layout_det1['score'] < layout_det2['score']:
84
+ layout_det_need_remove = layout_det1
85
+ else:
86
+ layout_det_need_remove = layout_det2
87
+
88
+ if layout_det_need_remove not in need_remove_list:
89
+ need_remove_list.append(layout_det_need_remove)
90
+ else:
91
+ continue
92
+ else:
93
+ continue
94
+ for need_remove in need_remove_list:
95
+ layout_dets.remove(need_remove)
96
+
97
+ def __init__(self, model_list: list, docs: fitz.Document):
98
+ self.__model_list = model_list
99
+ self.__docs = docs
100
+ '''为所有模型数据添加bbox信息(缩放,poly->bbox)'''
101
+ self.__fix_axis()
102
+ '''删除置信度特别低的模型数据(<0.05),提高质量'''
103
+ self.__fix_by_remove_low_confidence()
104
+ '''删除高iou(>0.9)数据中置信度较低的那个'''
105
+ self.__fix_by_remove_high_iou_and_low_confidence()
106
+
107
+ def __reduct_overlap(self, bboxes):
108
+ N = len(bboxes)
109
+ keep = [True] * N
110
+ for i in range(N):
111
+ for j in range(N):
112
+ if i == j:
113
+ continue
114
+ if _is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
115
+ keep[i] = False
116
+
117
+ return [bboxes[i] for i in range(N) if keep[i]]
118
+
119
+ def __tie_up_category_by_distance(
120
+ self, page_no, subject_category_id, object_category_id
121
+ ):
122
+ """
123
+ 假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
124
+ """
125
+ ret = []
126
+ MAX_DIS_OF_POINT = 10**9 + 7
127
+
128
+ def expand_bbox(bbox1, bbox2):
129
+ x0 = min(bbox1[0], bbox2[0])
130
+ y0 = min(bbox1[1], bbox2[1])
131
+ x1 = max(bbox1[2], bbox2[2])
132
+ y1 = max(bbox1[3], bbox2[3])
133
+ return [x0, y0, x1, y1]
134
+
135
+ def get_bbox_area(bbox):
136
+ return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
137
+
138
+ # subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
139
+ # 再求出筛选出的 subjects 和 object 的最短距离!
140
+ def may_find_other_nearest_bbox(subject_idx, object_idx):
141
+ ret = float("inf")
142
+
143
+ x0 = min(
144
+ all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]
145
+ )
146
+ y0 = min(
147
+ all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1]
148
+ )
149
+ x1 = max(
150
+ all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2]
151
+ )
152
+ y1 = max(
153
+ all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3]
154
+ )
155
+
156
+ object_area = abs(
157
+ all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]
158
+ ) * abs(
159
+ all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]
160
+ )
161
+
162
+ for i in range(len(all_bboxes)):
163
+ if (
164
+ i == subject_idx
165
+ or all_bboxes[i]["category_id"] != subject_category_id
166
+ ):
167
+ continue
168
+ if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
169
+ all_bboxes[i]["bbox"], [x0, y0, x1, y1]
170
+ ):
171
+
172
+ i_area = abs(
173
+ all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
174
+ ) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
175
+ if i_area >= object_area:
176
+ ret = min(float("inf"), dis[i][object_idx])
177
+
178
+ return ret
179
+
180
+ subjects = self.__reduct_overlap(
181
+ list(
182
+ map(
183
+ lambda x: {"bbox": x["bbox"], "score": x["score"]},
184
+ filter(
185
+ lambda x: x["category_id"] == subject_category_id,
186
+ self.__model_list[page_no]["layout_dets"],
187
+ ),
188
+ )
189
+ )
190
+ )
191
+
192
+ objects = self.__reduct_overlap(
193
+ list(
194
+ map(
195
+ lambda x: {"bbox": x["bbox"], "score": x["score"]},
196
+ filter(
197
+ lambda x: x["category_id"] == object_category_id,
198
+ self.__model_list[page_no]["layout_dets"],
199
+ ),
200
+ )
201
+ )
202
+ )
203
+ subject_object_relation_map = {}
204
+
205
+ subjects.sort(
206
+ key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2
207
+ ) # get the distance !
208
+
209
+ all_bboxes = []
210
+
211
+ for v in subjects:
212
+ all_bboxes.append(
213
+ {
214
+ "category_id": subject_category_id,
215
+ "bbox": v["bbox"],
216
+ "score": v["score"],
217
+ }
218
+ )
219
+
220
+ for v in objects:
221
+ all_bboxes.append(
222
+ {
223
+ "category_id": object_category_id,
224
+ "bbox": v["bbox"],
225
+ "score": v["score"],
226
+ }
227
+ )
228
+
229
+ N = len(all_bboxes)
230
+ dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
231
+
232
+ for i in range(N):
233
+ for j in range(i):
234
+ if (
235
+ all_bboxes[i]["category_id"] == subject_category_id
236
+ and all_bboxes[j]["category_id"] == subject_category_id
237
+ ):
238
+ continue
239
+
240
+ dis[i][j] = bbox_distance(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"])
241
+ dis[j][i] = dis[i][j]
242
+
243
+ used = set()
244
+ for i in range(N):
245
+ # 求第 i 个 subject 所关联的 object
246
+ if all_bboxes[i]["category_id"] != subject_category_id:
247
+ continue
248
+ seen = set()
249
+ candidates = []
250
+ arr = []
251
+ for j in range(N):
252
+
253
+ pos_flag_count = sum(
254
+ list(
255
+ map(
256
+ lambda x: 1 if x else 0,
257
+ bbox_relative_pos(
258
+ all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
259
+ ),
260
+ )
261
+ )
262
+ )
263
+ if pos_flag_count > 1:
264
+ continue
265
+ if (
266
+ all_bboxes[j]["category_id"] != object_category_id
267
+ or j in used
268
+ or dis[i][j] == MAX_DIS_OF_POINT
269
+ ):
270
+ continue
271
+ left, right, _, _ = bbox_relative_pos(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
272
+ if left or right:
273
+ one_way_dis = all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
274
+ else:
275
+ one_way_dis = all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1]
276
+ if dis[i][j] > one_way_dis:
277
+ continue
278
+ arr.append((dis[i][j], j))
279
+
280
+ arr.sort(key=lambda x: x[0])
281
+ if len(arr) > 0:
282
+ # bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject]
283
+ if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
284
+
285
+ candidates.append(arr[0][1])
286
+ seen.add(arr[0][1])
287
+
288
+ # 已经获取初始种子
289
+ for j in set(candidates):
290
+ tmp = []
291
+ for k in range(i + 1, N):
292
+ pos_flag_count = sum(
293
+ list(
294
+ map(
295
+ lambda x: 1 if x else 0,
296
+ bbox_relative_pos(
297
+ all_bboxes[j]["bbox"], all_bboxes[k]["bbox"]
298
+ ),
299
+ )
300
+ )
301
+ )
302
+
303
+ if pos_flag_count > 1:
304
+ continue
305
+
306
+ if (
307
+ all_bboxes[k]["category_id"] != object_category_id
308
+ or k in used
309
+ or k in seen
310
+ or dis[j][k] == MAX_DIS_OF_POINT
311
+ or dis[j][k] > dis[i][j]
312
+ ):
313
+ continue
314
+
315
+ is_nearest = True
316
+ for l in range(i + 1, N):
317
+ if l in (j, k) or l in used or l in seen:
318
+ continue
319
+
320
+ if not float_gt(dis[l][k], dis[j][k]):
321
+ is_nearest = False
322
+ break
323
+
324
+ if is_nearest:
325
+ tmp.append(k)
326
+ seen.add(k)
327
+
328
+ candidates = tmp
329
+ if len(candidates) == 0:
330
+ break
331
+
332
+ # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
333
+ # 先扩一下 bbox,
334
+ x0s = [all_bboxes[idx]["bbox"][0] for idx in seen] + [
335
+ all_bboxes[i]["bbox"][0]
336
+ ]
337
+ y0s = [all_bboxes[idx]["bbox"][1] for idx in seen] + [
338
+ all_bboxes[i]["bbox"][1]
339
+ ]
340
+ x1s = [all_bboxes[idx]["bbox"][2] for idx in seen] + [
341
+ all_bboxes[i]["bbox"][2]
342
+ ]
343
+ y1s = [all_bboxes[idx]["bbox"][3] for idx in seen] + [
344
+ all_bboxes[i]["bbox"][3]
345
+ ]
346
+
347
+ ox0, oy0, ox1, oy1 = min(x0s), min(y0s), max(x1s), max(y1s)
348
+ ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"]
349
+
350
+ # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
351
+ caption_poses = [
352
+ [ox0, oy0, ix0, oy1],
353
+ [ox0, oy0, ox1, iy0],
354
+ [ox0, iy1, ox1, oy1],
355
+ [ix1, oy0, ox1, oy1],
356
+ ]
357
+
358
+ caption_areas = []
359
+ for bbox in caption_poses:
360
+ embed_arr = []
361
+ for idx in seen:
362
+ if (
363
+ calculate_overlap_area_in_bbox1_area_ratio(
364
+ all_bboxes[idx]["bbox"], bbox
365
+ )
366
+ > CAPATION_OVERLAP_AREA_RATIO
367
+ ):
368
+ embed_arr.append(idx)
369
+
370
+ if len(embed_arr) > 0:
371
+ embed_x0 = min([all_bboxes[idx]["bbox"][0] for idx in embed_arr])
372
+ embed_y0 = min([all_bboxes[idx]["bbox"][1] for idx in embed_arr])
373
+ embed_x1 = max([all_bboxes[idx]["bbox"][2] for idx in embed_arr])
374
+ embed_y1 = max([all_bboxes[idx]["bbox"][3] for idx in embed_arr])
375
+ caption_areas.append(
376
+ int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
377
+ )
378
+ else:
379
+ caption_areas.append(0)
380
+
381
+ subject_object_relation_map[i] = []
382
+ if max(caption_areas) > 0:
383
+ max_area_idx = caption_areas.index(max(caption_areas))
384
+ caption_bbox = caption_poses[max_area_idx]
385
+
386
+ for j in seen:
387
+ if (
388
+ calculate_overlap_area_in_bbox1_area_ratio(
389
+ all_bboxes[j]["bbox"], caption_bbox
390
+ )
391
+ > CAPATION_OVERLAP_AREA_RATIO
392
+ ):
393
+ used.add(j)
394
+ subject_object_relation_map[i].append(j)
395
+
396
+ for i in sorted(subject_object_relation_map.keys()):
397
+ result = {
398
+ "subject_body": all_bboxes[i]["bbox"],
399
+ "all": all_bboxes[i]["bbox"],
400
+ "score": all_bboxes[i]["score"],
401
+ }
402
+
403
+ if len(subject_object_relation_map[i]) > 0:
404
+ x0 = min(
405
+ [all_bboxes[j]["bbox"][0] for j in subject_object_relation_map[i]]
406
+ )
407
+ y0 = min(
408
+ [all_bboxes[j]["bbox"][1] for j in subject_object_relation_map[i]]
409
+ )
410
+ x1 = max(
411
+ [all_bboxes[j]["bbox"][2] for j in subject_object_relation_map[i]]
412
+ )
413
+ y1 = max(
414
+ [all_bboxes[j]["bbox"][3] for j in subject_object_relation_map[i]]
415
+ )
416
+ result["object_body"] = [x0, y0, x1, y1]
417
+ result["all"] = [
418
+ min(x0, all_bboxes[i]["bbox"][0]),
419
+ min(y0, all_bboxes[i]["bbox"][1]),
420
+ max(x1, all_bboxes[i]["bbox"][2]),
421
+ max(y1, all_bboxes[i]["bbox"][3]),
422
+ ]
423
+ ret.append(result)
424
+
425
+ total_subject_object_dis = 0
426
+ # 计算已经配对的 distance 距离
427
+ for i in subject_object_relation_map.keys():
428
+ for j in subject_object_relation_map[i]:
429
+ total_subject_object_dis += bbox_distance(
430
+ all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
431
+ )
432
+
433
+ # 计算未匹配的 subject 和 object 的距离(非精确版)
434
+ with_caption_subject = set(
435
+ [
436
+ key
437
+ for key in subject_object_relation_map.keys()
438
+ if len(subject_object_relation_map[i]) > 0
439
+ ]
440
+ )
441
+ for i in range(N):
442
+ if all_bboxes[i]["category_id"] != object_category_id or i in used:
443
+ continue
444
+ candidates = []
445
+ for j in range(N):
446
+ if (
447
+ all_bboxes[j]["category_id"] != subject_category_id
448
+ or j in with_caption_subject
449
+ ):
450
+ continue
451
+ candidates.append((dis[i][j], j))
452
+ if len(candidates) > 0:
453
+ candidates.sort(key=lambda x: x[0])
454
+ total_subject_object_dis += candidates[0][1]
455
+ with_caption_subject.add(j)
456
+ return ret, total_subject_object_dis
457
+
458
+ def get_imgs(self, page_no: int): # @许瑞
459
+ records, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
460
+ return [
461
+ {
462
+ "bbox": record["all"],
463
+ "img_body_bbox": record["subject_body"],
464
+ "img_caption_bbox": record.get("object_body", None),
465
+ "score": record["score"],
466
+ }
467
+ for record in records
468
+ ]
469
+
470
+ def get_tables(
471
+ self, page_no: int
472
+ ) -> list: # 3个坐标, caption, table主体,table-note
473
+ with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
474
+ with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
475
+ ret = []
476
+ N, M = len(with_captions), len(with_footnotes)
477
+ assert N == M
478
+ for i in range(N):
479
+ record = {
480
+ "score": with_captions[i]["score"],
481
+ "table_caption_bbox": with_captions[i].get("object_body", None),
482
+ "table_body_bbox": with_captions[i]["subject_body"],
483
+ "table_footnote_bbox": with_footnotes[i].get("object_body", None),
484
+ }
485
+
486
+ x0 = min(with_captions[i]["all"][0], with_footnotes[i]["all"][0])
487
+ y0 = min(with_captions[i]["all"][1], with_footnotes[i]["all"][1])
488
+ x1 = max(with_captions[i]["all"][2], with_footnotes[i]["all"][2])
489
+ y1 = max(with_captions[i]["all"][3], with_footnotes[i]["all"][3])
490
+ record["bbox"] = [x0, y0, x1, y1]
491
+ ret.append(record)
492
+ return ret
493
+
494
+ def get_equations(self, page_no: int) -> list: # 有坐标,也有字
495
+ inline_equations = self.__get_blocks_by_type(
496
+ ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"]
497
+ )
498
+ interline_equations = self.__get_blocks_by_type(
499
+ ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"]
500
+ )
501
+ interline_equations_blocks = self.__get_blocks_by_type(
502
+ ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
503
+ )
504
+ return inline_equations, interline_equations, interline_equations_blocks
505
+
506
+ def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
507
+ blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
508
+ return blocks
509
+
510
+ def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字
511
+ blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
512
+ return blocks
513
+
514
+ def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字
515
+ blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
516
+ return blocks
517
+
518
+ def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
519
+ text_spans = []
520
+ model_page_info = self.__model_list[page_no]
521
+ layout_dets = model_page_info["layout_dets"]
522
+ for layout_det in layout_dets:
523
+ if layout_det["category_id"] == "15":
524
+ span = {
525
+ "bbox": layout_det["bbox"],
526
+ "content": layout_det["text"],
527
+ }
528
+ text_spans.append(span)
529
+ return text_spans
530
+
531
+ def get_all_spans(self, page_no: int) -> list:
532
+ def remove_duplicate_spans(spans):
533
+ new_spans = []
534
+ for span in spans:
535
+ if not any(span == existing_span for existing_span in new_spans):
536
+ new_spans.append(span)
537
+ return new_spans
538
+ all_spans = []
539
+ model_page_info = self.__model_list[page_no]
540
+ layout_dets = model_page_info["layout_dets"]
541
+ allow_category_id_list = [3, 5, 13, 14, 15]
542
+ """当成span拼接的"""
543
+ # 3: 'image', # 图片
544
+ # 5: 'table', # 表格
545
+ # 13: 'inline_equation', # 行内公式
546
+ # 14: 'interline_equation', # 行间公式
547
+ # 15: 'text', # ocr识别文本
548
+ for layout_det in layout_dets:
549
+ category_id = layout_det["category_id"]
550
+ if category_id in allow_category_id_list:
551
+ span = {
552
+ "bbox": layout_det["bbox"],
553
+ "score": layout_det["score"]
554
+ }
555
+ if category_id == 3:
556
+ span["type"] = ContentType.Image
557
+ elif category_id == 5:
558
+ span["type"] = ContentType.Table
559
+ elif category_id == 13:
560
+ span["content"] = layout_det["latex"]
561
+ span["type"] = ContentType.InlineEquation
562
+ elif category_id == 14:
563
+ span["content"] = layout_det["latex"]
564
+ span["type"] = ContentType.InterlineEquation
565
+ elif category_id == 15:
566
+ span["content"] = layout_det["text"]
567
+ span["type"] = ContentType.Text
568
+ all_spans.append(span)
569
+ return remove_duplicate_spans(all_spans)
570
+
571
+ def get_page_size(self, page_no: int): # 获取页面宽高
572
+ # 获取当前页的page对象
573
+ page = self.__docs[page_no]
574
+ # 获取当前页的宽高
575
+ page_w = page.rect.width
576
+ page_h = page.rect.height
577
+ return page_w, page_h
578
+
579
+ def __get_blocks_by_type(
580
+ self, type: int, page_no: int, extra_col: list[str] = []
581
+ ) -> list:
582
+ blocks = []
583
+ for page_dict in self.__model_list:
584
+ layout_dets = page_dict.get("layout_dets", [])
585
+ page_info = page_dict.get("page_info", {})
586
+ page_number = page_info.get("page_no", -1)
587
+ if page_no != page_number:
588
+ continue
589
+ for item in layout_dets:
590
+ category_id = item.get("category_id", -1)
591
+ bbox = item.get("bbox", None)
592
+
593
+ if category_id == type:
594
+ block = {
595
+ "bbox": bbox,
596
+ "score": item.get("score"),
597
+ }
598
+ for col in extra_col:
599
+ block[col] = item.get(col, None)
600
+ blocks.append(block)
601
+ return blocks
602
+
603
+ def get_model_list(self, page_no):
604
+ return self.__model_list[page_no]
605
+
606
+
607
+
608
+ if __name__ == "__main__":
609
+ drw = DiskReaderWriter(r"D:/project/20231108code-clean")
610
+ if 0:
611
+ pdf_file_path = r"linshixuqiu\19983-00.pdf"
612
+ model_file_path = r"linshixuqiu\19983-00_new.json"
613
+ pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
614
+ model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
615
+ model_list = json.loads(model_json_txt)
616
+ write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
617
+ img_bucket_path = "imgs"
618
+ img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
619
+ pdf_docs = fitz.open("pdf", pdf_bytes)
620
+ magic_model = MagicModel(model_list, pdf_docs)
621
+
622
+ if 1:
623
+ model_list = json.loads(
624
+ drw.read("/opt/data/pdf/20240418/j.chroma.2009.03.042.json")
625
+ )
626
+ pdf_bytes = drw.read(
627
+ "/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf", AbsReaderWriter.MODE_BIN
628
+ )
629
+ pdf_docs = fitz.open("pdf", pdf_bytes)
630
+ magic_model = MagicModel(model_list, pdf_docs)
631
+ for i in range(7):
632
+ print(magic_model.get_imgs(i))
File without changes