magic-pdf 0.10.4__py3-none-any.whl → 0.10.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. magic_pdf/config/constants.py +5 -0
  2. magic_pdf/data/data_reader_writer/base.py +13 -1
  3. magic_pdf/data/dataset.py +175 -4
  4. magic_pdf/data/utils.py +2 -2
  5. magic_pdf/dict2md/ocr_mkcontent.py +2 -2
  6. magic_pdf/filter/__init__.py +32 -0
  7. magic_pdf/filter/pdf_meta_scan.py +3 -2
  8. magic_pdf/libs/draw_bbox.py +11 -10
  9. magic_pdf/libs/pdf_check.py +30 -30
  10. magic_pdf/libs/version.py +1 -1
  11. magic_pdf/model/__init__.py +124 -0
  12. magic_pdf/model/doc_analyze_by_custom_model.py +119 -60
  13. magic_pdf/model/operators.py +190 -0
  14. magic_pdf/model/pdf_extract_kit.py +20 -1
  15. magic_pdf/model/sub_modules/model_init.py +13 -3
  16. magic_pdf/model/sub_modules/model_utils.py +11 -5
  17. magic_pdf/para/para_split_v3.py +2 -2
  18. magic_pdf/pdf_parse_by_ocr.py +4 -5
  19. magic_pdf/pdf_parse_by_txt.py +4 -5
  20. magic_pdf/pdf_parse_union_core_v2.py +10 -11
  21. magic_pdf/pipe/AbsPipe.py +3 -2
  22. magic_pdf/pipe/OCRPipe.py +54 -15
  23. magic_pdf/pipe/TXTPipe.py +5 -4
  24. magic_pdf/pipe/UNIPipe.py +82 -30
  25. magic_pdf/pipe/operators.py +138 -0
  26. magic_pdf/pre_proc/cut_image.py +2 -2
  27. magic_pdf/tools/common.py +108 -59
  28. magic_pdf/user_api.py +47 -24
  29. {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/METADATA +7 -4
  30. {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/RECORD +34 -32
  31. {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/LICENSE.md +0 -0
  32. {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/WHEEL +0 -0
  33. {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/entry_points.txt +0 -0
  34. {magic_pdf-0.10.4.dist-info → magic_pdf-0.10.6.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,34 @@
1
+ import os
1
2
  import time
2
3
 
3
4
  import fitz
4
5
  import numpy as np
5
6
  from loguru import logger
6
7
 
8
+ # 关闭paddle的信号处理
9
+ import paddle
10
+ paddle.disable_signal_handler()
11
+
12
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
13
+ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
14
+
15
+ try:
16
+ import torchtext
17
+
18
+ if torchtext.__version__ >= '0.18.0':
19
+ torchtext.disable_torchtext_deprecation_warning()
20
+ except ImportError:
21
+ pass
22
+
23
+ import magic_pdf.model as model_config
24
+ from magic_pdf.data.dataset import Dataset
7
25
  from magic_pdf.libs.clean_memory import clean_memory
8
- from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
9
- get_formula_config
26
+ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
27
+ get_layout_config,
28
+ get_local_models_dir,
29
+ get_table_recog_config)
10
30
  from magic_pdf.model.model_list import MODEL
11
- import magic_pdf.model as model_config
31
+ from magic_pdf.model.operators import InferenceResult
12
32
 
13
33
 
14
34
  def dict_compare(d1, d2):
@@ -19,25 +39,31 @@ def remove_duplicates_dicts(lst):
19
39
  unique_dicts = []
20
40
  for dict_item in lst:
21
41
  if not any(
22
- dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
42
+ dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
23
43
  ):
24
44
  unique_dicts.append(dict_item)
25
45
  return unique_dicts
26
46
 
27
47
 
28
- def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
48
+ def load_images_from_pdf(
49
+ pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
50
+ ) -> list:
29
51
  try:
30
52
  from PIL import Image
31
53
  except ImportError:
32
- logger.error("Pillow not installed, please install by pip.")
54
+ logger.error('Pillow not installed, please install by pip.')
33
55
  exit(1)
34
56
 
35
57
  images = []
36
- with fitz.open("pdf", pdf_bytes) as doc:
58
+ with fitz.open('pdf', pdf_bytes) as doc:
37
59
  pdf_page_num = doc.page_count
38
- end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
60
+ end_page_id = (
61
+ end_page_id
62
+ if end_page_id is not None and end_page_id >= 0
63
+ else pdf_page_num - 1
64
+ )
39
65
  if end_page_id > pdf_page_num - 1:
40
- logger.warning("end_page_id is out of range, use images length")
66
+ logger.warning('end_page_id is out of range, use images length')
41
67
  end_page_id = pdf_page_num - 1
42
68
 
43
69
  for index in range(0, doc.page_count):
@@ -50,11 +76,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
50
76
  if pm.width > 4500 or pm.height > 4500:
51
77
  pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
52
78
 
53
- img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
79
+ img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
54
80
  img = np.array(img)
55
- img_dict = {"img": img, "width": pm.width, "height": pm.height}
81
+ img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
56
82
  else:
57
- img_dict = {"img": [], "width": 0, "height": 0}
83
+ img_dict = {'img': [], 'width': 0, 'height': 0}
58
84
 
59
85
  images.append(img_dict)
60
86
  return images
@@ -69,117 +95,150 @@ class ModelSingleton:
69
95
  cls._instance = super().__new__(cls)
70
96
  return cls._instance
71
97
 
72
- def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
98
+ def get_model(
99
+ self,
100
+ ocr: bool,
101
+ show_log: bool,
102
+ lang=None,
103
+ layout_model=None,
104
+ formula_enable=None,
105
+ table_enable=None,
106
+ ):
73
107
  key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
74
108
  if key not in self._models:
75
- self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
76
- formula_enable=formula_enable, table_enable=table_enable)
109
+ self._models[key] = custom_model_init(
110
+ ocr=ocr,
111
+ show_log=show_log,
112
+ lang=lang,
113
+ layout_model=layout_model,
114
+ formula_enable=formula_enable,
115
+ table_enable=table_enable,
116
+ )
77
117
  return self._models[key]
78
118
 
79
119
 
80
- def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
81
- layout_model=None, formula_enable=None, table_enable=None):
120
+ def custom_model_init(
121
+ ocr: bool = False,
122
+ show_log: bool = False,
123
+ lang=None,
124
+ layout_model=None,
125
+ formula_enable=None,
126
+ table_enable=None,
127
+ ):
82
128
 
83
129
  model = None
84
130
 
85
- if model_config.__model_mode__ == "lite":
86
- logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is "
87
- "not guaranteed to be reliable.")
131
+ if model_config.__model_mode__ == 'lite':
132
+ logger.warning(
133
+ 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
134
+ 'not guaranteed to be reliable.'
135
+ )
88
136
  model = MODEL.Paddle
89
- elif model_config.__model_mode__ == "full":
137
+ elif model_config.__model_mode__ == 'full':
90
138
  model = MODEL.PEK
91
139
 
92
140
  if model_config.__use_inside_model__:
93
141
  model_init_start = time.time()
94
142
  if model == MODEL.Paddle:
95
143
  from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
144
+
96
145
  custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
97
146
  elif model == MODEL.PEK:
98
147
  from magic_pdf.model.pdf_extract_kit import CustomPEKModel
148
+
99
149
  # 从配置文件读取model-dir和device
100
150
  local_models_dir = get_local_models_dir()
101
151
  device = get_device()
102
152
 
103
153
  layout_config = get_layout_config()
104
154
  if layout_model is not None:
105
- layout_config["model"] = layout_model
155
+ layout_config['model'] = layout_model
106
156
 
107
157
  formula_config = get_formula_config()
108
158
  if formula_enable is not None:
109
- formula_config["enable"] = formula_enable
159
+ formula_config['enable'] = formula_enable
110
160
 
111
161
  table_config = get_table_recog_config()
112
162
  if table_enable is not None:
113
- table_config["enable"] = table_enable
163
+ table_config['enable'] = table_enable
114
164
 
115
165
  model_input = {
116
- "ocr": ocr,
117
- "show_log": show_log,
118
- "models_dir": local_models_dir,
119
- "device": device,
120
- "table_config": table_config,
121
- "layout_config": layout_config,
122
- "formula_config": formula_config,
123
- "lang": lang,
166
+ 'ocr': ocr,
167
+ 'show_log': show_log,
168
+ 'models_dir': local_models_dir,
169
+ 'device': device,
170
+ 'table_config': table_config,
171
+ 'layout_config': layout_config,
172
+ 'formula_config': formula_config,
173
+ 'lang': lang,
124
174
  }
125
175
 
126
176
  custom_model = CustomPEKModel(**model_input)
127
177
  else:
128
- logger.error("Not allow model_name!")
178
+ logger.error('Not allow model_name!')
129
179
  exit(1)
130
180
  model_init_cost = time.time() - model_init_start
131
- logger.info(f"model init cost: {model_init_cost}")
181
+ logger.info(f'model init cost: {model_init_cost}')
132
182
  else:
133
- logger.error("use_inside_model is False, not allow to use inside model")
183
+ logger.error('use_inside_model is False, not allow to use inside model')
134
184
  exit(1)
135
185
 
136
186
  return custom_model
137
187
 
138
188
 
139
- def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
140
- start_page_id=0, end_page_id=None, lang=None,
141
- layout_model=None, formula_enable=None, table_enable=None):
189
+ def doc_analyze(
190
+ dataset: Dataset,
191
+ ocr: bool = False,
192
+ show_log: bool = False,
193
+ start_page_id=0,
194
+ end_page_id=None,
195
+ lang=None,
196
+ layout_model=None,
197
+ formula_enable=None,
198
+ table_enable=None,
199
+ ) -> InferenceResult:
142
200
 
143
- if lang == "":
201
+ if lang == '':
144
202
  lang = None
145
203
 
146
204
  model_manager = ModelSingleton()
147
- custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
148
-
149
- with fitz.open("pdf", pdf_bytes) as doc:
150
- pdf_page_num = doc.page_count
151
- end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
152
- if end_page_id > pdf_page_num - 1:
153
- logger.warning("end_page_id is out of range, use images length")
154
- end_page_id = pdf_page_num - 1
155
-
156
- images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
205
+ custom_model = model_manager.get_model(
206
+ ocr, show_log, lang, layout_model, formula_enable, table_enable
207
+ )
157
208
 
158
209
  model_json = []
159
210
  doc_analyze_start = time.time()
160
211
 
161
- for index, img_dict in enumerate(images):
162
- img = img_dict["img"]
163
- page_width = img_dict["width"]
164
- page_height = img_dict["height"]
212
+ if end_page_id is None:
213
+ end_page_id = len(dataset)
214
+
215
+ for index in range(len(dataset)):
216
+ page_data = dataset.get_page(index)
217
+ img_dict = page_data.get_image()
218
+ img = img_dict['img']
219
+ page_width = img_dict['width']
220
+ page_height = img_dict['height']
165
221
  if start_page_id <= index <= end_page_id:
166
222
  page_start = time.time()
167
223
  result = custom_model(img)
168
224
  logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
169
225
  else:
170
226
  result = []
171
- page_info = {"page_no": index, "height": page_height, "width": page_width}
172
- page_dict = {"layout_dets": result, "page_info": page_info}
227
+
228
+ page_info = {'page_no': index, 'height': page_height, 'width': page_width}
229
+ page_dict = {'layout_dets': result, 'page_info': page_info}
173
230
  model_json.append(page_dict)
174
231
 
175
232
  gc_start = time.time()
176
233
  clean_memory()
177
234
  gc_time = round(time.time() - gc_start, 2)
178
- logger.info(f"gc time: {gc_time}")
235
+ logger.info(f'gc time: {gc_time}')
179
236
 
180
237
  doc_analyze_time = round(time.time() - doc_analyze_start, 2)
181
- doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
182
- logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
183
- f" speed: {doc_analyze_speed} pages/second")
238
+ doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
239
+ logger.info(
240
+ f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
241
+ f' speed: {doc_analyze_speed} pages/second'
242
+ )
184
243
 
185
- return model_json
244
+ return InferenceResult(model_json, dataset)
@@ -0,0 +1,190 @@
1
+ import copy
2
+ import json
3
+ import os
4
+ from typing import Callable
5
+
6
+ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
7
+ from magic_pdf.config.enums import SupportedPdfParseMethod
8
+ from magic_pdf.data.data_reader_writer import DataWriter
9
+ from magic_pdf.data.dataset import Dataset
10
+ from magic_pdf.filter import classify
11
+ from magic_pdf.libs.draw_bbox import draw_model_bbox
12
+ from magic_pdf.libs.version import __version__
13
+ from magic_pdf.model import InferenceResultBase
14
+ from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
15
+ from magic_pdf.pipe.operators import PipeResult
16
+
17
+
18
+ class InferenceResult(InferenceResultBase):
19
+ def __init__(self, inference_results: list, dataset: Dataset):
20
+ """Initialized method.
21
+
22
+ Args:
23
+ inference_results (list): the inference result generated by model
24
+ dataset (Dataset): the dataset related with model inference result
25
+ """
26
+ self._infer_res = inference_results
27
+ self._dataset = dataset
28
+
29
+ def draw_model(self, file_path: str) -> None:
30
+ """Draw model inference result.
31
+
32
+ Args:
33
+ file_path (str): the output file path
34
+ """
35
+ dir_name = os.path.dirname(file_path)
36
+ base_name = os.path.basename(file_path)
37
+ if not os.path.exists(dir_name):
38
+ os.makedirs(dir_name, exist_ok=True)
39
+ draw_model_bbox(
40
+ copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
41
+ )
42
+
43
+ def dump_model(self, writer: DataWriter, file_path: str):
44
+ """Dump model inference result to file.
45
+
46
+ Args:
47
+ writer (DataWriter): writer handle
48
+ file_path (str): the location of target file
49
+ """
50
+ writer.write_string(
51
+ file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
52
+ )
53
+
54
+ def get_infer_res(self):
55
+ """Get the inference result.
56
+
57
+ Returns:
58
+ list: the inference result generated by model
59
+ """
60
+ return self._infer_res
61
+
62
+ def apply(self, proc: Callable, *args, **kwargs):
63
+ """Apply callable method which.
64
+
65
+ Args:
66
+ proc (Callable): invoke proc as follows:
67
+ proc(inference_result, *args, **kwargs)
68
+
69
+ Returns:
70
+ Any: return the result generated by proc
71
+ """
72
+ return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
73
+
74
+ def pipe_auto_mode(
75
+ self,
76
+ imageWriter: DataWriter,
77
+ start_page_id=0,
78
+ end_page_id=None,
79
+ debug_mode=False,
80
+ lang=None,
81
+ ) -> PipeResult:
82
+ """Post-proc the model inference result.
83
+ step1: classify the dataset type
84
+ step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
85
+
86
+ Args:
87
+ imageWriter (DataWriter): the image writer handle
88
+ start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
89
+ end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
90
+ debug_mode (bool, optional): Defaults to False. will dump more log if enabled
91
+ lang (str, optional): Defaults to None.
92
+
93
+ Returns:
94
+ PipeResult: the result
95
+ """
96
+
97
+ pdf_proc_method = classify(self._dataset.data_bits())
98
+
99
+ if pdf_proc_method == SupportedPdfParseMethod.TXT:
100
+ return self.pipe_txt_mode(
101
+ imageWriter, start_page_id, end_page_id, debug_mode, lang
102
+ )
103
+ else:
104
+ return self.pipe_ocr_mode(
105
+ imageWriter, start_page_id, end_page_id, debug_mode, lang
106
+ )
107
+
108
+ def pipe_txt_mode(
109
+ self,
110
+ imageWriter: DataWriter,
111
+ start_page_id=0,
112
+ end_page_id=None,
113
+ debug_mode=False,
114
+ lang=None,
115
+ ) -> PipeResult:
116
+ """Post-proc the model inference result, Extract the text using the
117
+ third library, such as `pymupdf`
118
+
119
+ Args:
120
+ imageWriter (DataWriter): the image writer handle
121
+ start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
122
+ end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
123
+ debug_mode (bool, optional): Defaults to False. will dump more log if enabled
124
+ lang (str, optional): Defaults to None.
125
+
126
+ Returns:
127
+ PipeResult: the result
128
+ """
129
+
130
+ def proc(*args, **kwargs) -> PipeResult:
131
+ res = pdf_parse_union(*args, **kwargs)
132
+ res['_parse_type'] = PARSE_TYPE_TXT
133
+ res['_version_name'] = __version__
134
+ if 'lang' in kwargs and kwargs['lang'] is not None:
135
+ res['lang'] = kwargs['lang']
136
+ return PipeResult(res, self._dataset)
137
+
138
+ res = self.apply(
139
+ proc,
140
+ self._dataset,
141
+ imageWriter,
142
+ SupportedPdfParseMethod.TXT,
143
+ start_page_id=start_page_id,
144
+ end_page_id=end_page_id,
145
+ debug_mode=debug_mode,
146
+ lang=lang,
147
+ )
148
+ return res
149
+
150
+ def pipe_ocr_mode(
151
+ self,
152
+ imageWriter: DataWriter,
153
+ start_page_id=0,
154
+ end_page_id=None,
155
+ debug_mode=False,
156
+ lang=None,
157
+ ) -> PipeResult:
158
+ """Post-proc the model inference result, Extract the text using `OCR`
159
+ technical.
160
+
161
+ Args:
162
+ imageWriter (DataWriter): the image writer handle
163
+ start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
164
+ end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
165
+ debug_mode (bool, optional): Defaults to False. will dump more log if enabled
166
+ lang (str, optional): Defaults to None.
167
+
168
+ Returns:
169
+ PipeResult: the result
170
+ """
171
+
172
+ def proc(*args, **kwargs) -> PipeResult:
173
+ res = pdf_parse_union(*args, **kwargs)
174
+ res['_parse_type'] = PARSE_TYPE_OCR
175
+ res['_version_name'] = __version__
176
+ if 'lang' in kwargs and kwargs['lang'] is not None:
177
+ res['lang'] = kwargs['lang']
178
+ return PipeResult(res, self._dataset)
179
+
180
+ res = self.apply(
181
+ proc,
182
+ self._dataset,
183
+ imageWriter,
184
+ SupportedPdfParseMethod.OCR,
185
+ start_page_id=start_page_id,
186
+ end_page_id=end_page_id,
187
+ debug_mode=debug_mode,
188
+ lang=lang,
189
+ )
190
+ return res
@@ -179,7 +179,25 @@ class CustomPEKModel:
179
179
  layout_res = self.layout_model(image, ignore_catids=[])
180
180
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
181
181
  # doclayout_yolo
182
- layout_res = self.layout_model.predict(image)
182
+ img_pil = Image.fromarray(image)
183
+ width, height = img_pil.size
184
+ # logger.info(f'width: {width}, height: {height}')
185
+ input_res = {"poly":[0,0,width,0,width,height,0,height]}
186
+ new_image, useful_list = crop_img(input_res, img_pil, crop_paste_x=width//2, crop_paste_y=0)
187
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
188
+ layout_res = self.layout_model.predict(new_image)
189
+ for res in layout_res:
190
+ p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
191
+ p1 = p1 - paste_x + xmin
192
+ p2 = p2 - paste_y + ymin
193
+ p3 = p3 - paste_x + xmin
194
+ p4 = p4 - paste_y + ymin
195
+ p5 = p5 - paste_x + xmin
196
+ p6 = p6 - paste_y + ymin
197
+ p7 = p7 - paste_x + xmin
198
+ p8 = p8 - paste_y + ymin
199
+ res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
200
+
183
201
  layout_cost = round(time.time() - layout_start, 2)
184
202
  logger.info(f'layout detection time: {layout_cost}')
185
203
 
@@ -215,6 +233,7 @@ class CustomPEKModel:
215
233
 
216
234
  # OCR recognition
217
235
  new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
236
+
218
237
  if self.apply_ocr:
219
238
  ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
220
239
  else:
@@ -92,14 +92,24 @@ class AtomModelSingleton:
92
92
  return cls._instance
93
93
 
94
94
  def get_atom_model(self, atom_model_name: str, **kwargs):
95
+
95
96
  lang = kwargs.get('lang', None)
96
97
  layout_model_name = kwargs.get('layout_model_name', None)
97
- key = (atom_model_name, layout_model_name, lang)
98
+ table_model_name = kwargs.get('table_model_name', None)
99
+
100
+ if atom_model_name in [AtomicModel.OCR]:
101
+ key = (atom_model_name, lang)
102
+ elif atom_model_name in [AtomicModel.Layout]:
103
+ key = (atom_model_name, layout_model_name)
104
+ elif atom_model_name in [AtomicModel.Table]:
105
+ key = (atom_model_name, table_model_name)
106
+ else:
107
+ key = atom_model_name
108
+
98
109
  if key not in self._models:
99
110
  self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
100
111
  return self._models[key]
101
112
 
102
-
103
113
  def atom_model_init(model_name: str, **kwargs):
104
114
  atom_model = None
105
115
  if model_name == AtomicModel.Layout:
@@ -129,7 +139,7 @@ def atom_model_init(model_name: str, **kwargs):
129
139
  atom_model = ocr_model_init(
130
140
  kwargs.get('ocr_show_log'),
131
141
  kwargs.get('det_db_box_thresh'),
132
- kwargs.get('lang')
142
+ kwargs.get('lang'),
133
143
  )
134
144
  elif model_name == AtomicModel.Table:
135
145
  atom_model = table_model_init(
@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res):
42
42
 
43
43
 
44
44
  def clean_vram(device, vram_threshold=8):
45
+ total_memory = get_vram(device)
46
+ if total_memory and total_memory <= vram_threshold:
47
+ gc_start = time.time()
48
+ clean_memory()
49
+ gc_time = round(time.time() - gc_start, 2)
50
+ logger.info(f"gc time: {gc_time}")
51
+
52
+
53
+ def get_vram(device):
45
54
  if torch.cuda.is_available() and device != 'cpu':
46
55
  total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
47
- if total_memory <= vram_threshold:
48
- gc_start = time.time()
49
- clean_memory()
50
- gc_time = round(time.time() - gc_start, 2)
51
- logger.info(f"gc time: {gc_time}")
56
+ return total_memory
57
+ return None
@@ -112,8 +112,8 @@ def __is_list_or_index_block(block):
112
112
  line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
113
113
  block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
114
114
  if (
115
- line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height
116
- and block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height
115
+ line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
116
+ and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
117
117
  ):
118
118
  external_sides_not_close_num += 1
119
119
  if abs(line_mid_x - block_mid_x) < line_height / 2:
@@ -1,9 +1,9 @@
1
1
  from magic_pdf.config.enums import SupportedPdfParseMethod
2
- from magic_pdf.data.dataset import PymuDocDataset
2
+ from magic_pdf.data.dataset import Dataset
3
3
  from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
4
4
 
5
5
 
6
- def parse_pdf_by_ocr(pdf_bytes,
6
+ def parse_pdf_by_ocr(dataset: Dataset,
7
7
  model_list,
8
8
  imageWriter,
9
9
  start_page_id=0,
@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
11
11
  debug_mode=False,
12
12
  lang=None,
13
13
  ):
14
- dataset = PymuDocDataset(pdf_bytes)
15
- return pdf_parse_union(dataset,
16
- model_list,
14
+ return pdf_parse_union(model_list,
15
+ dataset,
17
16
  imageWriter,
18
17
  SupportedPdfParseMethod.OCR,
19
18
  start_page_id=start_page_id,
@@ -1,10 +1,10 @@
1
1
  from magic_pdf.config.enums import SupportedPdfParseMethod
2
- from magic_pdf.data.dataset import PymuDocDataset
2
+ from magic_pdf.data.dataset import Dataset
3
3
  from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
4
4
 
5
5
 
6
6
  def parse_pdf_by_txt(
7
- pdf_bytes,
7
+ dataset: Dataset,
8
8
  model_list,
9
9
  imageWriter,
10
10
  start_page_id=0,
@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
12
12
  debug_mode=False,
13
13
  lang=None,
14
14
  ):
15
- dataset = PymuDocDataset(pdf_bytes)
16
- return pdf_parse_union(dataset,
17
- model_list,
15
+ return pdf_parse_union(model_list,
16
+ dataset,
18
17
  imageWriter,
19
18
  SupportedPdfParseMethod.TXT,
20
19
  start_page_id=start_page_id,