magic-pdf 0.6.0__py3-none-any.whl → 0.6.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
magic_pdf/cli/magicpdf.py CHANGED
@@ -85,10 +85,10 @@ def do_parse(
85
85
  f_dump_content_list=True,
86
86
  f_make_md_mode=MakeMode.MM_MD,
87
87
  ):
88
+
88
89
  orig_model_list = copy.deepcopy(model_list)
89
90
 
90
91
  local_image_dir, local_md_dir = prepare_env(pdf_file_name, parse_method)
91
- logger.info(f"local output dir is {local_md_dir}")
92
92
  image_writer, md_writer = DiskReaderWriter(local_image_dir), DiskReaderWriter(local_md_dir)
93
93
  image_dir = str(os.path.basename(local_image_dir))
94
94
 
@@ -162,6 +162,7 @@ def do_parse(
162
162
  path=f"{pdf_file_name}_content_list.json",
163
163
  mode=AbsReaderWriter.MODE_TXT,
164
164
  )
165
+ logger.info(f"local output dir is '{local_md_dir}', you can found the result in it.")
165
166
 
166
167
 
167
168
  @click.group()
@@ -179,8 +180,9 @@ def cli():
179
180
  help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
180
181
  default="auto",
181
182
  )
182
- @click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
183
- @click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
183
+ @click.option("--inside_model", type=click.BOOL, default=True, help="使用内置模型测试")
184
+ @click.option("--model_mode", type=click.STRING, default="full",
185
+ help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
184
186
  def json_command(json, method, inside_model, model_mode):
185
187
  model_config.__use_inside_model__ = inside_model
186
188
  model_config.__model_mode__ = model_mode
@@ -232,8 +234,9 @@ def json_command(json, method, inside_model, model_mode):
232
234
  help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
233
235
  default="auto",
234
236
  )
235
- @click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
236
- @click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
237
+ @click.option("--inside_model", type=click.BOOL, default=True, help="使用内置模型测试")
238
+ @click.option("--model_mode", type=click.STRING, default="full",
239
+ help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
237
240
  def local_json_command(local_json, method, inside_model, model_mode):
238
241
  model_config.__use_inside_model__ = inside_model
239
242
  model_config.__model_mode__ = model_mode
@@ -277,8 +280,8 @@ def local_json_command(local_json, method, inside_model, model_mode):
277
280
 
278
281
  @cli.command()
279
282
  @click.option(
280
- "--pdf", type=click.Path(exists=True), required=True, help="PDF文件的路径"
281
- )
283
+ "--pdf", type=click.Path(exists=True), required=True,
284
+ help='pdf 文件路径, 支持单个文件或文件列表, 文件列表需要以".list"结尾, 一行一个pdf文件路径')
282
285
  @click.option("--model", type=click.Path(exists=True), help="模型的路径")
283
286
  @click.option(
284
287
  "--method",
@@ -286,8 +289,9 @@ def local_json_command(local_json, method, inside_model, model_mode):
286
289
  help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
287
290
  default="auto",
288
291
  )
289
- @click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
290
- @click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
292
+ @click.option("--inside_model", type=click.BOOL, default=True, help="使用内置模型测试")
293
+ @click.option("--model_mode", type=click.STRING, default="full",
294
+ help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
291
295
  def pdf_command(pdf, model, method, inside_model, model_mode):
292
296
  model_config.__use_inside_model__ = inside_model
293
297
  model_config.__model_mode__ = model_mode
@@ -296,12 +300,10 @@ def pdf_command(pdf, model, method, inside_model, model_mode):
296
300
  disk_rw = DiskReaderWriter(os.path.dirname(path))
297
301
  return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
298
302
 
299
- pdf_data = read_fn(pdf)
300
-
301
- def get_model_json(model_path):
303
+ def get_model_json(model_path, doc_path):
302
304
  # 这里处理pdf和模型相关的逻辑
303
305
  if model_path is None:
304
- file_name_without_extension, extension = os.path.splitext(pdf)
306
+ file_name_without_extension, extension = os.path.splitext(doc_path)
305
307
  if extension == ".pdf":
306
308
  model_path = file_name_without_extension + ".json"
307
309
  else:
@@ -319,15 +321,35 @@ def pdf_command(pdf, model, method, inside_model, model_mode):
319
321
 
320
322
  return model_json
321
323
 
322
- jso = json_parse.loads(get_model_json(model))
323
- pdf_file_name = Path(pdf).stem
324
+ def parse_doc(doc_path):
325
+ try:
326
+ file_name = str(Path(doc_path).stem)
327
+ pdf_data = read_fn(doc_path)
328
+ jso = json_parse.loads(get_model_json(model, doc_path))
324
329
 
325
- do_parse(
326
- pdf_file_name,
327
- pdf_data,
328
- jso,
329
- method,
330
- )
330
+ do_parse(
331
+ file_name,
332
+ pdf_data,
333
+ jso,
334
+ method,
335
+ )
336
+
337
+ except Exception as e:
338
+ logger.exception(e)
339
+
340
+ if not pdf:
341
+ logger.error(f"Error: Missing argument '--pdf'.")
342
+ exit(f"Error: Missing argument '--pdf'.")
343
+ else:
344
+ '''适配多个文档的list文件输入'''
345
+ if pdf.endswith(".list"):
346
+ with open(pdf, "r") as f:
347
+ for line in f.readlines():
348
+ line = line.strip()
349
+ parse_doc(line)
350
+ else:
351
+ '''适配单个文档的输入'''
352
+ parse_doc(pdf)
331
353
 
332
354
 
333
355
  if __name__ == "__main__":
@@ -112,7 +112,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
112
112
  for line in block['lines']:
113
113
  for span in line['spans']:
114
114
  if span['type'] == ContentType.Image:
115
- para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
115
+ para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
116
116
  for block in para_block['blocks']: # 2nd.拼image_caption
117
117
  if block['type'] == BlockType.ImageCaption:
118
118
  para_text += merge_para_with_text(block)
@@ -128,7 +128,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
128
128
  for line in block['lines']:
129
129
  for span in line['spans']:
130
130
  if span['type'] == ContentType.Table:
131
- para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
131
+ para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
132
132
  for block in para_block['blocks']: # 3rd.拼table_footnote
133
133
  if block['type'] == BlockType.TableFootnote:
134
134
  para_text += merge_para_with_text(block)
@@ -210,28 +210,32 @@ def para_to_standard_format(para, img_buket_path):
210
210
  return para_content
211
211
 
212
212
 
213
- def para_to_standard_format_v2(para_block, img_buket_path):
213
+ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
214
214
  para_type = para_block['type']
215
215
  if para_type == BlockType.Text:
216
216
  para_content = {
217
217
  'type': 'text',
218
218
  'text': merge_para_with_text(para_block),
219
+ 'page_idx': page_idx
219
220
  }
220
221
  elif para_type == BlockType.Title:
221
222
  para_content = {
222
223
  'type': 'text',
223
224
  'text': merge_para_with_text(para_block),
224
- 'text_level': 1
225
+ 'text_level': 1,
226
+ 'page_idx': page_idx
225
227
  }
226
228
  elif para_type == BlockType.InterlineEquation:
227
229
  para_content = {
228
230
  'type': 'equation',
229
231
  'text': merge_para_with_text(para_block),
230
- 'text_format': "latex"
232
+ 'text_format': "latex",
233
+ 'page_idx': page_idx
231
234
  }
232
235
  elif para_type == BlockType.Image:
233
236
  para_content = {
234
237
  'type': 'image',
238
+ 'page_idx': page_idx
235
239
  }
236
240
  for block in para_block['blocks']:
237
241
  if block['type'] == BlockType.ImageBody:
@@ -241,6 +245,7 @@ def para_to_standard_format_v2(para_block, img_buket_path):
241
245
  elif para_type == BlockType.Table:
242
246
  para_content = {
243
247
  'type': 'table',
248
+ 'page_idx': page_idx
244
249
  }
245
250
  for block in para_block['blocks']:
246
251
  if block['type'] == BlockType.TableBody:
@@ -345,6 +350,7 @@ def union_make(pdf_info_dict: list, make_mode: str, drop_mode: str, img_buket_pa
345
350
  raise Exception(f"drop_mode can not be null")
346
351
 
347
352
  paras_of_layout = page_info.get("para_blocks")
353
+ page_idx = page_info.get("page_idx")
348
354
  if not paras_of_layout:
349
355
  continue
350
356
  if make_mode == MakeMode.MM_MD:
@@ -355,7 +361,7 @@ def union_make(pdf_info_dict: list, make_mode: str, drop_mode: str, img_buket_pa
355
361
  output_content.extend(page_markdown)
356
362
  elif make_mode == MakeMode.STANDARD_FORMAT:
357
363
  for para_block in paras_of_layout:
358
- para_content = para_to_standard_format_v2(para_block, img_buket_path)
364
+ para_content = para_to_standard_format_v2(para_block, img_buket_path, page_idx)
359
365
  output_content.append(para_content)
360
366
  if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
361
367
  return '\n\n'.join(output_content)
@@ -10,16 +10,19 @@ from loguru import logger
10
10
 
11
11
  from magic_pdf.libs.commons import parse_bucket_key
12
12
 
13
+ # 定义配置文件名常量
14
+ CONFIG_FILE_NAME = "magic-pdf.json"
15
+
13
16
 
14
17
  def read_config():
15
18
  home_dir = os.path.expanduser("~")
16
19
 
17
- config_file = os.path.join(home_dir, "magic-pdf.json")
20
+ config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
18
21
 
19
22
  if not os.path.exists(config_file):
20
- raise Exception(f"{config_file} not found")
23
+ raise FileNotFoundError(f"{config_file} not found")
21
24
 
22
- with open(config_file, "r") as f:
25
+ with open(config_file, "r", encoding="utf-8") as f:
23
26
  config = json.load(f)
24
27
  return config
25
28
 
@@ -37,7 +40,7 @@ def get_s3_config(bucket_name: str):
37
40
  access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
38
41
 
39
42
  if access_key is None or secret_key is None or storage_endpoint is None:
40
- raise Exception("ak, sk or endpoint not found in magic-pdf.json")
43
+ raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
41
44
 
42
45
  # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
43
46
 
@@ -56,17 +59,32 @@ def get_bucket_name(path):
56
59
 
57
60
  def get_local_dir():
58
61
  config = read_config()
59
- return config.get("temp-output-dir", "/tmp")
62
+ local_dir = config.get("temp-output-dir")
63
+ if local_dir is None:
64
+ logger.warning(f"'temp-output-dir' not found in {CONFIG_FILE_NAME}, use '/tmp' as default")
65
+ return "/tmp"
66
+ else:
67
+ return local_dir
60
68
 
61
69
 
62
70
  def get_local_models_dir():
63
71
  config = read_config()
64
- return config.get("models-dir", "/tmp/models")
72
+ models_dir = config.get("models-dir")
73
+ if models_dir is None:
74
+ logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
75
+ return "/tmp/models"
76
+ else:
77
+ return models_dir
65
78
 
66
79
 
67
80
  def get_device():
68
81
  config = read_config()
69
- return config.get("device-mode", "cpu")
82
+ device = config.get("device-mode")
83
+ if device is None:
84
+ logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
85
+ return "cpu"
86
+ else:
87
+ return device
70
88
 
71
89
 
72
90
  if __name__ == "__main__":
@@ -1,8 +1,19 @@
1
+ import os
1
2
  import unicodedata
3
+
4
+ if not os.getenv("FTLANG_CACHE"):
5
+ current_file_path = os.path.abspath(__file__)
6
+ current_dir = os.path.dirname(current_file_path)
7
+ root_dir = os.path.dirname(current_dir)
8
+ ftlang_cache_dir = os.path.join(root_dir, 'resources', 'fasttext-langdetect')
9
+ os.environ["FTLANG_CACHE"] = str(ftlang_cache_dir)
10
+ # print(os.getenv("FTLANG_CACHE"))
11
+
2
12
  from fast_langdetect import detect_language
3
13
 
4
14
 
5
15
  def detect_lang(text: str) -> str:
16
+
6
17
  if len(text) == 0:
7
18
  return ""
8
19
  try:
@@ -18,6 +29,7 @@ def detect_lang(text: str) -> str:
18
29
 
19
30
 
20
31
  if __name__ == '__main__':
32
+ print(os.getenv("FTLANG_CACHE"))
21
33
  print(detect_lang("This is a test."))
22
34
  print(detect_lang("<html>This is a test</html>"))
23
35
  print(detect_lang("这个是中文测试。"))
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.6.0"
1
+ __version__ = "0.6.2b1"
@@ -1,2 +1,2 @@
1
- __use_inside_model__ = False
1
+ __use_inside_model__ = True
2
2
  __model_mode__ = "full"
@@ -48,10 +48,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
48
48
  return images
49
49
 
50
50
 
51
- def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
51
+ class ModelSingleton:
52
+ _instance = None
53
+ _models = {}
54
+
55
+ def __new__(cls, *args, **kwargs):
56
+ if cls._instance is None:
57
+ cls._instance = super().__new__(cls)
58
+ return cls._instance
59
+
60
+ def get_model(self, ocr: bool, show_log: bool):
61
+ key = (ocr, show_log)
62
+ if key not in self._models:
63
+ self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
64
+ return self._models[key]
65
+
66
+
67
+ def custom_model_init(ocr: bool = False, show_log: bool = False):
52
68
  model = None
53
69
 
54
70
  if model_config.__model_mode__ == "lite":
71
+ logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is "
72
+ "not guaranteed to be reliable.")
55
73
  model = MODEL.Paddle
56
74
  elif model_config.__model_mode__ == "full":
57
75
  model = MODEL.PEK
@@ -76,6 +94,14 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
76
94
  logger.error("use_inside_model is False, not allow to use inside model")
77
95
  exit(1)
78
96
 
97
+ return custom_model
98
+
99
+
100
+ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
101
+
102
+ model_manager = ModelSingleton()
103
+ custom_model = model_manager.get_model(ocr, show_log)
104
+
79
105
  images = load_images_from_pdf(pdf_bytes)
80
106
 
81
107
  model_json = []
@@ -9,13 +9,14 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
9
9
  from magic_pdf.libs.ocr_content_type import ContentType
10
10
  from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
11
11
  from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
12
- from magic_pdf.libs.math import float_gt
12
+ from magic_pdf.libs.local_math import float_gt
13
13
  from magic_pdf.libs.boxbase import (
14
14
  _is_in,
15
15
  bbox_relative_pos,
16
16
  bbox_distance,
17
17
  _is_part_overlap,
18
- calculate_overlap_area_in_bbox1_area_ratio, calculate_iou,
18
+ calculate_overlap_area_in_bbox1_area_ratio,
19
+ calculate_iou,
19
20
  )
20
21
  from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
21
22
 
@@ -78,9 +79,23 @@ class MagicModel:
78
79
  for layout_det2 in layout_dets:
79
80
  if layout_det1 == layout_det2:
80
81
  continue
81
- if layout_det1["category_id"] in [0,1,2,3,4,5,6,7,8,9] and layout_det2["category_id"] in [0,1,2,3,4,5,6,7,8,9]:
82
- if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
83
- if layout_det1['score'] < layout_det2['score']:
82
+ if layout_det1["category_id"] in [
83
+ 0,
84
+ 1,
85
+ 2,
86
+ 3,
87
+ 4,
88
+ 5,
89
+ 6,
90
+ 7,
91
+ 8,
92
+ 9,
93
+ ] and layout_det2["category_id"] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
94
+ if (
95
+ calculate_iou(layout_det1["bbox"], layout_det2["bbox"])
96
+ > 0.9
97
+ ):
98
+ if layout_det1["score"] < layout_det2["score"]:
84
99
  layout_det_need_remove = layout_det1
85
100
  else:
86
101
  layout_det_need_remove = layout_det2
@@ -97,11 +112,11 @@ class MagicModel:
97
112
  def __init__(self, model_list: list, docs: fitz.Document):
98
113
  self.__model_list = model_list
99
114
  self.__docs = docs
100
- '''为所有模型数据添加bbox信息(缩放,poly->bbox)'''
115
+ """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
101
116
  self.__fix_axis()
102
- '''删除置信度特别低的模型数据(<0.05),提高质量'''
117
+ """删除置信度特别低的模型数据(<0.05),提高质量"""
103
118
  self.__fix_by_remove_low_confidence()
104
- '''删除高iou(>0.9)数据中置信度较低的那个'''
119
+ """删除高iou(>0.9)数据中置信度较低的那个"""
105
120
  self.__fix_by_remove_high_iou_and_low_confidence()
106
121
 
107
122
  def __reduct_overlap(self, bboxes):
@@ -125,16 +140,6 @@ class MagicModel:
125
140
  ret = []
126
141
  MAX_DIS_OF_POINT = 10**9 + 7
127
142
 
128
- def expand_bbox(bbox1, bbox2):
129
- x0 = min(bbox1[0], bbox2[0])
130
- y0 = min(bbox1[1], bbox2[1])
131
- x1 = max(bbox1[2], bbox2[2])
132
- y1 = max(bbox1[3], bbox2[3])
133
- return [x0, y0, x1, y1]
134
-
135
- def get_bbox_area(bbox):
136
- return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
137
-
138
143
  # subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
139
144
  # 再求出筛选出的 subjects 和 object 的最短距离!
140
145
  def may_find_other_nearest_bbox(subject_idx, object_idx):
@@ -177,6 +182,13 @@ class MagicModel:
177
182
 
178
183
  return ret
179
184
 
185
+ def expand_bbbox(idxes):
186
+ x0s = [all_bboxes[idx]["bbox"][0] for idx in idxes]
187
+ y0s = [all_bboxes[idx]["bbox"][1] for idx in idxes]
188
+ x1s = [all_bboxes[idx]["bbox"][2] for idx in idxes]
189
+ y1s = [all_bboxes[idx]["bbox"][3] for idx in idxes]
190
+ return min(x0s), min(y0s), max(x1s), max(y1s)
191
+
180
192
  subjects = self.__reduct_overlap(
181
193
  list(
182
194
  map(
@@ -268,7 +280,9 @@ class MagicModel:
268
280
  or dis[i][j] == MAX_DIS_OF_POINT
269
281
  ):
270
282
  continue
271
- left, right, _, _ = bbox_relative_pos(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
283
+ left, right, _, _ = bbox_relative_pos(
284
+ all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
285
+ ) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
272
286
  if left or right:
273
287
  one_way_dis = all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
274
288
  else:
@@ -322,6 +336,10 @@ class MagicModel:
322
336
  break
323
337
 
324
338
  if is_nearest:
339
+ nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
340
+ n_dis = bbox_distance(all_bboxes[i]["bbox"], [nx0, ny0, nx1, ny1])
341
+ if float_gt(dis[i][j], n_dis):
342
+ continue
325
343
  tmp.append(k)
326
344
  seen.add(k)
327
345
 
@@ -331,20 +349,7 @@ class MagicModel:
331
349
 
332
350
  # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
333
351
  # 先扩一下 bbox,
334
- x0s = [all_bboxes[idx]["bbox"][0] for idx in seen] + [
335
- all_bboxes[i]["bbox"][0]
336
- ]
337
- y0s = [all_bboxes[idx]["bbox"][1] for idx in seen] + [
338
- all_bboxes[i]["bbox"][1]
339
- ]
340
- x1s = [all_bboxes[idx]["bbox"][2] for idx in seen] + [
341
- all_bboxes[i]["bbox"][2]
342
- ]
343
- y1s = [all_bboxes[idx]["bbox"][3] for idx in seen] + [
344
- all_bboxes[i]["bbox"][3]
345
- ]
346
-
347
- ox0, oy0, ox1, oy1 = min(x0s), min(y0s), max(x1s), max(y1s)
352
+ ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
348
353
  ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"]
349
354
 
350
355
  # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
@@ -455,8 +460,10 @@ class MagicModel:
455
460
  with_caption_subject.add(j)
456
461
  return ret, total_subject_object_dis
457
462
 
458
- def get_imgs(self, page_no: int): # @许瑞
459
- records, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
463
+ def get_imgs(self, page_no: int):
464
+ figure_captions, _ = self.__tie_up_category_by_distance(
465
+ page_no, 3, 4
466
+ )
460
467
  return [
461
468
  {
462
469
  "bbox": record["all"],
@@ -464,7 +471,7 @@ class MagicModel:
464
471
  "img_caption_bbox": record.get("object_body", None),
465
472
  "score": record["score"],
466
473
  }
467
- for record in records
474
+ for record in figure_captions
468
475
  ]
469
476
 
470
477
  def get_tables(
@@ -535,6 +542,7 @@ class MagicModel:
535
542
  if not any(span == existing_span for existing_span in new_spans):
536
543
  new_spans.append(span)
537
544
  return new_spans
545
+
538
546
  all_spans = []
539
547
  model_page_info = self.__model_list[page_no]
540
548
  layout_dets = model_page_info["layout_dets"]
@@ -548,10 +556,7 @@ class MagicModel:
548
556
  for layout_det in layout_dets:
549
557
  category_id = layout_det["category_id"]
550
558
  if category_id in allow_category_id_list:
551
- span = {
552
- "bbox": layout_det["bbox"],
553
- "score": layout_det["score"]
554
- }
559
+ span = {"bbox": layout_det["bbox"], "score": layout_det["score"]}
555
560
  if category_id == 3:
556
561
  span["type"] = ContentType.Image
557
562
  elif category_id == 5:
@@ -604,7 +609,6 @@ class MagicModel:
604
609
  return self.__model_list[page_no]
605
610
 
606
611
 
607
-
608
612
  if __name__ == "__main__":
609
613
  drw = DiskReaderWriter(r"D:/project/20231108code-clean")
610
614
  if 0: