magic-pdf 0.10.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/config/constants.py +7 -0
  2. magic_pdf/config/exceptions.py +7 -0
  3. magic_pdf/data/data_reader_writer/base.py +13 -1
  4. magic_pdf/data/data_reader_writer/filebase.py +1 -1
  5. magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
  6. magic_pdf/data/dataset.py +188 -5
  7. magic_pdf/data/read_api.py +59 -12
  8. magic_pdf/data/utils.py +35 -0
  9. magic_pdf/dict2md/ocr_mkcontent.py +16 -15
  10. magic_pdf/filter/__init__.py +32 -0
  11. magic_pdf/filter/pdf_meta_scan.py +3 -2
  12. magic_pdf/libs/clean_memory.py +11 -4
  13. magic_pdf/libs/config_reader.py +9 -0
  14. magic_pdf/libs/draw_bbox.py +19 -22
  15. magic_pdf/libs/language.py +3 -0
  16. magic_pdf/libs/pdf_check.py +30 -30
  17. magic_pdf/libs/version.py +1 -1
  18. magic_pdf/model/__init__.py +1 -1
  19. magic_pdf/model/batch_analyze.py +275 -0
  20. magic_pdf/model/doc_analyze_by_custom_model.py +104 -92
  21. magic_pdf/model/magic_model.py +4 -435
  22. magic_pdf/model/model_list.py +1 -0
  23. magic_pdf/model/pdf_extract_kit.py +35 -5
  24. magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
  25. magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
  26. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
  27. magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
  29. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
  30. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
  31. magic_pdf/model/sub_modules/model_init.py +43 -7
  32. magic_pdf/model/sub_modules/model_utils.py +17 -5
  33. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
  34. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
  35. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
  36. magic_pdf/operators/__init__.py +94 -0
  37. magic_pdf/operators/models.py +154 -0
  38. magic_pdf/operators/pipes.py +191 -0
  39. magic_pdf/pdf_parse_union_core_v2.py +77 -27
  40. magic_pdf/post_proc/__init__.py +1 -0
  41. magic_pdf/post_proc/llm_aided.py +133 -0
  42. magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
  43. magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
  44. magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
  45. magic_pdf/tools/cli.py +36 -11
  46. magic_pdf/tools/common.py +120 -61
  47. magic_pdf/utils/office_to_pdf.py +29 -0
  48. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/METADATA +78 -25
  49. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/RECORD +54 -55
  50. magic_pdf/para/__init__.py +0 -0
  51. magic_pdf/pdf_parse_by_ocr.py +0 -23
  52. magic_pdf/pdf_parse_by_txt.py +0 -24
  53. magic_pdf/pipe/AbsPipe.py +0 -98
  54. magic_pdf/pipe/OCRPipe.py +0 -41
  55. magic_pdf/pipe/TXTPipe.py +0 -41
  56. magic_pdf/pipe/UNIPipe.py +0 -98
  57. magic_pdf/pipe/__init__.py +0 -0
  58. magic_pdf/rw/AbsReaderWriter.py +0 -17
  59. magic_pdf/rw/DiskReaderWriter.py +0 -74
  60. magic_pdf/rw/S3ReaderWriter.py +0 -142
  61. magic_pdf/rw/__init__.py +0 -0
  62. magic_pdf/user_api.py +0 -121
  63. /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
  64. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/LICENSE.md +0 -0
  65. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/WHEEL +0 -0
  66. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ from collections import Counter
3
+ from uuid import uuid4
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from loguru import logger
8
+ from ultralytics import YOLO
9
+
10
+ language_dict = {
11
+ "ch": "中文简体",
12
+ "en": "英语",
13
+ "japan": "日语",
14
+ "korean": "韩语",
15
+ "fr": "法语",
16
+ "german": "德语",
17
+ "ar": "阿拉伯语",
18
+ "ru": "俄语"
19
+ }
20
+
21
+
22
+ def split_images(image, result_images=None):
23
+ """
24
+ 对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
25
+ 每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
26
+ 避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
27
+ """
28
+ if result_images is None:
29
+ result_images = []
30
+
31
+ width, height = image.size
32
+ long_side = max(width, height) # 获取较长边长度
33
+
34
+ if long_side <= 400:
35
+ result_images.append(image)
36
+ return result_images
37
+
38
+ new_long_side = long_side // 2
39
+ sub_images = []
40
+
41
+ if width >= height: # 如果宽度是较长边
42
+ for x in range(0, width, new_long_side):
43
+ # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
44
+ if x + new_long_side > width:
45
+ continue
46
+ box = (x, 0, x + new_long_side, height)
47
+ sub_image = image.crop(box)
48
+ sub_images.append(sub_image)
49
+ else: # 如果高度是较长边
50
+ for y in range(0, height, new_long_side):
51
+ # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
52
+ if y + new_long_side > height:
53
+ continue
54
+ box = (0, y, width, y + new_long_side)
55
+ sub_image = image.crop(box)
56
+ sub_images.append(sub_image)
57
+
58
+ for sub_image in sub_images:
59
+ split_images(sub_image, result_images)
60
+
61
+ return result_images
62
+
63
+
64
+ def resize_images_to_224(image):
65
+ """
66
+ 若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
67
+ """
68
+ try:
69
+ width, height = image.size
70
+ if width < 224 or height < 224:
71
+ new_image = Image.new('RGB', (224, 224), (0, 0, 0))
72
+ paste_x = (224 - width) // 2
73
+ paste_y = (224 - height) // 2
74
+ new_image.paste(image, (paste_x, paste_y))
75
+ image = new_image
76
+ else:
77
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
78
+
79
+ # uuid = str(uuid4())
80
+ # image.save(f"/tmp/{uuid}.jpg")
81
+ return image
82
+ except Exception as e:
83
+ logger.exception(e)
84
+
85
+
86
+ class YOLOv11LangDetModel(object):
87
+ def __init__(self, langdetect_model_weight, device):
88
+
89
+ self.model = YOLO(langdetect_model_weight)
90
+
91
+ if str(device).startswith("npu"):
92
+ self.device = torch.device(device)
93
+ else:
94
+ self.device = device
95
+ def do_detect(self, images: list):
96
+ all_images = []
97
+ for image in images:
98
+ width, height = image.size
99
+ # logger.info(f"image size: {width} x {height}")
100
+ if width < 100 and height < 100:
101
+ continue
102
+ temp_images = split_images(image)
103
+ for temp_image in temp_images:
104
+ all_images.append(resize_images_to_224(temp_image))
105
+
106
+ images_lang_res = self.batch_predict(all_images, batch_size=8)
107
+ # logger.info(f"images_lang_res: {images_lang_res}")
108
+ if len(images_lang_res) > 0:
109
+ count_dict = Counter(images_lang_res)
110
+ language = max(count_dict, key=count_dict.get)
111
+ else:
112
+ language = None
113
+ return language
114
+
115
+ def predict(self, image):
116
+ results = self.model.predict(image, verbose=False, device=self.device)
117
+ predicted_class_id = int(results[0].probs.top1)
118
+ predicted_class_name = self.model.names[predicted_class_id]
119
+ return predicted_class_name
120
+
121
+
122
+ def batch_predict(self, images: list, batch_size: int) -> list:
123
+ images_lang_res = []
124
+
125
+ for index in range(0, len(images), batch_size):
126
+ lang_res = [
127
+ image_res.cpu()
128
+ for image_res in self.model.predict(
129
+ images[index: index + batch_size],
130
+ verbose = False,
131
+ device=self.device,
132
+ )
133
+ ]
134
+ for res in lang_res:
135
+ predicted_class_id = int(res.probs.top1)
136
+ predicted_class_name = self.model.names[predicted_class_id]
137
+ images_lang_res.append(predicted_class_name)
138
+
139
+ return images_lang_res
@@ -0,0 +1 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
@@ -8,14 +8,51 @@ class DocLayoutYOLOModel(object):
8
8
 
9
9
  def predict(self, image):
10
10
  layout_res = []
11
- doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
12
- for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
13
- doclayout_yolo_res.boxes.cls.cpu()):
11
+ doclayout_yolo_res = self.model.predict(
12
+ image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
13
+ )[0]
14
+ for xyxy, conf, cla in zip(
15
+ doclayout_yolo_res.boxes.xyxy.cpu(),
16
+ doclayout_yolo_res.boxes.conf.cpu(),
17
+ doclayout_yolo_res.boxes.cls.cpu(),
18
+ ):
14
19
  xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
15
20
  new_item = {
16
- 'category_id': int(cla.item()),
17
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
18
- 'score': round(float(conf.item()), 3),
21
+ "category_id": int(cla.item()),
22
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
23
+ "score": round(float(conf.item()), 3),
19
24
  }
20
25
  layout_res.append(new_item)
21
- return layout_res
26
+ return layout_res
27
+
28
+ def batch_predict(self, images: list, batch_size: int) -> list:
29
+ images_layout_res = []
30
+ for index in range(0, len(images), batch_size):
31
+ doclayout_yolo_res = [
32
+ image_res.cpu()
33
+ for image_res in self.model.predict(
34
+ images[index : index + batch_size],
35
+ imgsz=1024,
36
+ conf=0.25,
37
+ iou=0.45,
38
+ verbose=False,
39
+ device=self.device,
40
+ )
41
+ ]
42
+ for image_res in doclayout_yolo_res:
43
+ layout_res = []
44
+ for xyxy, conf, cla in zip(
45
+ image_res.boxes.xyxy,
46
+ image_res.boxes.conf,
47
+ image_res.boxes.cls,
48
+ ):
49
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
50
+ new_item = {
51
+ "category_id": int(cla.item()),
52
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
53
+ "score": round(float(conf.item()), 3),
54
+ }
55
+ layout_res.append(new_item)
56
+ images_layout_res.append(layout_res)
57
+
58
+ return images_layout_res
@@ -2,11 +2,30 @@ from ultralytics import YOLO
2
2
 
3
3
 
4
4
  class YOLOv8MFDModel(object):
5
- def __init__(self, weight, device='cpu'):
5
+ def __init__(self, weight, device="cpu"):
6
6
  self.mfd_model = YOLO(weight)
7
7
  self.device = device
8
8
 
9
9
  def predict(self, image):
10
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
10
+ mfd_res = self.mfd_model.predict(
11
+ image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
12
+ )[0]
11
13
  return mfd_res
12
14
 
15
+ def batch_predict(self, images: list, batch_size: int) -> list:
16
+ images_mfd_res = []
17
+ for index in range(0, len(images), batch_size):
18
+ mfd_res = [
19
+ image_res.cpu()
20
+ for image_res in self.mfd_model.predict(
21
+ images[index : index + batch_size],
22
+ imgsz=1888,
23
+ conf=0.25,
24
+ iou=0.45,
25
+ verbose=False,
26
+ device=self.device,
27
+ )
28
+ ]
29
+ for image_res in mfd_res:
30
+ images_mfd_res.append(image_res)
31
+ return images_mfd_res
@@ -1,13 +1,13 @@
1
- import os
2
1
  import argparse
2
+ import os
3
3
  import re
4
4
 
5
- from PIL import Image
6
5
  import torch
7
- from torch.utils.data import Dataset, DataLoader
6
+ import unimernet.tasks as tasks
7
+ from PIL import Image
8
+ from torch.utils.data import DataLoader, Dataset
8
9
  from torchvision import transforms
9
10
  from unimernet.common.config import Config
10
- import unimernet.tasks as tasks
11
11
  from unimernet.processors import load_processor
12
12
 
13
13
 
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
31
31
 
32
32
 
33
33
  def latex_rm_whitespace(s: str):
34
- """Remove unnecessary whitespace from LaTeX code.
35
- """
36
- text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
37
- letter = '[a-zA-Z]'
38
- noletter = '[\W_^\d]'
39
- names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
34
+ """Remove unnecessary whitespace from LaTeX code."""
35
+ text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
36
+ letter = "[a-zA-Z]"
37
+ noletter = "[\W_^\d]"
38
+ names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
40
39
  s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
41
40
  news = s
42
41
  while True:
43
42
  s = news
44
- news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
45
- news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
46
- news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
43
+ news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
44
+ news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
45
+ news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
47
46
  if news == s:
48
47
  break
49
48
  return s
50
49
 
51
50
 
52
51
  class UnimernetModel(object):
53
- def __init__(self, weight_dir, cfg_path, _device_='cpu'):
54
-
52
+ def __init__(self, weight_dir, cfg_path, _device_="cpu"):
55
53
  args = argparse.Namespace(cfg_path=cfg_path, options=None)
56
54
  cfg = Config(args)
57
55
  cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
@@ -62,20 +60,28 @@ class UnimernetModel(object):
62
60
  self.device = _device_
63
61
  self.model.to(_device_)
64
62
  self.model.eval()
65
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
66
- self.mfr_transform = transforms.Compose([vis_processor, ])
63
+ vis_processor = load_processor(
64
+ "formula_image_eval",
65
+ cfg.config.datasets.formula_rec_eval.vis_processor.eval,
66
+ )
67
+ self.mfr_transform = transforms.Compose(
68
+ [
69
+ vis_processor,
70
+ ]
71
+ )
67
72
 
68
73
  def predict(self, mfd_res, image):
69
-
70
74
  formula_list = []
71
75
  mf_image_list = []
72
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
76
+ for xyxy, conf, cla in zip(
77
+ mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
78
+ ):
73
79
  xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
74
80
  new_item = {
75
- 'category_id': 13 + int(cla.item()),
76
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
77
- 'score': round(float(conf.item()), 2),
78
- 'latex': '',
81
+ "category_id": 13 + int(cla.item()),
82
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
83
+ "score": round(float(conf.item()), 2),
84
+ "latex": "",
79
85
  }
80
86
  formula_list.append(new_item)
81
87
  pil_img = Image.fromarray(image)
@@ -88,11 +94,48 @@ class UnimernetModel(object):
88
94
  for mf_img in dataloader:
89
95
  mf_img = mf_img.to(self.device)
90
96
  with torch.no_grad():
91
- output = self.model.generate({'image': mf_img})
92
- mfr_res.extend(output['pred_str'])
97
+ output = self.model.generate({"image": mf_img})
98
+ mfr_res.extend(output["pred_str"])
93
99
  for res, latex in zip(formula_list, mfr_res):
94
- res['latex'] = latex_rm_whitespace(latex)
100
+ res["latex"] = latex_rm_whitespace(latex)
95
101
  return formula_list
96
102
 
103
+ def batch_predict(
104
+ self, images_mfd_res: list, images: list, batch_size: int = 64
105
+ ) -> list:
106
+ images_formula_list = []
107
+ mf_image_list = []
108
+ backfill_list = []
109
+ for image_index in range(len(images_mfd_res)):
110
+ mfd_res = images_mfd_res[image_index]
111
+ pil_img = Image.fromarray(images[image_index])
112
+ formula_list = []
113
+
114
+ for xyxy, conf, cla in zip(
115
+ mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
116
+ ):
117
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
118
+ new_item = {
119
+ "category_id": 13 + int(cla.item()),
120
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
121
+ "score": round(float(conf.item()), 2),
122
+ "latex": "",
123
+ }
124
+ formula_list.append(new_item)
125
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
126
+ mf_image_list.append(bbox_img)
127
+
128
+ images_formula_list.append(formula_list)
129
+ backfill_list += formula_list
97
130
 
98
-
131
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
132
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
133
+ mfr_res = []
134
+ for mf_img in dataloader:
135
+ mf_img = mf_img.to(self.device)
136
+ with torch.no_grad():
137
+ output = self.model.generate({"image": mf_img})
138
+ mfr_res.extend(output["pred_str"])
139
+ for res, latex in zip(backfill_list, mfr_res):
140
+ res["latex"] = latex_rm_whitespace(latex)
141
+ return images_formula_list
@@ -1,7 +1,9 @@
1
+ import torch
1
2
  from loguru import logger
2
3
 
3
4
  from magic_pdf.config.constants import MODEL_NAME
4
5
  from magic_pdf.model.model_list import AtomicModel
6
+ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
5
7
  from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
6
8
  DocLayoutYOLOModel
7
9
  from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
19
21
  TableMasterPaddleModel
20
22
 
21
23
 
22
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
24
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
23
25
  if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
24
26
  table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
25
27
  elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
29
31
  }
30
32
  table_model = TableMasterPaddleModel(config)
31
33
  elif table_model_type == MODEL_NAME.RAPID_TABLE:
32
- table_model = RapidTableModel()
34
+ table_model = RapidTableModel(ocr_engine)
33
35
  else:
34
36
  logger.error('table model type not allow')
35
37
  exit(1)
@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
38
40
 
39
41
 
40
42
  def mfd_model_init(weight, device='cpu'):
43
+ if str(device).startswith("npu"):
44
+ device = torch.device(device)
41
45
  mfd_model = YOLOv8MFDModel(weight, device)
42
46
  return mfd_model
43
47
 
@@ -53,16 +57,26 @@ def layout_model_init(weight, config_file, device):
53
57
 
54
58
 
55
59
  def doclayout_yolo_model_init(weight, device='cpu'):
60
+ if str(device).startswith("npu"):
61
+ device = torch.device(device)
56
62
  model = DocLayoutYOLOModel(weight, device)
57
63
  return model
58
64
 
59
65
 
66
+ def langdetect_model_init(langdetect_model_weight, device='cpu'):
67
+ if str(device).startswith("npu"):
68
+ device = torch.device(device)
69
+ model = YOLOv11LangDetModel(langdetect_model_weight, device)
70
+ return model
71
+
72
+
60
73
  def ocr_model_init(show_log: bool = False,
61
74
  det_db_box_thresh=0.3,
62
75
  lang=None,
63
76
  use_dilation=True,
64
77
  det_db_unclip_ratio=1.8,
65
78
  ):
79
+
66
80
  if lang is not None and lang != '':
67
81
  model = ModifiedPaddleOCR(
68
82
  show_log=show_log,
@@ -77,7 +91,6 @@ def ocr_model_init(show_log: bool = False,
77
91
  det_db_box_thresh=det_db_box_thresh,
78
92
  use_dilation=use_dilation,
79
93
  det_db_unclip_ratio=det_db_unclip_ratio,
80
- # use_angle_cls=True,
81
94
  )
82
95
  return model
83
96
 
@@ -92,14 +105,24 @@ class AtomModelSingleton:
92
105
  return cls._instance
93
106
 
94
107
  def get_atom_model(self, atom_model_name: str, **kwargs):
108
+
95
109
  lang = kwargs.get('lang', None)
96
110
  layout_model_name = kwargs.get('layout_model_name', None)
97
- key = (atom_model_name, layout_model_name, lang)
111
+ table_model_name = kwargs.get('table_model_name', None)
112
+
113
+ if atom_model_name in [AtomicModel.OCR]:
114
+ key = (atom_model_name, lang)
115
+ elif atom_model_name in [AtomicModel.Layout]:
116
+ key = (atom_model_name, layout_model_name)
117
+ elif atom_model_name in [AtomicModel.Table]:
118
+ key = (atom_model_name, table_model_name)
119
+ else:
120
+ key = atom_model_name
121
+
98
122
  if key not in self._models:
99
123
  self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
100
124
  return self._models[key]
101
125
 
102
-
103
126
  def atom_model_init(model_name: str, **kwargs):
104
127
  atom_model = None
105
128
  if model_name == AtomicModel.Layout:
@@ -114,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
114
137
  kwargs.get('doclayout_yolo_weights'),
115
138
  kwargs.get('device')
116
139
  )
140
+ else:
141
+ logger.error('layout model name not allow')
142
+ exit(1)
117
143
  elif model_name == AtomicModel.MFD:
118
144
  atom_model = mfd_model_init(
119
145
  kwargs.get('mfd_weights'),
@@ -129,15 +155,25 @@ def atom_model_init(model_name: str, **kwargs):
129
155
  atom_model = ocr_model_init(
130
156
  kwargs.get('ocr_show_log'),
131
157
  kwargs.get('det_db_box_thresh'),
132
- kwargs.get('lang')
158
+ kwargs.get('lang'),
133
159
  )
134
160
  elif model_name == AtomicModel.Table:
135
161
  atom_model = table_model_init(
136
162
  kwargs.get('table_model_name'),
137
163
  kwargs.get('table_model_path'),
138
164
  kwargs.get('table_max_time'),
139
- kwargs.get('device')
165
+ kwargs.get('device'),
166
+ kwargs.get('ocr_engine')
140
167
  )
168
+ elif model_name == AtomicModel.LangDetect:
169
+ if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
170
+ atom_model = langdetect_model_init(
171
+ kwargs.get('langdetect_model_weight'),
172
+ kwargs.get('device')
173
+ )
174
+ else:
175
+ logger.error('langdetect model name not allow')
176
+ exit(1)
141
177
  else:
142
178
  logger.error('model name not allow')
143
179
  exit(1)
@@ -42,10 +42,22 @@ def get_res_list_from_layout_res(layout_res):
42
42
 
43
43
 
44
44
  def clean_vram(device, vram_threshold=8):
45
+ total_memory = get_vram(device)
46
+ if total_memory and total_memory <= vram_threshold:
47
+ gc_start = time.time()
48
+ clean_memory(device)
49
+ gc_time = round(time.time() - gc_start, 2)
50
+ logger.info(f"gc time: {gc_time}")
51
+
52
+
53
+ def get_vram(device):
45
54
  if torch.cuda.is_available() and device != 'cpu':
46
55
  total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
47
- if total_memory <= vram_threshold:
48
- gc_start = time.time()
49
- clean_memory()
50
- gc_time = round(time.time() - gc_start, 2)
51
- logger.info(f"gc time: {gc_time}")
56
+ return total_memory
57
+ elif str(device).startswith("npu"):
58
+ import torch_npu
59
+ if torch_npu.npu.is_available():
60
+ total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
61
+ return total_memory
62
+ else:
63
+ return None
@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
303
303
  return False
304
304
  else:
305
305
  # logger.info((p3[1] - p1[1])/height)
306
- return True
306
+ return True
307
+
308
+
309
+ class ONNXModelSingleton:
310
+ _instance = None
311
+ _models = {}
312
+
313
+ def __new__(cls, *args, **kwargs):
314
+ if cls._instance is None:
315
+ cls._instance = super().__new__(cls)
316
+ return cls._instance
317
+
318
+ def get_onnx_model(self, **kwargs):
319
+
320
+ lang = kwargs.get('lang', None)
321
+ det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
322
+ use_dilation = kwargs.get('use_dilation', True)
323
+ det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
324
+ key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
325
+ if key not in self._models:
326
+ self._models[key] = onnx_model_init(key)
327
+ return self._models[key]
328
+
329
+ def onnx_model_init(key):
330
+
331
+ import importlib.resources
332
+
333
+ resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
334
+
335
+ onnx_model = None
336
+ additional_ocr_params = {
337
+ "use_onnx": True,
338
+ "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
339
+ "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
340
+ "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
341
+ "det_db_box_thresh": key[1],
342
+ "use_dilation": key[2],
343
+ "det_db_unclip_ratio": key[3],
344
+ }
345
+ # logger.info(f"additional_ocr_params: {additional_ocr_params}")
346
+ if key[0] is not None:
347
+ additional_ocr_params["lang"] = key[0]
348
+
349
+ from paddleocr import PaddleOCR
350
+ onnx_model = PaddleOCR(**additional_ocr_params)
351
+
352
+ if onnx_model is None:
353
+ logger.error('model init failed')
354
+ exit(1)
355
+ else:
356
+ return onnx_model