magic-pdf 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
magic_pdf/libs/boxbase.py CHANGED
@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
185
185
  bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
186
186
  bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
187
187
 
188
+ if any([bbox1_area == 0, bbox2_area == 0]):
189
+ return 0
190
+
188
191
  # Compute the intersection over union by taking the intersection area
189
192
  # and dividing it by the sum of both areas minus the intersection area
190
- iou = intersection_area / float(bbox1_area + bbox2_area -
191
- intersection_area)
193
+ iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
194
+
192
195
  return iou
193
196
 
194
197
 
@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
362
362
  for page in pdf_info:
363
363
  page_line_list = []
364
364
  for block in page['preproc_blocks']:
365
- if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
365
+ if block['type'] in [BlockType.Text]:
366
366
  for line in block['lines']:
367
367
  bbox = line['bbox']
368
368
  index = line['index']
369
369
  page_line_list.append({'index': index, 'bbox': bbox})
370
- if block['type'] in [BlockType.Image, BlockType.Table]:
370
+ elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
371
+ if 'virtual_lines' in block:
372
+ if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
373
+ for line in block['virtual_lines']:
374
+ bbox = line['bbox']
375
+ index = line['index']
376
+ page_line_list.append({'index': index, 'bbox': bbox})
377
+ else:
378
+ for line in block['lines']:
379
+ bbox = line['bbox']
380
+ index = line['index']
381
+ page_line_list.append({'index': index, 'bbox': bbox})
382
+ elif block['type'] in [BlockType.Image, BlockType.Table]:
371
383
  for sub_block in block['blocks']:
372
384
  if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
373
385
  if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
12
12
  from fast_langdetect import detect_language
13
13
 
14
14
 
15
+ def remove_invalid_surrogates(text):
16
+ # 移除无效的 UTF-16 代理对
17
+ return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
18
+
19
+
15
20
  def detect_lang(text: str) -> str:
16
21
 
17
22
  if len(text) == 0:
18
23
  return ""
19
24
 
20
25
  text = text.replace("\n", "")
26
+ text = remove_invalid_surrogates(text)
27
+
28
+ # print(text)
21
29
  try:
22
30
  lang_upper = detect_language(text)
23
31
  except:
@@ -37,3 +45,4 @@ if __name__ == '__main__':
37
45
  print(detect_lang("<html>This is a test</html>"))
38
46
  print(detect_lang("这个是中文测试。"))
39
47
  print(detect_lang("<html>这个是中文测试。</html>"))
48
+ print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.0.1"
1
+ __version__ = "1.1.0"
@@ -7,19 +7,19 @@ from loguru import logger
7
7
  from PIL import Image
8
8
 
9
9
  from magic_pdf.config.constants import MODEL_NAME
10
- from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
11
- from magic_pdf.data.dataset import Dataset
12
- from magic_pdf.libs.clean_memory import clean_memory
13
- from magic_pdf.libs.config_reader import get_device
14
- from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
10
+ # from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
11
+ # from magic_pdf.data.dataset import Dataset
12
+ # from magic_pdf.libs.clean_memory import clean_memory
13
+ # from magic_pdf.libs.config_reader import get_device
14
+ # from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
15
15
  from magic_pdf.model.pdf_extract_kit import CustomPEKModel
16
16
  from magic_pdf.model.sub_modules.model_utils import (
17
17
  clean_vram, crop_img, get_res_list_from_layout_res)
18
18
  from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
19
19
  get_adjusted_mfdetrec_res, get_ocr_result_list)
20
- from magic_pdf.operators.models import InferenceResult
20
+ # from magic_pdf.operators.models import InferenceResult
21
21
 
22
- YOLO_LAYOUT_BASE_BATCH_SIZE = 4
22
+ YOLO_LAYOUT_BASE_BATCH_SIZE = 1
23
23
  MFD_BASE_BATCH_SIZE = 1
24
24
  MFR_BASE_BATCH_SIZE = 16
25
25
 
@@ -44,19 +44,20 @@ class BatchAnalyze:
44
44
  modified_images = []
45
45
  for image_index, image in enumerate(images):
46
46
  pil_img = Image.fromarray(image)
47
- width, height = pil_img.size
48
- if height > width:
49
- input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
50
- new_image, useful_list = crop_img(
51
- input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
52
- )
53
- layout_images.append(new_image)
54
- modified_images.append([image_index, useful_list])
55
- else:
56
- layout_images.append(pil_img)
47
+ # width, height = pil_img.size
48
+ # if height > width:
49
+ # input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
50
+ # new_image, useful_list = crop_img(
51
+ # input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
52
+ # )
53
+ # layout_images.append(new_image)
54
+ # modified_images.append([image_index, useful_list])
55
+ # else:
56
+ layout_images.append(pil_img)
57
57
 
58
58
  images_layout_res += self.model.layout_model.batch_predict(
59
- layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
59
+ # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
60
+ layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
60
61
  )
61
62
 
62
63
  for image_index, useful_list in modified_images:
@@ -78,7 +79,8 @@ class BatchAnalyze:
78
79
  # 公式检测
79
80
  mfd_start_time = time.time()
80
81
  images_mfd_res = self.model.mfd_model.batch_predict(
81
- images, self.batch_ratio * MFD_BASE_BATCH_SIZE
82
+ # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
83
+ images, MFD_BASE_BATCH_SIZE
82
84
  )
83
85
  logger.info(
84
86
  f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
@@ -91,10 +93,12 @@ class BatchAnalyze:
91
93
  images,
92
94
  batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
93
95
  )
96
+ mfr_count = 0
94
97
  for image_index in range(len(images)):
95
98
  images_layout_res[image_index] += images_formula_list[image_index]
99
+ mfr_count += len(images_formula_list[image_index])
96
100
  logger.info(
97
- f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
101
+ f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
98
102
  )
99
103
 
100
104
  # 清理显存
@@ -159,7 +163,7 @@ class BatchAnalyze:
159
163
  elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
160
164
  html_code = self.model.table_model.img2html(new_image)
161
165
  elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
162
- html_code, table_cell_bboxes, elapse = (
166
+ html_code, table_cell_bboxes, logic_points, elapse = (
163
167
  self.model.table_model.predict(new_image)
164
168
  )
165
169
  run_time = time.time() - single_table_start_time
@@ -195,81 +199,81 @@ class BatchAnalyze:
195
199
  return images_layout_res
196
200
 
197
201
 
198
- def doc_batch_analyze(
199
- dataset: Dataset,
200
- ocr: bool = False,
201
- show_log: bool = False,
202
- start_page_id=0,
203
- end_page_id=None,
204
- lang=None,
205
- layout_model=None,
206
- formula_enable=None,
207
- table_enable=None,
208
- batch_ratio: int | None = None,
209
- ) -> InferenceResult:
210
- """Perform batch analysis on a document dataset.
211
-
212
- Args:
213
- dataset (Dataset): The dataset containing document pages to be analyzed.
214
- ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
215
- show_log (bool, optional): Flag to enable logging. Defaults to False.
216
- start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
217
- end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
218
- lang (str, optional): Language for OCR. Defaults to None.
219
- layout_model (optional): Layout model to be used for analysis. Defaults to None.
220
- formula_enable (optional): Flag to enable formula detection. Defaults to None.
221
- table_enable (optional): Flag to enable table detection. Defaults to None.
222
- batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
223
-
224
- Raises:
225
- CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
226
-
227
- Returns:
228
- InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
229
- """
230
-
231
- if not torch.cuda.is_available():
232
- raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
233
-
234
- lang = None if lang == '' else lang
235
- # TODO: auto detect batch size
236
- batch_ratio = 1 if batch_ratio is None else batch_ratio
237
- end_page_id = end_page_id if end_page_id else len(dataset)
238
-
239
- model_manager = ModelSingleton()
240
- custom_model: CustomPEKModel = model_manager.get_model(
241
- ocr, show_log, lang, layout_model, formula_enable, table_enable
242
- )
243
- batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
244
-
245
- model_json = []
246
-
247
- # batch analyze
248
- images = []
249
- for index in range(len(dataset)):
250
- if start_page_id <= index <= end_page_id:
251
- page_data = dataset.get_page(index)
252
- img_dict = page_data.get_image()
253
- images.append(img_dict['img'])
254
- analyze_result = batch_model(images)
255
-
256
- for index in range(len(dataset)):
257
- page_data = dataset.get_page(index)
258
- img_dict = page_data.get_image()
259
- page_width = img_dict['width']
260
- page_height = img_dict['height']
261
- if start_page_id <= index <= end_page_id:
262
- result = analyze_result.pop(0)
263
- else:
264
- result = []
265
-
266
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
267
- page_dict = {'layout_dets': result, 'page_info': page_info}
268
- model_json.append(page_dict)
269
-
270
- # TODO: clean memory when gpu memory is not enough
271
- clean_memory_start_time = time.time()
272
- clean_memory(get_device())
273
- logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
274
-
275
- return InferenceResult(model_json, dataset)
202
+ # def doc_batch_analyze(
203
+ # dataset: Dataset,
204
+ # ocr: bool = False,
205
+ # show_log: bool = False,
206
+ # start_page_id=0,
207
+ # end_page_id=None,
208
+ # lang=None,
209
+ # layout_model=None,
210
+ # formula_enable=None,
211
+ # table_enable=None,
212
+ # batch_ratio: int | None = None,
213
+ # ) -> InferenceResult:
214
+ # """Perform batch analysis on a document dataset.
215
+ #
216
+ # Args:
217
+ # dataset (Dataset): The dataset containing document pages to be analyzed.
218
+ # ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
219
+ # show_log (bool, optional): Flag to enable logging. Defaults to False.
220
+ # start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
221
+ # end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
222
+ # lang (str, optional): Language for OCR. Defaults to None.
223
+ # layout_model (optional): Layout model to be used for analysis. Defaults to None.
224
+ # formula_enable (optional): Flag to enable formula detection. Defaults to None.
225
+ # table_enable (optional): Flag to enable table detection. Defaults to None.
226
+ # batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
227
+ #
228
+ # Raises:
229
+ # CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
230
+ #
231
+ # Returns:
232
+ # InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
233
+ # """
234
+ #
235
+ # if not torch.cuda.is_available():
236
+ # raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
237
+ #
238
+ # lang = None if lang == '' else lang
239
+ # # TODO: auto detect batch size
240
+ # batch_ratio = 1 if batch_ratio is None else batch_ratio
241
+ # end_page_id = end_page_id if end_page_id else len(dataset)
242
+ #
243
+ # model_manager = ModelSingleton()
244
+ # custom_model: CustomPEKModel = model_manager.get_model(
245
+ # ocr, show_log, lang, layout_model, formula_enable, table_enable
246
+ # )
247
+ # batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
248
+ #
249
+ # model_json = []
250
+ #
251
+ # # batch analyze
252
+ # images = []
253
+ # for index in range(len(dataset)):
254
+ # if start_page_id <= index <= end_page_id:
255
+ # page_data = dataset.get_page(index)
256
+ # img_dict = page_data.get_image()
257
+ # images.append(img_dict['img'])
258
+ # analyze_result = batch_model(images)
259
+ #
260
+ # for index in range(len(dataset)):
261
+ # page_data = dataset.get_page(index)
262
+ # img_dict = page_data.get_image()
263
+ # page_width = img_dict['width']
264
+ # page_height = img_dict['height']
265
+ # if start_page_id <= index <= end_page_id:
266
+ # result = analyze_result.pop(0)
267
+ # else:
268
+ # result = []
269
+ #
270
+ # page_info = {'page_no': index, 'height': page_height, 'width': page_width}
271
+ # page_dict = {'layout_dets': result, 'page_info': page_info}
272
+ # model_json.append(page_dict)
273
+ #
274
+ # # TODO: clean memory when gpu memory is not enough
275
+ # clean_memory_start_time = time.time()
276
+ # clean_memory(get_device())
277
+ # logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
278
+ #
279
+ # return InferenceResult(model_json, dataset)
@@ -3,8 +3,12 @@ import time
3
3
 
4
4
  # 关闭paddle的信号处理
5
5
  import paddle
6
+ import torch
6
7
  from loguru import logger
7
8
 
9
+ from magic_pdf.model.batch_analyze import BatchAnalyze
10
+ from magic_pdf.model.sub_modules.model_utils import get_vram
11
+
8
12
  paddle.disable_signal_handler()
9
13
 
10
14
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
@@ -154,33 +158,88 @@ def doc_analyze(
154
158
  table_enable=None,
155
159
  ) -> InferenceResult:
156
160
 
161
+ end_page_id = end_page_id if end_page_id else len(dataset) - 1
162
+
157
163
  model_manager = ModelSingleton()
158
164
  custom_model = model_manager.get_model(
159
165
  ocr, show_log, lang, layout_model, formula_enable, table_enable
160
166
  )
161
167
 
168
+ batch_analyze = False
169
+ device = get_device()
170
+
171
+ npu_support = False
172
+ if str(device).startswith("npu"):
173
+ import torch_npu
174
+ if torch_npu.npu.is_available():
175
+ npu_support = True
176
+
177
+ if torch.cuda.is_available() and device != 'cpu' or npu_support:
178
+ gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
179
+ if gpu_memory is not None and gpu_memory >= 8:
180
+
181
+ if 8 <= gpu_memory < 10:
182
+ batch_ratio = 2
183
+ elif 10 <= gpu_memory <= 12:
184
+ batch_ratio = 4
185
+ elif 12 < gpu_memory <= 16:
186
+ batch_ratio = 8
187
+ elif 16 < gpu_memory <= 24:
188
+ batch_ratio = 16
189
+ else:
190
+ batch_ratio = 32
191
+
192
+ if batch_ratio >= 1:
193
+ logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
194
+ batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
195
+ batch_analyze = True
196
+
162
197
  model_json = []
163
198
  doc_analyze_start = time.time()
164
199
 
165
- if end_page_id is None:
166
- end_page_id = len(dataset)
167
-
168
- for index in range(len(dataset)):
169
- page_data = dataset.get_page(index)
170
- img_dict = page_data.get_image()
171
- img = img_dict['img']
172
- page_width = img_dict['width']
173
- page_height = img_dict['height']
174
- if start_page_id <= index <= end_page_id:
175
- page_start = time.time()
176
- result = custom_model(img)
177
- logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
178
- else:
179
- result = []
200
+ if batch_analyze:
201
+ # batch analyze
202
+ images = []
203
+ for index in range(len(dataset)):
204
+ if start_page_id <= index <= end_page_id:
205
+ page_data = dataset.get_page(index)
206
+ img_dict = page_data.get_image()
207
+ images.append(img_dict['img'])
208
+ analyze_result = batch_model(images)
209
+
210
+ for index in range(len(dataset)):
211
+ page_data = dataset.get_page(index)
212
+ img_dict = page_data.get_image()
213
+ page_width = img_dict['width']
214
+ page_height = img_dict['height']
215
+ if start_page_id <= index <= end_page_id:
216
+ result = analyze_result.pop(0)
217
+ else:
218
+ result = []
219
+
220
+ page_info = {'page_no': index, 'height': page_height, 'width': page_width}
221
+ page_dict = {'layout_dets': result, 'page_info': page_info}
222
+ model_json.append(page_dict)
180
223
 
181
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
182
- page_dict = {'layout_dets': result, 'page_info': page_info}
183
- model_json.append(page_dict)
224
+ else:
225
+ # single analyze
226
+
227
+ for index in range(len(dataset)):
228
+ page_data = dataset.get_page(index)
229
+ img_dict = page_data.get_image()
230
+ img = img_dict['img']
231
+ page_width = img_dict['width']
232
+ page_height = img_dict['height']
233
+ if start_page_id <= index <= end_page_id:
234
+ page_start = time.time()
235
+ result = custom_model(img)
236
+ logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
237
+ else:
238
+ result = []
239
+
240
+ page_info = {'page_no': index, 'height': page_height, 'width': page_width}
241
+ page_dict = {'layout_dets': result, 'page_info': page_info}
242
+ model_json.append(page_dict)
184
243
 
185
244
  gc_start = time.time()
186
245
  clean_memory(get_device())
@@ -69,6 +69,7 @@ class CustomPEKModel:
69
69
  self.apply_table = self.table_config.get('enable', False)
70
70
  self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
71
71
  self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
72
+ self.table_sub_model_name = self.table_config.get('sub_model', None)
72
73
 
73
74
  # ocr config
74
75
  self.apply_ocr = ocr
@@ -144,7 +145,7 @@ class CustomPEKModel:
144
145
  model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
145
146
  )
146
147
  ),
147
- device=self.device,
148
+ device='cpu' if str(self.device).startswith("mps") else self.device,
148
149
  )
149
150
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
150
151
  self.layout_model = atom_model_manager.get_atom_model(
@@ -174,6 +175,7 @@ class CustomPEKModel:
174
175
  table_max_time=self.table_max_time,
175
176
  device=self.device,
176
177
  ocr_engine=self.ocr_model,
178
+ table_sub_model_name=self.table_sub_model_name
177
179
  )
178
180
 
179
181
  logger.info('DocAnalysis init done!')
@@ -192,24 +194,24 @@ class CustomPEKModel:
192
194
  layout_res = self.layout_model(image, ignore_catids=[])
193
195
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
194
196
  # doclayout_yolo
195
- if height > width:
196
- input_res = {"poly":[0,0,width,0,width,height,0,height]}
197
- new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
198
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
199
- layout_res = self.layout_model.predict(new_image)
200
- for res in layout_res:
201
- p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
202
- p1 = p1 - paste_x + xmin
203
- p2 = p2 - paste_y + ymin
204
- p3 = p3 - paste_x + xmin
205
- p4 = p4 - paste_y + ymin
206
- p5 = p5 - paste_x + xmin
207
- p6 = p6 - paste_y + ymin
208
- p7 = p7 - paste_x + xmin
209
- p8 = p8 - paste_y + ymin
210
- res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
211
- else:
212
- layout_res = self.layout_model.predict(image)
197
+ # if height > width:
198
+ # input_res = {"poly":[0,0,width,0,width,height,0,height]}
199
+ # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
200
+ # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
201
+ # layout_res = self.layout_model.predict(new_image)
202
+ # for res in layout_res:
203
+ # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
204
+ # p1 = p1 - paste_x + xmin
205
+ # p2 = p2 - paste_y + ymin
206
+ # p3 = p3 - paste_x + xmin
207
+ # p4 = p4 - paste_y + ymin
208
+ # p5 = p5 - paste_x + xmin
209
+ # p6 = p6 - paste_y + ymin
210
+ # p7 = p7 - paste_x + xmin
211
+ # p8 = p8 - paste_y + ymin
212
+ # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
213
+ # else:
214
+ layout_res = self.layout_model.predict(image)
213
215
 
214
216
  layout_cost = round(time.time() - layout_start, 2)
215
217
  logger.info(f'layout detection time: {layout_cost}')
@@ -228,7 +230,7 @@ class CustomPEKModel:
228
230
  logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
229
231
 
230
232
  # 清理显存
231
- clean_vram(self.device, vram_threshold=8)
233
+ clean_vram(self.device, vram_threshold=6)
232
234
 
233
235
  # 从layout_res中获取ocr区域、表格区域、公式区域
234
236
  ocr_res_list, table_res_list, single_page_mfdetrec_res = (
@@ -276,7 +278,7 @@ class CustomPEKModel:
276
278
  elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
277
279
  html_code = self.table_model.img2html(new_image)
278
280
  elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
279
- html_code, table_cell_bboxes, elapse = self.table_model.predict(
281
+ html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
280
282
  new_image
281
283
  )
282
284
  run_time = time.time() - single_table_start_time
@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
9
9
  def predict(self, image):
10
10
  layout_res = []
11
11
  doclayout_yolo_res = self.model.predict(
12
- image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
12
+ image,
13
+ imgsz=1280,
14
+ conf=0.10,
15
+ iou=0.45,
16
+ verbose=False, device=self.device
13
17
  )[0]
14
18
  for xyxy, conf, cla in zip(
15
19
  doclayout_yolo_res.boxes.xyxy.cpu(),
@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
32
36
  image_res.cpu()
33
37
  for image_res in self.model.predict(
34
38
  images[index : index + batch_size],
35
- imgsz=1024,
36
- conf=0.25,
39
+ imgsz=1280,
40
+ conf=0.10,
37
41
  iou=0.45,
38
42
  verbose=False,
39
43
  device=self.device,
@@ -89,7 +89,7 @@ class UnimernetModel(object):
89
89
  mf_image_list.append(bbox_img)
90
90
 
91
91
  dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
92
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
92
+ dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
93
93
  mfr_res = []
94
94
  for mf_img in dataloader:
95
95
  mf_img = mf_img.to(self.device)
@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
21
21
  TableMasterPaddleModel
22
22
 
23
23
 
24
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
24
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
25
25
  if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
26
26
  table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
27
27
  elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
31
31
  }
32
32
  table_model = TableMasterPaddleModel(config)
33
33
  elif table_model_type == MODEL_NAME.RAPID_TABLE:
34
- table_model = RapidTableModel(ocr_engine)
34
+ table_model = RapidTableModel(ocr_engine, table_sub_model_name)
35
35
  else:
36
36
  logger.error('table model type not allow')
37
37
  exit(1)
@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
163
163
  kwargs.get('table_model_path'),
164
164
  kwargs.get('table_max_time'),
165
165
  kwargs.get('device'),
166
- kwargs.get('ocr_engine')
166
+ kwargs.get('ocr_engine'),
167
+ kwargs.get('table_sub_model_name')
167
168
  )
168
169
  elif model_name == AtomicModel.LangDetect:
169
170
  if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect: