magic-pdf 0.10.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/config/constants.py +7 -0
  2. magic_pdf/config/exceptions.py +7 -0
  3. magic_pdf/data/data_reader_writer/base.py +13 -1
  4. magic_pdf/data/data_reader_writer/filebase.py +1 -1
  5. magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
  6. magic_pdf/data/dataset.py +188 -5
  7. magic_pdf/data/read_api.py +59 -12
  8. magic_pdf/data/utils.py +35 -0
  9. magic_pdf/dict2md/ocr_mkcontent.py +16 -15
  10. magic_pdf/filter/__init__.py +32 -0
  11. magic_pdf/filter/pdf_meta_scan.py +3 -2
  12. magic_pdf/libs/clean_memory.py +11 -4
  13. magic_pdf/libs/config_reader.py +9 -0
  14. magic_pdf/libs/draw_bbox.py +19 -22
  15. magic_pdf/libs/language.py +3 -0
  16. magic_pdf/libs/pdf_check.py +30 -30
  17. magic_pdf/libs/version.py +1 -1
  18. magic_pdf/model/__init__.py +1 -1
  19. magic_pdf/model/batch_analyze.py +275 -0
  20. magic_pdf/model/doc_analyze_by_custom_model.py +104 -92
  21. magic_pdf/model/magic_model.py +4 -435
  22. magic_pdf/model/model_list.py +1 -0
  23. magic_pdf/model/pdf_extract_kit.py +35 -5
  24. magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
  25. magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
  26. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
  27. magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
  29. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
  30. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
  31. magic_pdf/model/sub_modules/model_init.py +43 -7
  32. magic_pdf/model/sub_modules/model_utils.py +17 -5
  33. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
  34. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
  35. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
  36. magic_pdf/operators/__init__.py +94 -0
  37. magic_pdf/operators/models.py +154 -0
  38. magic_pdf/operators/pipes.py +191 -0
  39. magic_pdf/pdf_parse_union_core_v2.py +77 -27
  40. magic_pdf/post_proc/__init__.py +1 -0
  41. magic_pdf/post_proc/llm_aided.py +133 -0
  42. magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
  43. magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
  44. magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
  45. magic_pdf/tools/cli.py +36 -11
  46. magic_pdf/tools/common.py +120 -61
  47. magic_pdf/utils/office_to_pdf.py +29 -0
  48. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/METADATA +78 -25
  49. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/RECORD +54 -55
  50. magic_pdf/para/__init__.py +0 -0
  51. magic_pdf/pdf_parse_by_ocr.py +0 -23
  52. magic_pdf/pdf_parse_by_txt.py +0 -24
  53. magic_pdf/pipe/AbsPipe.py +0 -98
  54. magic_pdf/pipe/OCRPipe.py +0 -41
  55. magic_pdf/pipe/TXTPipe.py +0 -41
  56. magic_pdf/pipe/UNIPipe.py +0 -98
  57. magic_pdf/pipe/__init__.py +0 -0
  58. magic_pdf/rw/AbsReaderWriter.py +0 -17
  59. magic_pdf/rw/DiskReaderWriter.py +0 -74
  60. magic_pdf/rw/S3ReaderWriter.py +0 -142
  61. magic_pdf/rw/__init__.py +0 -0
  62. magic_pdf/user_api.py +0 -121
  63. /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
  64. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/LICENSE.md +0 -0
  65. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/WHEEL +0 -0
  66. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  import fitz
2
2
  from magic_pdf.config.constants import CROSS_PAGE
3
- from magic_pdf.config.ocr_content_type import BlockType, CategoryId, ContentType
4
- from magic_pdf.data.dataset import PymuDocDataset
3
+ from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
4
+ ContentType)
5
+ from magic_pdf.data.dataset import Dataset
5
6
  from magic_pdf.model.magic_model import MagicModel
6
7
 
7
8
 
@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
194
195
  )
195
196
 
196
197
  # Save the PDF
197
- pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
198
+ pdf_docs.save(f'{out_path}/{filename}')
198
199
 
199
200
 
200
201
  def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
282
283
  draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
283
284
 
284
285
  # Save the PDF
285
- pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
286
+ pdf_docs.save(f'{out_path}/{filename}')
286
287
 
287
288
 
288
- def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
289
+ def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
289
290
  dropped_bbox_list = []
290
291
  tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
291
292
  imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
292
293
  titles_list = []
293
294
  texts_list = []
294
295
  interequations_list = []
295
- pdf_docs = fitz.open('pdf', pdf_bytes)
296
- magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
296
+ magic_model = MagicModel(model_list, dataset)
297
297
  for i in range(len(model_list)):
298
298
  page_dropped_list = []
299
299
  tables_body, tables_caption, tables_footnote = [], [], []
@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
337
337
  dropped_bbox_list.append(page_dropped_list)
338
338
  imgs_footnote_list.append(imgs_footnote)
339
339
 
340
- for i, page in enumerate(pdf_docs):
340
+ for i in range(len(dataset)):
341
+ page = dataset.get_page(i)
341
342
  draw_bbox_with_number(
342
343
  i, dropped_bbox_list, page, [158, 158, 158], True
343
344
  ) # color !
@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
352
353
  draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
353
354
 
354
355
  # Save the PDF
355
- pdf_docs.save(f'{out_path}/{filename}_model.pdf')
356
+ dataset.dump_to_file(f'{out_path}/{filename}')
356
357
 
357
358
 
358
359
  def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
@@ -390,20 +391,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
390
391
  for i, page in enumerate(pdf_docs):
391
392
  draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
392
393
 
393
- pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')
394
-
394
+ pdf_docs.save(f'{out_path}/{filename}')
395
395
 
396
- def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
397
- layout_bbox_list = []
398
396
 
399
- for page in pdf_info:
400
- page_block_list = []
401
- for block in page['para_blocks']:
402
- bbox = block['bbox']
403
- page_block_list.append(bbox)
404
- layout_bbox_list.append(page_block_list)
397
+ def draw_char_bbox(pdf_bytes, out_path, filename):
405
398
  pdf_docs = fitz.open('pdf', pdf_bytes)
406
399
  for i, page in enumerate(pdf_docs):
407
- draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
408
-
409
- pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
400
+ for block in page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']:
401
+ for line in block['lines']:
402
+ for span in line['spans']:
403
+ for char in span['chars']:
404
+ char_bbox = char['bbox']
405
+ page.draw_rect(char_bbox, color=[1, 0, 0], fill=None, fill_opacity=1, width=0.3, overlay=True,)
406
+ pdf_docs.save(f'{out_path}/{filename}')
@@ -16,11 +16,14 @@ def detect_lang(text: str) -> str:
16
16
 
17
17
  if len(text) == 0:
18
18
  return ""
19
+
20
+ text = text.replace("\n", "")
19
21
  try:
20
22
  lang_upper = detect_language(text)
21
23
  except:
22
24
  html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
23
25
  lang_upper = detect_language(html_no_ctrl_chars)
26
+
24
27
  try:
25
28
  lang = lang_upper.lower()
26
29
  except:
@@ -1,9 +1,9 @@
1
1
  import fitz
2
2
  import numpy as np
3
3
  from loguru import logger
4
- # import re
5
- # from io import BytesIO
6
- # from pdfminer.high_level import extract_text
4
+ import re
5
+ from io import BytesIO
6
+ from pdfminer.high_level import extract_text
7
7
 
8
8
 
9
9
  def calculate_sample_count(total_page: int):
@@ -33,33 +33,33 @@ def extract_pages(src_pdf_bytes: bytes) -> fitz.Document:
33
33
  return sample_docs
34
34
 
35
35
 
36
- # def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
37
- # """"
38
- # 检测PDF中是否包含非法字符
39
- # """
40
- # '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
41
- # sample_docs = extract_pages(src_pdf_bytes)
42
- # sample_pdf_bytes = sample_docs.tobytes()
43
- # sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
44
- # text = extract_text(sample_pdf_file_like_object)
45
- # text = text.replace("\n", "")
46
- # # logger.info(text)
47
- # '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
48
- # cid_pattern = re.compile(r'\(cid:\d+\)')
49
- # matches = cid_pattern.findall(text)
50
- # cid_count = len(matches)
51
- # cid_len = sum(len(match) for match in matches)
52
- # text_len = len(text)
53
- # if text_len == 0:
54
- # cid_chars_radio = 0
55
- # else:
56
- # cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
57
- # logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
58
- # '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
59
- # if cid_chars_radio > 0.05:
60
- # return False # 乱码文档
61
- # else:
62
- # return True # 正常文档
36
+ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
37
+ """"
38
+ 检测PDF中是否包含非法字符
39
+ """
40
+ '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
41
+ sample_docs = extract_pages(src_pdf_bytes)
42
+ sample_pdf_bytes = sample_docs.tobytes()
43
+ sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
44
+ text = extract_text(sample_pdf_file_like_object)
45
+ text = text.replace("\n", "")
46
+ # logger.info(text)
47
+ '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
48
+ cid_pattern = re.compile(r'\(cid:\d+\)')
49
+ matches = cid_pattern.findall(text)
50
+ cid_count = len(matches)
51
+ cid_len = sum(len(match) for match in matches)
52
+ text_len = len(text)
53
+ if text_len == 0:
54
+ cid_chars_radio = 0
55
+ else:
56
+ cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
57
+ logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
58
+ '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
59
+ if cid_chars_radio > 0.05:
60
+ return False # 乱码文档
61
+ else:
62
+ return True # 正常文档
63
63
 
64
64
 
65
65
  def count_replacement_characters(text: str) -> int:
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.10.5"
1
+ __version__ = "1.0.0"
@@ -1,2 +1,2 @@
1
1
  __use_inside_model__ = True
2
- __model_mode__ = "full"
2
+ __model_mode__ = 'full'
@@ -0,0 +1,275 @@
1
+ import time
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from loguru import logger
7
+ from PIL import Image
8
+
9
+ from magic_pdf.config.constants import MODEL_NAME
10
+ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
11
+ from magic_pdf.data.dataset import Dataset
12
+ from magic_pdf.libs.clean_memory import clean_memory
13
+ from magic_pdf.libs.config_reader import get_device
14
+ from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
15
+ from magic_pdf.model.pdf_extract_kit import CustomPEKModel
16
+ from magic_pdf.model.sub_modules.model_utils import (
17
+ clean_vram, crop_img, get_res_list_from_layout_res)
18
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
19
+ get_adjusted_mfdetrec_res, get_ocr_result_list)
20
+ from magic_pdf.operators.models import InferenceResult
21
+
22
+ YOLO_LAYOUT_BASE_BATCH_SIZE = 4
23
+ MFD_BASE_BATCH_SIZE = 1
24
+ MFR_BASE_BATCH_SIZE = 16
25
+
26
+
27
+ class BatchAnalyze:
28
+ def __init__(self, model: CustomPEKModel, batch_ratio: int):
29
+ self.model = model
30
+ self.batch_ratio = batch_ratio
31
+
32
+ def __call__(self, images: list) -> list:
33
+ images_layout_res = []
34
+
35
+ layout_start_time = time.time()
36
+ if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
37
+ # layoutlmv3
38
+ for image in images:
39
+ layout_res = self.model.layout_model(image, ignore_catids=[])
40
+ images_layout_res.append(layout_res)
41
+ elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
42
+ # doclayout_yolo
43
+ layout_images = []
44
+ modified_images = []
45
+ for image_index, image in enumerate(images):
46
+ pil_img = Image.fromarray(image)
47
+ width, height = pil_img.size
48
+ if height > width:
49
+ input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
50
+ new_image, useful_list = crop_img(
51
+ input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
52
+ )
53
+ layout_images.append(new_image)
54
+ modified_images.append([image_index, useful_list])
55
+ else:
56
+ layout_images.append(pil_img)
57
+
58
+ images_layout_res += self.model.layout_model.batch_predict(
59
+ layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
60
+ )
61
+
62
+ for image_index, useful_list in modified_images:
63
+ for res in images_layout_res[image_index]:
64
+ for i in range(len(res['poly'])):
65
+ if i % 2 == 0:
66
+ res['poly'][i] = (
67
+ res['poly'][i] - useful_list[0] + useful_list[2]
68
+ )
69
+ else:
70
+ res['poly'][i] = (
71
+ res['poly'][i] - useful_list[1] + useful_list[3]
72
+ )
73
+ logger.info(
74
+ f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
75
+ )
76
+
77
+ if self.model.apply_formula:
78
+ # 公式检测
79
+ mfd_start_time = time.time()
80
+ images_mfd_res = self.model.mfd_model.batch_predict(
81
+ images, self.batch_ratio * MFD_BASE_BATCH_SIZE
82
+ )
83
+ logger.info(
84
+ f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
85
+ )
86
+
87
+ # 公式识别
88
+ mfr_start_time = time.time()
89
+ images_formula_list = self.model.mfr_model.batch_predict(
90
+ images_mfd_res,
91
+ images,
92
+ batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
93
+ )
94
+ for image_index in range(len(images)):
95
+ images_layout_res[image_index] += images_formula_list[image_index]
96
+ logger.info(
97
+ f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
98
+ )
99
+
100
+ # 清理显存
101
+ clean_vram(self.model.device, vram_threshold=8)
102
+
103
+ ocr_time = 0
104
+ ocr_count = 0
105
+ table_time = 0
106
+ table_count = 0
107
+ # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
108
+ for index in range(len(images)):
109
+ layout_res = images_layout_res[index]
110
+ pil_img = Image.fromarray(images[index])
111
+
112
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = (
113
+ get_res_list_from_layout_res(layout_res)
114
+ )
115
+ # ocr识别
116
+ ocr_start = time.time()
117
+ # Process each area that requires OCR processing
118
+ for res in ocr_res_list:
119
+ new_image, useful_list = crop_img(
120
+ res, pil_img, crop_paste_x=50, crop_paste_y=50
121
+ )
122
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
123
+ single_page_mfdetrec_res, useful_list
124
+ )
125
+
126
+ # OCR recognition
127
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
128
+
129
+ if self.model.apply_ocr:
130
+ ocr_res = self.model.ocr_model.ocr(
131
+ new_image, mfd_res=adjusted_mfdetrec_res
132
+ )[0]
133
+ else:
134
+ ocr_res = self.model.ocr_model.ocr(
135
+ new_image, mfd_res=adjusted_mfdetrec_res, rec=False
136
+ )[0]
137
+
138
+ # Integration results
139
+ if ocr_res:
140
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
141
+ layout_res.extend(ocr_result_list)
142
+ ocr_time += time.time() - ocr_start
143
+ ocr_count += len(ocr_res_list)
144
+
145
+ # 表格识别 table recognition
146
+ if self.model.apply_table:
147
+ table_start = time.time()
148
+ for res in table_res_list:
149
+ new_image, _ = crop_img(res, pil_img)
150
+ single_table_start_time = time.time()
151
+ html_code = None
152
+ if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
153
+ with torch.no_grad():
154
+ table_result = self.model.table_model.predict(
155
+ new_image, 'html'
156
+ )
157
+ if len(table_result) > 0:
158
+ html_code = table_result[0]
159
+ elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
160
+ html_code = self.model.table_model.img2html(new_image)
161
+ elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
162
+ html_code, table_cell_bboxes, elapse = (
163
+ self.model.table_model.predict(new_image)
164
+ )
165
+ run_time = time.time() - single_table_start_time
166
+ if run_time > self.model.table_max_time:
167
+ logger.warning(
168
+ f'table recognition processing exceeds max time {self.model.table_max_time}s'
169
+ )
170
+ # 判断是否返回正常
171
+ if html_code:
172
+ expected_ending = html_code.strip().endswith(
173
+ '</html>'
174
+ ) or html_code.strip().endswith('</table>')
175
+ if expected_ending:
176
+ res['html'] = html_code
177
+ else:
178
+ logger.warning(
179
+ 'table recognition processing fails, not found expected HTML table end'
180
+ )
181
+ else:
182
+ logger.warning(
183
+ 'table recognition processing fails, not get html return'
184
+ )
185
+ table_time += time.time() - table_start
186
+ table_count += len(table_res_list)
187
+
188
+ if self.model.apply_ocr:
189
+ logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
190
+ else:
191
+ logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
192
+ if self.model.apply_table:
193
+ logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
194
+
195
+ return images_layout_res
196
+
197
+
198
+ def doc_batch_analyze(
199
+ dataset: Dataset,
200
+ ocr: bool = False,
201
+ show_log: bool = False,
202
+ start_page_id=0,
203
+ end_page_id=None,
204
+ lang=None,
205
+ layout_model=None,
206
+ formula_enable=None,
207
+ table_enable=None,
208
+ batch_ratio: int | None = None,
209
+ ) -> InferenceResult:
210
+ """Perform batch analysis on a document dataset.
211
+
212
+ Args:
213
+ dataset (Dataset): The dataset containing document pages to be analyzed.
214
+ ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
215
+ show_log (bool, optional): Flag to enable logging. Defaults to False.
216
+ start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
217
+ end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
218
+ lang (str, optional): Language for OCR. Defaults to None.
219
+ layout_model (optional): Layout model to be used for analysis. Defaults to None.
220
+ formula_enable (optional): Flag to enable formula detection. Defaults to None.
221
+ table_enable (optional): Flag to enable table detection. Defaults to None.
222
+ batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
223
+
224
+ Raises:
225
+ CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
226
+
227
+ Returns:
228
+ InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
229
+ """
230
+
231
+ if not torch.cuda.is_available():
232
+ raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
233
+
234
+ lang = None if lang == '' else lang
235
+ # TODO: auto detect batch size
236
+ batch_ratio = 1 if batch_ratio is None else batch_ratio
237
+ end_page_id = end_page_id if end_page_id else len(dataset)
238
+
239
+ model_manager = ModelSingleton()
240
+ custom_model: CustomPEKModel = model_manager.get_model(
241
+ ocr, show_log, lang, layout_model, formula_enable, table_enable
242
+ )
243
+ batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
244
+
245
+ model_json = []
246
+
247
+ # batch analyze
248
+ images = []
249
+ for index in range(len(dataset)):
250
+ if start_page_id <= index <= end_page_id:
251
+ page_data = dataset.get_page(index)
252
+ img_dict = page_data.get_image()
253
+ images.append(img_dict['img'])
254
+ analyze_result = batch_model(images)
255
+
256
+ for index in range(len(dataset)):
257
+ page_data = dataset.get_page(index)
258
+ img_dict = page_data.get_image()
259
+ page_width = img_dict['width']
260
+ page_height = img_dict['height']
261
+ if start_page_id <= index <= end_page_id:
262
+ result = analyze_result.pop(0)
263
+ else:
264
+ result = []
265
+
266
+ page_info = {'page_no': index, 'height': page_height, 'width': page_width}
267
+ page_dict = {'layout_dets': result, 'page_info': page_info}
268
+ model_json.append(page_dict)
269
+
270
+ # TODO: clean memory when gpu memory is not enough
271
+ clean_memory_start_time = time.time()
272
+ clean_memory(get_device())
273
+ logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
274
+
275
+ return InferenceResult(model_json, dataset)