magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/post_proc/para_split_v3.py +16 -13
  82. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  83. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  84. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
  88. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  91. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  92. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  93. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  94. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  95. magic_pdf-1.2.1.dist-info/RECORD +0 -147
  96. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  97. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  98. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  99. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  100. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  101. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -5,47 +5,57 @@ from magic_pdf.config.constants import MODEL_NAME
5
5
  from magic_pdf.model.model_list import AtomicModel
6
6
  from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
7
7
  from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
8
- from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
9
8
  from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
10
9
  from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
11
-
12
- try:
13
- from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
14
- from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
15
- from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
16
- license_key = load_license()
17
- logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
18
- f' License expired at {license_key["payload"]["date"]["end_date"]}')
19
- except Exception as e:
20
- if isinstance(e, ImportError):
21
- pass
22
- elif isinstance(e, LicenseFormatError):
23
- logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
24
- elif isinstance(e, LicenseSignatureError):
25
- logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
26
- elif isinstance(e, LicenseExpiredError):
27
- logger.error("Ascend Plugin: License has expired. Please renew your license.")
28
- elif isinstance(e, FileNotFoundError):
29
- logger.error("Ascend Plugin: Not found License file.")
30
- else:
31
- logger.error(f"Ascend Plugin: {e}")
32
- from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
33
- # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
34
- from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
35
-
36
- from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
37
- from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
38
-
39
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
10
+ from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
11
+ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
12
+ # try:
13
+ # from magic_pdf_ascend_plugin.libs.license_verifier import (
14
+ # LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
15
+ # load_license)
16
+ # from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
17
+ # from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
18
+ # license_key = load_license()
19
+ # logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
20
+ # f' License expired at {license_key["payload"]["date"]["end_date"]}')
21
+ # except Exception as e:
22
+ # if isinstance(e, ImportError):
23
+ # pass
24
+ # elif isinstance(e, LicenseFormatError):
25
+ # logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
26
+ # elif isinstance(e, LicenseSignatureError):
27
+ # logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
28
+ # elif isinstance(e, LicenseExpiredError):
29
+ # logger.error('Ascend Plugin: License has expired. Please renew your license.')
30
+ # elif isinstance(e, FileNotFoundError):
31
+ # logger.error('Ascend Plugin: Not found License file.')
32
+ # else:
33
+ # logger.error(f'Ascend Plugin: {e}')
34
+ # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
35
+ # # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
36
+ # from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
37
+
38
+
39
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
40
40
  if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
41
+ from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
41
42
  table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
42
43
  elif table_model_type == MODEL_NAME.TABLE_MASTER:
44
+ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
43
45
  config = {
44
46
  'model_dir': model_path,
45
47
  'device': _device_
46
48
  }
47
49
  table_model = TableMasterPaddleModel(config)
48
50
  elif table_model_type == MODEL_NAME.RAPID_TABLE:
51
+ atom_model_manager = AtomModelSingleton()
52
+ ocr_engine = atom_model_manager.get_atom_model(
53
+ atom_model_name='ocr',
54
+ ocr_show_log=False,
55
+ det_db_box_thresh=0.5,
56
+ det_db_unclip_ratio=1.6,
57
+ lang=lang
58
+ )
49
59
  table_model = RapidTableModel(ocr_engine, table_sub_model_name)
50
60
  else:
51
61
  logger.error('table model type not allow')
@@ -55,7 +65,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
55
65
 
56
66
 
57
67
  def mfd_model_init(weight, device='cpu'):
58
- if str(device).startswith("npu"):
68
+ if str(device).startswith('npu'):
59
69
  device = torch.device(device)
60
70
  mfd_model = YOLOv8MFDModel(weight, device)
61
71
  return mfd_model
@@ -67,19 +77,20 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
67
77
 
68
78
 
69
79
  def layout_model_init(weight, config_file, device):
80
+ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
70
81
  model = Layoutlmv3_Predictor(weight, config_file, device)
71
82
  return model
72
83
 
73
84
 
74
85
  def doclayout_yolo_model_init(weight, device='cpu'):
75
- if str(device).startswith("npu"):
86
+ if str(device).startswith('npu'):
76
87
  device = torch.device(device)
77
88
  model = DocLayoutYOLOModel(weight, device)
78
89
  return model
79
90
 
80
91
 
81
92
  def langdetect_model_init(langdetect_model_weight, device='cpu'):
82
- if str(device).startswith("npu"):
93
+ if str(device).startswith('npu'):
83
94
  device = torch.device(device)
84
95
  model = YOLOv11LangDetModel(langdetect_model_weight, device)
85
96
  return model
@@ -92,7 +103,8 @@ def ocr_model_init(show_log: bool = False,
92
103
  det_db_unclip_ratio=1.8,
93
104
  ):
94
105
  if lang is not None and lang != '':
95
- model = ModifiedPaddleOCR(
106
+ # model = ModifiedPaddleOCR(
107
+ model = PytorchPaddleOCR(
96
108
  show_log=show_log,
97
109
  det_db_box_thresh=det_db_box_thresh,
98
110
  lang=lang,
@@ -100,7 +112,8 @@ def ocr_model_init(show_log: bool = False,
100
112
  det_db_unclip_ratio=det_db_unclip_ratio,
101
113
  )
102
114
  else:
103
- model = ModifiedPaddleOCR(
115
+ # model = ModifiedPaddleOCR(
116
+ model = PytorchPaddleOCR(
104
117
  show_log=show_log,
105
118
  det_db_box_thresh=det_db_box_thresh,
106
119
  use_dilation=use_dilation,
@@ -129,7 +142,7 @@ class AtomModelSingleton:
129
142
  elif atom_model_name in [AtomicModel.Layout]:
130
143
  key = (atom_model_name, layout_model_name)
131
144
  elif atom_model_name in [AtomicModel.Table]:
132
- key = (atom_model_name, table_model_name)
145
+ key = (atom_model_name, table_model_name, lang)
133
146
  else:
134
147
  key = atom_model_name
135
148
 
@@ -177,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
177
190
  kwargs.get('table_model_path'),
178
191
  kwargs.get('table_max_time'),
179
192
  kwargs.get('device'),
180
- kwargs.get('ocr_engine'),
193
+ kwargs.get('lang'),
181
194
  kwargs.get('table_sub_model_name')
182
195
  )
183
196
  elif model_name == AtomicModel.LangDetect:
@@ -1,25 +1,31 @@
1
1
  import time
2
-
3
2
  import torch
4
- from PIL import Image
5
3
  from loguru import logger
6
-
4
+ import numpy as np
7
5
  from magic_pdf.libs.clean_memory import clean_memory
8
6
 
9
7
 
10
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
8
+ def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
9
+
11
10
  crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
12
11
  crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
13
- # Create a white background with an additional width and height of 50
12
+
13
+ # Calculate new dimensions
14
14
  crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
15
15
  crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
16
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
17
16
 
18
- # Crop image
19
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
20
- cropped_img = input_pil_img.crop(crop_box)
21
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
22
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
17
+ # Create a white background array
18
+ return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
19
+
20
+ # Crop the original image using numpy slicing
21
+ cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
22
+
23
+ # Paste the cropped image onto the white background
24
+ return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
25
+ crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
26
+
27
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
28
+ crop_new_height]
23
29
  return return_image, return_list
24
30
 
25
31
 
@@ -0,0 +1 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
@@ -1,58 +1,67 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import copy
3
+
1
4
  import cv2
2
5
  import numpy as np
3
- from loguru import logger
4
- from io import BytesIO
5
- from PIL import Image
6
- import base64
7
- from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
8
6
  from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
9
-
10
- import importlib.resources
11
- from paddleocr import PaddleOCR
12
- from ppocr.utils.utility import check_and_read
7
+ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
13
8
 
14
9
 
15
10
  def img_decode(content: bytes):
16
11
  np_arr = np.frombuffer(content, dtype=np.uint8)
17
12
  return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
18
13
 
19
-
20
14
  def check_img(img):
21
15
  if isinstance(img, bytes):
22
16
  img = img_decode(img)
23
- if isinstance(img, str):
24
- image_file = img
25
- img, flag_gif, flag_pdf = check_and_read(image_file)
26
- if not flag_gif and not flag_pdf:
27
- with open(image_file, 'rb') as f:
28
- img_str = f.read()
29
- img = img_decode(img_str)
30
- if img is None:
31
- try:
32
- buf = BytesIO()
33
- image = BytesIO(img_str)
34
- im = Image.open(image)
35
- rgb = im.convert('RGB')
36
- rgb.save(buf, 'jpeg')
37
- buf.seek(0)
38
- image_bytes = buf.read()
39
- data_base64 = str(base64.b64encode(image_bytes),
40
- encoding="utf-8")
41
- image_decode = base64.b64decode(data_base64)
42
- img_array = np.frombuffer(image_decode, np.uint8)
43
- img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
44
- except:
45
- logger.error("error in loading image:{}".format(image_file))
46
- return None
47
- if img is None:
48
- logger.error("error in loading image:{}".format(image_file))
49
- return None
50
17
  if isinstance(img, np.ndarray) and len(img.shape) == 2:
51
18
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
19
+ return img
52
20
 
21
+
22
+ def alpha_to_color(img, alpha_color=(255, 255, 255)):
23
+ if len(img.shape) == 3 and img.shape[2] == 4:
24
+ B, G, R, A = cv2.split(img)
25
+ alpha = A / 255
26
+
27
+ R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
28
+ G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
29
+ B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
30
+
31
+ img = cv2.merge((B, G, R))
53
32
  return img
54
33
 
55
34
 
35
+ def preprocess_image(_image):
36
+ alpha_color = (255, 255, 255)
37
+ _image = alpha_to_color(_image, alpha_color)
38
+ return _image
39
+
40
+
41
+ def sorted_boxes(dt_boxes):
42
+ """
43
+ Sort text boxes in order from top to bottom, left to right
44
+ args:
45
+ dt_boxes(array):detected text boxes with shape [4, 2]
46
+ return:
47
+ sorted boxes(array) with shape [4, 2]
48
+ """
49
+ num_boxes = dt_boxes.shape[0]
50
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
51
+ _boxes = list(sorted_boxes)
52
+
53
+ for i in range(num_boxes - 1):
54
+ for j in range(i, -1, -1):
55
+ if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
56
+ (_boxes[j + 1][0][0] < _boxes[j][0][0]):
57
+ tmp = _boxes[j]
58
+ _boxes[j] = _boxes[j + 1]
59
+ _boxes[j + 1] = tmp
60
+ else:
61
+ break
62
+ return _boxes
63
+
64
+
56
65
  def bbox_to_points(bbox):
57
66
  """ 将bbox格式转换为四个顶点的数组 """
58
67
  x0, y0, x1, y1 = bbox
@@ -252,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
252
261
  return adjusted_mfdetrec_res
253
262
 
254
263
 
255
- def get_ocr_result_list(ocr_res, useful_list):
264
+ def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
256
265
  paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
257
266
  ocr_result_list = []
267
+ ori_im = new_image.copy()
258
268
  for box_ocr_res in ocr_res:
259
269
 
260
270
  if len(box_ocr_res) == 2:
@@ -266,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
266
276
  else:
267
277
  p1, p2, p3, p4 = box_ocr_res
268
278
  text, score = "", 1
279
+
280
+ if ocr_enable:
281
+ tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
282
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
283
+
269
284
  # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
270
285
  # if average_angle_degrees > 0.5:
271
286
  poly = [p1, p2, p3, p4]
@@ -288,12 +303,22 @@ def get_ocr_result_list(ocr_res, useful_list):
288
303
  p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
289
304
  p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
290
305
 
291
- ocr_result_list.append({
292
- 'category_id': 15,
293
- 'poly': p1 + p2 + p3 + p4,
294
- 'score': float(round(score, 2)),
295
- 'text': text,
296
- })
306
+ if ocr_enable:
307
+ ocr_result_list.append({
308
+ 'category_id': 15,
309
+ 'poly': p1 + p2 + p3 + p4,
310
+ 'score': 1,
311
+ 'text': text,
312
+ 'np_img': img_crop,
313
+ 'lang': lang,
314
+ })
315
+ else:
316
+ ocr_result_list.append({
317
+ 'category_id': 15,
318
+ 'poly': p1 + p2 + p3 + p4,
319
+ 'score': float(round(score, 2)),
320
+ 'text': text,
321
+ })
297
322
 
298
323
  return ocr_result_list
299
324
 
@@ -308,56 +333,36 @@ def calculate_is_angle(poly):
308
333
  return True
309
334
 
310
335
 
311
- class ONNXModelSingleton:
312
- _instance = None
313
- _models = {}
314
-
315
- def __new__(cls, *args, **kwargs):
316
- if cls._instance is None:
317
- cls._instance = super().__new__(cls)
318
- return cls._instance
319
-
320
- def get_onnx_model(self, **kwargs):
321
-
322
- lang = kwargs.get('lang', None)
323
- det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
324
- use_dilation = kwargs.get('use_dilation', True)
325
- det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
326
- key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
327
- if key not in self._models:
328
- self._models[key] = onnx_model_init(key)
329
- return self._models[key]
330
-
331
- def onnx_model_init(key):
332
- if len(key) < 4:
333
- logger.error('Invalid key length, expected at least 4 elements')
334
- exit(1)
335
-
336
- try:
337
- with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
338
- additional_ocr_params = {
339
- "use_onnx": True,
340
- "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
341
- "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
342
- "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
343
- "det_db_box_thresh": key[1],
344
- "use_dilation": key[2],
345
- "det_db_unclip_ratio": key[3],
346
- }
347
-
348
- if key[0] is not None:
349
- additional_ocr_params["lang"] = key[0]
350
-
351
- # logger.info(f"additional_ocr_params: {additional_ocr_params}")
352
-
353
- onnx_model = PaddleOCR(**additional_ocr_params)
354
-
355
- if onnx_model is None:
356
- logger.error('model init failed')
357
- exit(1)
358
- else:
359
- return onnx_model
360
-
361
- except Exception as e:
362
- logger.exception(f'Error initializing model: {e}')
363
- exit(1)
336
+ def get_rotate_crop_image(img, points):
337
+ '''
338
+ img_height, img_width = img.shape[0:2]
339
+ left = int(np.min(points[:, 0]))
340
+ right = int(np.max(points[:, 0]))
341
+ top = int(np.min(points[:, 1]))
342
+ bottom = int(np.max(points[:, 1]))
343
+ img_crop = img[top:bottom, left:right, :].copy()
344
+ points[:, 0] = points[:, 0] - left
345
+ points[:, 1] = points[:, 1] - top
346
+ '''
347
+ assert len(points) == 4, "shape of points must be 4*2"
348
+ img_crop_width = int(
349
+ max(
350
+ np.linalg.norm(points[0] - points[1]),
351
+ np.linalg.norm(points[2] - points[3])))
352
+ img_crop_height = int(
353
+ max(
354
+ np.linalg.norm(points[0] - points[3]),
355
+ np.linalg.norm(points[1] - points[2])))
356
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
357
+ [img_crop_width, img_crop_height],
358
+ [0, img_crop_height]])
359
+ M = cv2.getPerspectiveTransform(points, pts_std)
360
+ dst_img = cv2.warpPerspective(
361
+ img,
362
+ M, (img_crop_width, img_crop_height),
363
+ borderMode=cv2.BORDER_REPLICATE,
364
+ flags=cv2.INTER_CUBIC)
365
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
366
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
367
+ dst_img = np.rot90(dst_img)
368
+ return dst_img
@@ -0,0 +1,193 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import copy
3
+ import os.path
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import yaml
10
+ from loguru import logger
11
+
12
+ from magic_pdf.libs.config_reader import get_device, get_local_models_dir
13
+ from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
14
+ from .tools.infer.predict_system import TextSystem
15
+ from .tools.infer import pytorchocr_utility as utility
16
+ import argparse
17
+
18
+
19
+ latin_lang = [
20
+ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
21
+ 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
22
+ 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
23
+ 'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
24
+ ]
25
+ arabic_lang = ['ar', 'fa', 'ug', 'ur']
26
+ cyrillic_lang = [
27
+ 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
28
+ 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
29
+ ]
30
+ devanagari_lang = [
31
+ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
32
+ 'sa', 'bgc'
33
+ ]
34
+
35
+
36
+ def get_model_params(lang, config):
37
+ if lang in config['lang']:
38
+ params = config['lang'][lang]
39
+ det = params.get('det')
40
+ rec = params.get('rec')
41
+ dict_file = params.get('dict')
42
+ return det, rec, dict_file
43
+ else:
44
+ raise Exception (f'Language {lang} not supported')
45
+
46
+
47
+ root_dir = Path(__file__).resolve().parent
48
+
49
+
50
+ class PytorchPaddleOCR(TextSystem):
51
+ def __init__(self, *args, **kwargs):
52
+ parser = utility.init_args()
53
+ args = parser.parse_args(args)
54
+
55
+ self.lang = kwargs.get('lang', 'ch')
56
+ if self.lang in latin_lang:
57
+ self.lang = 'latin'
58
+ elif self.lang in arabic_lang:
59
+ self.lang = 'arabic'
60
+ elif self.lang in cyrillic_lang:
61
+ self.lang = 'cyrillic'
62
+ elif self.lang in devanagari_lang:
63
+ self.lang = 'devanagari'
64
+ else:
65
+ pass
66
+
67
+ models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
68
+ with open(models_config_path) as file:
69
+ config = yaml.safe_load(file)
70
+ det, rec, dict_file = get_model_params(self.lang, config)
71
+ ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch')
72
+ kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
73
+ kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
74
+ kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
75
+ # kwargs['rec_batch_num'] = 8
76
+
77
+ kwargs['device'] = get_device()
78
+
79
+ default_args = vars(args)
80
+ default_args.update(kwargs)
81
+ args = argparse.Namespace(**default_args)
82
+
83
+ super().__init__(args)
84
+
85
+ def ocr(self,
86
+ img,
87
+ det=True,
88
+ rec=True,
89
+ mfd_res=None,
90
+ tqdm_enable=False,
91
+ ):
92
+ assert isinstance(img, (np.ndarray, list, str, bytes))
93
+ if isinstance(img, list) and det == True:
94
+ logger.error('When input a list of images, det must be false')
95
+ exit(0)
96
+ img = check_img(img)
97
+ imgs = [img]
98
+ with warnings.catch_warnings():
99
+ warnings.simplefilter("ignore", category=RuntimeWarning)
100
+ if det and rec:
101
+ ocr_res = []
102
+ for img in imgs:
103
+ img = preprocess_image(img)
104
+ dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
105
+ if not dt_boxes and not rec_res:
106
+ ocr_res.append(None)
107
+ continue
108
+ tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
109
+ ocr_res.append(tmp_res)
110
+ return ocr_res
111
+ elif det and not rec:
112
+ ocr_res = []
113
+ for img in imgs:
114
+ img = preprocess_image(img)
115
+ dt_boxes, elapse = self.text_detector(img)
116
+ # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
117
+ if dt_boxes is None:
118
+ ocr_res.append(None)
119
+ continue
120
+ dt_boxes = sorted_boxes(dt_boxes)
121
+ # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
122
+ dt_boxes = merge_det_boxes(dt_boxes)
123
+ if mfd_res:
124
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
125
+ tmp_res = [box.tolist() for box in dt_boxes]
126
+ ocr_res.append(tmp_res)
127
+ return ocr_res
128
+ elif not det and rec:
129
+ ocr_res = []
130
+ for img in imgs:
131
+ if not isinstance(img, list):
132
+ img = preprocess_image(img)
133
+ img = [img]
134
+ rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
135
+ # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
136
+ ocr_res.append(rec_res)
137
+ return ocr_res
138
+
139
+ def __call__(self, img, mfd_res=None):
140
+
141
+ if img is None:
142
+ logger.debug("no valid image provided")
143
+ return None, None
144
+
145
+ ori_im = img.copy()
146
+ dt_boxes, elapse = self.text_detector(img)
147
+
148
+ if dt_boxes is None:
149
+ logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
150
+ return None, None
151
+ else:
152
+ pass
153
+ # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
154
+ img_crop_list = []
155
+
156
+ dt_boxes = sorted_boxes(dt_boxes)
157
+
158
+ # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
159
+ dt_boxes = merge_det_boxes(dt_boxes)
160
+
161
+ if mfd_res:
162
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
163
+
164
+ for bno in range(len(dt_boxes)):
165
+ tmp_box = copy.deepcopy(dt_boxes[bno])
166
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
167
+ img_crop_list.append(img_crop)
168
+
169
+ rec_res, elapse = self.text_recognizer(img_crop_list)
170
+ # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
171
+
172
+ filter_boxes, filter_rec_res = [], []
173
+ for box, rec_result in zip(dt_boxes, rec_res):
174
+ text, score = rec_result
175
+ if score >= self.drop_score:
176
+ filter_boxes.append(box)
177
+ filter_rec_res.append(rec_result)
178
+
179
+ return filter_boxes, filter_rec_res
180
+
181
+ if __name__ == '__main__':
182
+ pytorch_paddle_ocr = PytorchPaddleOCR()
183
+ img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
184
+ dt_boxes, rec_res = pytorch_paddle_ocr(img)
185
+ ocr_res = []
186
+ if not dt_boxes and not rec_res:
187
+ ocr_res.append(None)
188
+ else:
189
+ tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
190
+ ocr_res.append(tmp_res)
191
+ print(ocr_res)
192
+
193
+