magic-pdf 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +24 -0
  2. magic_pdf/filter/__init__.py +1 -1
  3. magic_pdf/filter/pdf_classify_by_type.py +6 -4
  4. magic_pdf/filter/pdf_meta_scan.py +4 -4
  5. magic_pdf/libs/boxbase.py +5 -2
  6. magic_pdf/libs/draw_bbox.py +14 -2
  7. magic_pdf/libs/language.py +9 -0
  8. magic_pdf/libs/pdf_check.py +11 -1
  9. magic_pdf/libs/version.py +1 -1
  10. magic_pdf/model/batch_analyze.py +103 -99
  11. magic_pdf/model/doc_analyze_by_custom_model.py +87 -36
  12. magic_pdf/model/magic_model.py +161 -4
  13. magic_pdf/model/pdf_extract_kit.py +23 -28
  14. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
  15. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
  16. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
  17. magic_pdf/model/sub_modules/model_init.py +34 -19
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
  19. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
  20. magic_pdf/pdf_parse_union_core_v2.py +176 -61
  21. magic_pdf/post_proc/llm_aided.py +55 -24
  22. magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
  23. magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
  24. magic_pdf/resources/model_config/model_configs.yaml +2 -2
  25. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/METADATA +36 -19
  26. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/RECORD +30 -30
  27. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/LICENSE.md +0 -0
  28. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/WHEEL +0 -0
  29. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/entry_points.txt +0 -0
  30. {magic_pdf-1.0.1.dist-info → magic_pdf-1.2.0.dist-info}/top_level.txt +0 -0
@@ -126,11 +126,35 @@ def detect_language(text):
126
126
  return 'empty'
127
127
 
128
128
 
129
+ def full_to_half(text: str) -> str:
130
+ """Convert full-width characters to half-width characters using code point manipulation.
131
+
132
+ Args:
133
+ text: String containing full-width characters
134
+
135
+ Returns:
136
+ String with full-width characters converted to half-width
137
+ """
138
+ result = []
139
+ for char in text:
140
+ code = ord(char)
141
+ # Full-width ASCII variants (FF01-FF5E)
142
+ if 0xFF01 <= code <= 0xFF5E:
143
+ result.append(chr(code - 0xFEE0)) # Shift to ASCII range
144
+ # Full-width space
145
+ elif code == 0x3000:
146
+ result.append(' ')
147
+ else:
148
+ result.append(char)
149
+ return ''.join(result)
150
+
151
+
129
152
  def merge_para_with_text(para_block):
130
153
  block_text = ''
131
154
  for line in para_block['lines']:
132
155
  for span in line['spans']:
133
156
  if span['type'] in [ContentType.Text]:
157
+ span['content'] = full_to_half(span['content'])
134
158
  block_text += span['content']
135
159
  block_lang = detect_lang(block_text)
136
160
 
@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
23
23
  pdf_meta['image_info_per_page'],
24
24
  pdf_meta['text_len_per_page'],
25
25
  pdf_meta['imgs_per_page'],
26
- pdf_meta['text_layout_per_page'],
26
+ # pdf_meta['text_layout_per_page'],
27
27
  pdf_meta['invalid_chars'],
28
28
  )
29
29
  if is_text_pdf:
@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
305
305
 
306
306
 
307
307
  def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
308
- text_layout_list: list, invalid_chars: bool):
308
+ # text_layout_list: list,
309
+ invalid_chars: bool):
309
310
  """
310
311
  这里的图片和页面长度单位是pts
311
312
  :param total_page:
@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
321
322
  'by_text_len': classify_by_text_len(text_len_list, total_page),
322
323
  'by_avg_words': classify_by_avg_words(text_len_list),
323
324
  'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
324
- 'by_text_layout': classify_by_text_layout(text_layout_list),
325
+ # 'by_text_layout': classify_by_text_layout(text_layout_list),
325
326
  'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
326
327
  'by_invalid_chars': invalid_chars,
327
328
  }
@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
332
333
  return False, results
333
334
  else:
334
335
  logger.warning(
335
- f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']},"
336
+ f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
336
337
  f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
337
- f" by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']},"
338
+ # f" by_text_layout: {results['by_text_layout']},"
339
+ f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
338
340
  f" by_invalid_chars: {results['by_invalid_chars']}",
339
341
  file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
340
342
  return False, results
@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
356
356
  # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
357
357
  text_len_per_page = get_pdf_textlen_per_page(doc)
358
358
  # logger.info(f"text_len_per_page: {text_len_per_page}")
359
- text_layout_per_page = get_pdf_text_layout_per_page(doc)
359
+ # text_layout_per_page = get_pdf_text_layout_per_page(doc)
360
360
  # logger.info(f"text_layout_per_page: {text_layout_per_page}")
361
- text_language = get_language(doc)
361
+ # text_language = get_language(doc)
362
362
  # logger.info(f"text_language: {text_language}")
363
363
  invalid_chars = check_invalid_chars(pdf_bytes)
364
364
  # logger.info(f"invalid_chars: {invalid_chars}")
@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
372
372
  'page_height_pts': int(page_height_pts),
373
373
  'image_info_per_page': image_info_per_page,
374
374
  'text_len_per_page': text_len_per_page,
375
- 'text_layout_per_page': text_layout_per_page,
376
- 'text_language': text_language,
375
+ # 'text_layout_per_page': text_layout_per_page,
376
+ # 'text_language': text_language,
377
377
  # "svgs_per_page": svgs_per_page,
378
378
  'imgs_per_page': imgs_per_page, # 增加每页img数量list
379
379
  'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
magic_pdf/libs/boxbase.py CHANGED
@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
185
185
  bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
186
186
  bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
187
187
 
188
+ if any([bbox1_area == 0, bbox2_area == 0]):
189
+ return 0
190
+
188
191
  # Compute the intersection over union by taking the intersection area
189
192
  # and dividing it by the sum of both areas minus the intersection area
190
- iou = intersection_area / float(bbox1_area + bbox2_area -
191
- intersection_area)
193
+ iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
194
+
192
195
  return iou
193
196
 
194
197
 
@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
362
362
  for page in pdf_info:
363
363
  page_line_list = []
364
364
  for block in page['preproc_blocks']:
365
- if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
365
+ if block['type'] in [BlockType.Text]:
366
366
  for line in block['lines']:
367
367
  bbox = line['bbox']
368
368
  index = line['index']
369
369
  page_line_list.append({'index': index, 'bbox': bbox})
370
- if block['type'] in [BlockType.Image, BlockType.Table]:
370
+ elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
371
+ if 'virtual_lines' in block:
372
+ if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
373
+ for line in block['virtual_lines']:
374
+ bbox = line['bbox']
375
+ index = line['index']
376
+ page_line_list.append({'index': index, 'bbox': bbox})
377
+ else:
378
+ for line in block['lines']:
379
+ bbox = line['bbox']
380
+ index = line['index']
381
+ page_line_list.append({'index': index, 'bbox': bbox})
382
+ elif block['type'] in [BlockType.Image, BlockType.Table]:
371
383
  for sub_block in block['blocks']:
372
384
  if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
373
385
  if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
12
12
  from fast_langdetect import detect_language
13
13
 
14
14
 
15
+ def remove_invalid_surrogates(text):
16
+ # 移除无效的 UTF-16 代理对
17
+ return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
18
+
19
+
15
20
  def detect_lang(text: str) -> str:
16
21
 
17
22
  if len(text) == 0:
18
23
  return ""
19
24
 
20
25
  text = text.replace("\n", "")
26
+ text = remove_invalid_surrogates(text)
27
+
28
+ # print(text)
21
29
  try:
22
30
  lang_upper = detect_language(text)
23
31
  except:
@@ -37,3 +45,4 @@ if __name__ == '__main__':
37
45
  print(detect_lang("<html>This is a test</html>"))
38
46
  print(detect_lang("这个是中文测试。"))
39
47
  print(detect_lang("<html>这个是中文测试。</html>"))
48
+ print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
@@ -4,6 +4,7 @@ from loguru import logger
4
4
  import re
5
5
  from io import BytesIO
6
6
  from pdfminer.high_level import extract_text
7
+ from pdfminer.layout import LAParams
7
8
 
8
9
 
9
10
  def calculate_sample_count(total_page: int):
@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
41
42
  sample_docs = extract_pages(src_pdf_bytes)
42
43
  sample_pdf_bytes = sample_docs.tobytes()
43
44
  sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
44
- text = extract_text(sample_pdf_file_like_object)
45
+ laparams = LAParams(
46
+ line_overlap=0.5,
47
+ char_margin=2.0,
48
+ line_margin=0.5,
49
+ word_margin=0.1,
50
+ boxes_flow=None,
51
+ detect_vertical=False,
52
+ all_texts=False,
53
+ )
54
+ text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
45
55
  text = text.replace("\n", "")
46
56
  # logger.info(text)
47
57
  '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.0.1"
1
+ __version__ = "1.2.0"
@@ -7,19 +7,19 @@ from loguru import logger
7
7
  from PIL import Image
8
8
 
9
9
  from magic_pdf.config.constants import MODEL_NAME
10
- from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
11
- from magic_pdf.data.dataset import Dataset
12
- from magic_pdf.libs.clean_memory import clean_memory
13
- from magic_pdf.libs.config_reader import get_device
14
- from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
10
+ # from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
11
+ # from magic_pdf.data.dataset import Dataset
12
+ # from magic_pdf.libs.clean_memory import clean_memory
13
+ # from magic_pdf.libs.config_reader import get_device
14
+ # from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
15
15
  from magic_pdf.model.pdf_extract_kit import CustomPEKModel
16
16
  from magic_pdf.model.sub_modules.model_utils import (
17
17
  clean_vram, crop_img, get_res_list_from_layout_res)
18
18
  from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
19
19
  get_adjusted_mfdetrec_res, get_ocr_result_list)
20
- from magic_pdf.operators.models import InferenceResult
20
+ # from magic_pdf.operators.models import InferenceResult
21
21
 
22
- YOLO_LAYOUT_BASE_BATCH_SIZE = 4
22
+ YOLO_LAYOUT_BASE_BATCH_SIZE = 1
23
23
  MFD_BASE_BATCH_SIZE = 1
24
24
  MFR_BASE_BATCH_SIZE = 16
25
25
 
@@ -44,19 +44,20 @@ class BatchAnalyze:
44
44
  modified_images = []
45
45
  for image_index, image in enumerate(images):
46
46
  pil_img = Image.fromarray(image)
47
- width, height = pil_img.size
48
- if height > width:
49
- input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
50
- new_image, useful_list = crop_img(
51
- input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
52
- )
53
- layout_images.append(new_image)
54
- modified_images.append([image_index, useful_list])
55
- else:
56
- layout_images.append(pil_img)
47
+ # width, height = pil_img.size
48
+ # if height > width:
49
+ # input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
50
+ # new_image, useful_list = crop_img(
51
+ # input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
52
+ # )
53
+ # layout_images.append(new_image)
54
+ # modified_images.append([image_index, useful_list])
55
+ # else:
56
+ layout_images.append(pil_img)
57
57
 
58
58
  images_layout_res += self.model.layout_model.batch_predict(
59
- layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
59
+ # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
60
+ layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
60
61
  )
61
62
 
62
63
  for image_index, useful_list in modified_images:
@@ -78,7 +79,8 @@ class BatchAnalyze:
78
79
  # 公式检测
79
80
  mfd_start_time = time.time()
80
81
  images_mfd_res = self.model.mfd_model.batch_predict(
81
- images, self.batch_ratio * MFD_BASE_BATCH_SIZE
82
+ # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
83
+ images, MFD_BASE_BATCH_SIZE
82
84
  )
83
85
  logger.info(
84
86
  f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
@@ -91,10 +93,12 @@ class BatchAnalyze:
91
93
  images,
92
94
  batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
93
95
  )
96
+ mfr_count = 0
94
97
  for image_index in range(len(images)):
95
98
  images_layout_res[image_index] += images_formula_list[image_index]
99
+ mfr_count += len(images_formula_list[image_index])
96
100
  logger.info(
97
- f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
101
+ f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
98
102
  )
99
103
 
100
104
  # 清理显存
@@ -159,7 +163,7 @@ class BatchAnalyze:
159
163
  elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
160
164
  html_code = self.model.table_model.img2html(new_image)
161
165
  elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
162
- html_code, table_cell_bboxes, elapse = (
166
+ html_code, table_cell_bboxes, logic_points, elapse = (
163
167
  self.model.table_model.predict(new_image)
164
168
  )
165
169
  run_time = time.time() - single_table_start_time
@@ -195,81 +199,81 @@ class BatchAnalyze:
195
199
  return images_layout_res
196
200
 
197
201
 
198
- def doc_batch_analyze(
199
- dataset: Dataset,
200
- ocr: bool = False,
201
- show_log: bool = False,
202
- start_page_id=0,
203
- end_page_id=None,
204
- lang=None,
205
- layout_model=None,
206
- formula_enable=None,
207
- table_enable=None,
208
- batch_ratio: int | None = None,
209
- ) -> InferenceResult:
210
- """Perform batch analysis on a document dataset.
211
-
212
- Args:
213
- dataset (Dataset): The dataset containing document pages to be analyzed.
214
- ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
215
- show_log (bool, optional): Flag to enable logging. Defaults to False.
216
- start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
217
- end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
218
- lang (str, optional): Language for OCR. Defaults to None.
219
- layout_model (optional): Layout model to be used for analysis. Defaults to None.
220
- formula_enable (optional): Flag to enable formula detection. Defaults to None.
221
- table_enable (optional): Flag to enable table detection. Defaults to None.
222
- batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
223
-
224
- Raises:
225
- CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
226
-
227
- Returns:
228
- InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
229
- """
230
-
231
- if not torch.cuda.is_available():
232
- raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
233
-
234
- lang = None if lang == '' else lang
235
- # TODO: auto detect batch size
236
- batch_ratio = 1 if batch_ratio is None else batch_ratio
237
- end_page_id = end_page_id if end_page_id else len(dataset)
238
-
239
- model_manager = ModelSingleton()
240
- custom_model: CustomPEKModel = model_manager.get_model(
241
- ocr, show_log, lang, layout_model, formula_enable, table_enable
242
- )
243
- batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
244
-
245
- model_json = []
246
-
247
- # batch analyze
248
- images = []
249
- for index in range(len(dataset)):
250
- if start_page_id <= index <= end_page_id:
251
- page_data = dataset.get_page(index)
252
- img_dict = page_data.get_image()
253
- images.append(img_dict['img'])
254
- analyze_result = batch_model(images)
255
-
256
- for index in range(len(dataset)):
257
- page_data = dataset.get_page(index)
258
- img_dict = page_data.get_image()
259
- page_width = img_dict['width']
260
- page_height = img_dict['height']
261
- if start_page_id <= index <= end_page_id:
262
- result = analyze_result.pop(0)
263
- else:
264
- result = []
265
-
266
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
267
- page_dict = {'layout_dets': result, 'page_info': page_info}
268
- model_json.append(page_dict)
269
-
270
- # TODO: clean memory when gpu memory is not enough
271
- clean_memory_start_time = time.time()
272
- clean_memory(get_device())
273
- logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
274
-
275
- return InferenceResult(model_json, dataset)
202
+ # def doc_batch_analyze(
203
+ # dataset: Dataset,
204
+ # ocr: bool = False,
205
+ # show_log: bool = False,
206
+ # start_page_id=0,
207
+ # end_page_id=None,
208
+ # lang=None,
209
+ # layout_model=None,
210
+ # formula_enable=None,
211
+ # table_enable=None,
212
+ # batch_ratio: int | None = None,
213
+ # ) -> InferenceResult:
214
+ # """Perform batch analysis on a document dataset.
215
+ #
216
+ # Args:
217
+ # dataset (Dataset): The dataset containing document pages to be analyzed.
218
+ # ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
219
+ # show_log (bool, optional): Flag to enable logging. Defaults to False.
220
+ # start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
221
+ # end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
222
+ # lang (str, optional): Language for OCR. Defaults to None.
223
+ # layout_model (optional): Layout model to be used for analysis. Defaults to None.
224
+ # formula_enable (optional): Flag to enable formula detection. Defaults to None.
225
+ # table_enable (optional): Flag to enable table detection. Defaults to None.
226
+ # batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
227
+ #
228
+ # Raises:
229
+ # CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
230
+ #
231
+ # Returns:
232
+ # InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
233
+ # """
234
+ #
235
+ # if not torch.cuda.is_available():
236
+ # raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
237
+ #
238
+ # lang = None if lang == '' else lang
239
+ # # TODO: auto detect batch size
240
+ # batch_ratio = 1 if batch_ratio is None else batch_ratio
241
+ # end_page_id = end_page_id if end_page_id else len(dataset)
242
+ #
243
+ # model_manager = ModelSingleton()
244
+ # custom_model: CustomPEKModel = model_manager.get_model(
245
+ # ocr, show_log, lang, layout_model, formula_enable, table_enable
246
+ # )
247
+ # batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
248
+ #
249
+ # model_json = []
250
+ #
251
+ # # batch analyze
252
+ # images = []
253
+ # for index in range(len(dataset)):
254
+ # if start_page_id <= index <= end_page_id:
255
+ # page_data = dataset.get_page(index)
256
+ # img_dict = page_data.get_image()
257
+ # images.append(img_dict['img'])
258
+ # analyze_result = batch_model(images)
259
+ #
260
+ # for index in range(len(dataset)):
261
+ # page_data = dataset.get_page(index)
262
+ # img_dict = page_data.get_image()
263
+ # page_width = img_dict['width']
264
+ # page_height = img_dict['height']
265
+ # if start_page_id <= index <= end_page_id:
266
+ # result = analyze_result.pop(0)
267
+ # else:
268
+ # result = []
269
+ #
270
+ # page_info = {'page_no': index, 'height': page_height, 'width': page_width}
271
+ # page_dict = {'layout_dets': result, 'page_info': page_info}
272
+ # model_json.append(page_dict)
273
+ #
274
+ # # TODO: clean memory when gpu memory is not enough
275
+ # clean_memory_start_time = time.time()
276
+ # clean_memory(get_device())
277
+ # logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
278
+ #
279
+ # return InferenceResult(model_json, dataset)
@@ -1,17 +1,22 @@
1
1
  import os
2
2
  import time
3
+ import torch
3
4
 
5
+ os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
6
+ os.environ['FLAGS_use_stride_kernel'] = '0'
7
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
8
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
4
9
  # 关闭paddle的信号处理
5
10
  import paddle
6
- from loguru import logger
7
-
8
11
  paddle.disable_signal_handler()
9
12
 
10
- os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
13
+ from loguru import logger
14
+
15
+ from magic_pdf.model.batch_analyze import BatchAnalyze
16
+ from magic_pdf.model.sub_modules.model_utils import get_vram
11
17
 
12
18
  try:
13
19
  import torchtext
14
-
15
20
  if torchtext.__version__ >= '0.18.0':
16
21
  torchtext.disable_torchtext_deprecation_warning()
17
22
  except ImportError:
@@ -28,20 +33,6 @@ from magic_pdf.model.model_list import MODEL
28
33
  from magic_pdf.operators.models import InferenceResult
29
34
 
30
35
 
31
- def dict_compare(d1, d2):
32
- return d1.items() == d2.items()
33
-
34
-
35
- def remove_duplicates_dicts(lst):
36
- unique_dicts = []
37
- for dict_item in lst:
38
- if not any(
39
- dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
40
- ):
41
- unique_dicts.append(dict_item)
42
- return unique_dicts
43
-
44
-
45
36
  class ModelSingleton:
46
37
  _instance = None
47
38
  _models = {}
@@ -154,33 +145,93 @@ def doc_analyze(
154
145
  table_enable=None,
155
146
  ) -> InferenceResult:
156
147
 
148
+ end_page_id = (
149
+ end_page_id
150
+ if end_page_id is not None and end_page_id >= 0
151
+ else len(dataset) - 1
152
+ )
153
+
157
154
  model_manager = ModelSingleton()
158
155
  custom_model = model_manager.get_model(
159
156
  ocr, show_log, lang, layout_model, formula_enable, table_enable
160
157
  )
161
158
 
159
+ batch_analyze = False
160
+ batch_ratio = 1
161
+ device = get_device()
162
+
163
+ npu_support = False
164
+ if str(device).startswith("npu"):
165
+ import torch_npu
166
+ if torch_npu.npu.is_available():
167
+ npu_support = True
168
+
169
+ if torch.cuda.is_available() and device != 'cpu' or npu_support:
170
+ gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
171
+ if gpu_memory is not None and gpu_memory >= 8:
172
+
173
+ if gpu_memory >= 40:
174
+ batch_ratio = 32
175
+ elif gpu_memory >=20:
176
+ batch_ratio = 16
177
+ elif gpu_memory >= 16:
178
+ batch_ratio = 8
179
+ elif gpu_memory >= 10:
180
+ batch_ratio = 4
181
+ else:
182
+ batch_ratio = 2
183
+
184
+ logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
185
+ batch_analyze = True
186
+
162
187
  model_json = []
163
188
  doc_analyze_start = time.time()
164
189
 
165
- if end_page_id is None:
166
- end_page_id = len(dataset)
167
-
168
- for index in range(len(dataset)):
169
- page_data = dataset.get_page(index)
170
- img_dict = page_data.get_image()
171
- img = img_dict['img']
172
- page_width = img_dict['width']
173
- page_height = img_dict['height']
174
- if start_page_id <= index <= end_page_id:
175
- page_start = time.time()
176
- result = custom_model(img)
177
- logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
178
- else:
179
- result = []
190
+ if batch_analyze:
191
+ # batch analyze
192
+ images = []
193
+ page_wh_list = []
194
+ for index in range(len(dataset)):
195
+ if start_page_id <= index <= end_page_id:
196
+ page_data = dataset.get_page(index)
197
+ img_dict = page_data.get_image()
198
+ images.append(img_dict['img'])
199
+ page_wh_list.append((img_dict['width'], img_dict['height']))
200
+ batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
201
+ analyze_result = batch_model(images)
202
+
203
+ for index in range(len(dataset)):
204
+ if start_page_id <= index <= end_page_id:
205
+ result = analyze_result.pop(0)
206
+ page_width, page_height = page_wh_list.pop(0)
207
+ else:
208
+ result = []
209
+ page_height = 0
210
+ page_width = 0
211
+
212
+ page_info = {'page_no': index, 'width': page_width, 'height': page_height}
213
+ page_dict = {'layout_dets': result, 'page_info': page_info}
214
+ model_json.append(page_dict)
180
215
 
181
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
182
- page_dict = {'layout_dets': result, 'page_info': page_info}
183
- model_json.append(page_dict)
216
+ else:
217
+ # single analyze
218
+
219
+ for index in range(len(dataset)):
220
+ page_data = dataset.get_page(index)
221
+ img_dict = page_data.get_image()
222
+ img = img_dict['img']
223
+ page_width = img_dict['width']
224
+ page_height = img_dict['height']
225
+ if start_page_id <= index <= end_page_id:
226
+ page_start = time.time()
227
+ result = custom_model(img)
228
+ logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
229
+ else:
230
+ result = []
231
+
232
+ page_info = {'page_no': index, 'width': page_width, 'height': page_height}
233
+ page_dict = {'layout_dets': result, 'page_info': page_info}
234
+ model_json.append(page_dict)
184
235
 
185
236
  gc_start = time.time()
186
237
  clean_memory(get_device())