magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/data/batch_build_dataset.py +156 -0
- magic_pdf/data/dataset.py +44 -24
- magic_pdf/data/utils.py +108 -9
- magic_pdf/dict2md/ocr_mkcontent.py +4 -3
- magic_pdf/libs/pdf_image_tools.py +11 -6
- magic_pdf/libs/performance_stats.py +12 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +175 -201
- magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
- magic_pdf/model/pdf_extract_kit.py +5 -38
- magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
- magic_pdf/model/sub_modules/model_init.py +50 -37
- magic_pdf/model/sub_modules/model_utils.py +17 -11
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
- magic_pdf/pdf_parse_union_core_v2.py +112 -74
- magic_pdf/post_proc/para_split_v3.py +16 -13
- magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
- magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
- magic_pdf/resources/model_config/model_configs.yaml +1 -1
- magic_pdf/tools/cli.py +30 -12
- magic_pdf/tools/common.py +90 -12
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
- magic_pdf-1.3.0.dist-info/RECORD +202 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
- magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
- magic_pdf-1.2.1.dist-info/RECORD +0 -147
- /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
- /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
- /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -5,47 +5,57 @@ from magic_pdf.config.constants import MODEL_NAME
|
|
5
5
|
from magic_pdf.model.model_list import AtomicModel
|
6
6
|
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
|
7
7
|
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
|
8
|
-
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
|
9
8
|
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
10
9
|
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
from magic_pdf.model.sub_modules.
|
37
|
-
from magic_pdf.model.sub_modules.table.
|
38
|
-
|
39
|
-
|
10
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
|
11
|
+
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
|
12
|
+
# try:
|
13
|
+
# from magic_pdf_ascend_plugin.libs.license_verifier import (
|
14
|
+
# LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
|
15
|
+
# load_license)
|
16
|
+
# from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
|
17
|
+
# from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
|
18
|
+
# license_key = load_license()
|
19
|
+
# logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
|
20
|
+
# f' License expired at {license_key["payload"]["date"]["end_date"]}')
|
21
|
+
# except Exception as e:
|
22
|
+
# if isinstance(e, ImportError):
|
23
|
+
# pass
|
24
|
+
# elif isinstance(e, LicenseFormatError):
|
25
|
+
# logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
|
26
|
+
# elif isinstance(e, LicenseSignatureError):
|
27
|
+
# logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
|
28
|
+
# elif isinstance(e, LicenseExpiredError):
|
29
|
+
# logger.error('Ascend Plugin: License has expired. Please renew your license.')
|
30
|
+
# elif isinstance(e, FileNotFoundError):
|
31
|
+
# logger.error('Ascend Plugin: Not found License file.')
|
32
|
+
# else:
|
33
|
+
# logger.error(f'Ascend Plugin: {e}')
|
34
|
+
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
|
35
|
+
# # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
36
|
+
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
|
37
|
+
|
38
|
+
|
39
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
|
40
40
|
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
41
|
+
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
|
41
42
|
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
42
43
|
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
44
|
+
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
|
43
45
|
config = {
|
44
46
|
'model_dir': model_path,
|
45
47
|
'device': _device_
|
46
48
|
}
|
47
49
|
table_model = TableMasterPaddleModel(config)
|
48
50
|
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
51
|
+
atom_model_manager = AtomModelSingleton()
|
52
|
+
ocr_engine = atom_model_manager.get_atom_model(
|
53
|
+
atom_model_name='ocr',
|
54
|
+
ocr_show_log=False,
|
55
|
+
det_db_box_thresh=0.5,
|
56
|
+
det_db_unclip_ratio=1.6,
|
57
|
+
lang=lang
|
58
|
+
)
|
49
59
|
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
|
50
60
|
else:
|
51
61
|
logger.error('table model type not allow')
|
@@ -55,7 +65,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
|
|
55
65
|
|
56
66
|
|
57
67
|
def mfd_model_init(weight, device='cpu'):
|
58
|
-
if str(device).startswith(
|
68
|
+
if str(device).startswith('npu'):
|
59
69
|
device = torch.device(device)
|
60
70
|
mfd_model = YOLOv8MFDModel(weight, device)
|
61
71
|
return mfd_model
|
@@ -67,19 +77,20 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
|
67
77
|
|
68
78
|
|
69
79
|
def layout_model_init(weight, config_file, device):
|
80
|
+
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
|
70
81
|
model = Layoutlmv3_Predictor(weight, config_file, device)
|
71
82
|
return model
|
72
83
|
|
73
84
|
|
74
85
|
def doclayout_yolo_model_init(weight, device='cpu'):
|
75
|
-
if str(device).startswith(
|
86
|
+
if str(device).startswith('npu'):
|
76
87
|
device = torch.device(device)
|
77
88
|
model = DocLayoutYOLOModel(weight, device)
|
78
89
|
return model
|
79
90
|
|
80
91
|
|
81
92
|
def langdetect_model_init(langdetect_model_weight, device='cpu'):
|
82
|
-
if str(device).startswith(
|
93
|
+
if str(device).startswith('npu'):
|
83
94
|
device = torch.device(device)
|
84
95
|
model = YOLOv11LangDetModel(langdetect_model_weight, device)
|
85
96
|
return model
|
@@ -92,7 +103,8 @@ def ocr_model_init(show_log: bool = False,
|
|
92
103
|
det_db_unclip_ratio=1.8,
|
93
104
|
):
|
94
105
|
if lang is not None and lang != '':
|
95
|
-
model = ModifiedPaddleOCR(
|
106
|
+
# model = ModifiedPaddleOCR(
|
107
|
+
model = PytorchPaddleOCR(
|
96
108
|
show_log=show_log,
|
97
109
|
det_db_box_thresh=det_db_box_thresh,
|
98
110
|
lang=lang,
|
@@ -100,7 +112,8 @@ def ocr_model_init(show_log: bool = False,
|
|
100
112
|
det_db_unclip_ratio=det_db_unclip_ratio,
|
101
113
|
)
|
102
114
|
else:
|
103
|
-
model = ModifiedPaddleOCR(
|
115
|
+
# model = ModifiedPaddleOCR(
|
116
|
+
model = PytorchPaddleOCR(
|
104
117
|
show_log=show_log,
|
105
118
|
det_db_box_thresh=det_db_box_thresh,
|
106
119
|
use_dilation=use_dilation,
|
@@ -129,7 +142,7 @@ class AtomModelSingleton:
|
|
129
142
|
elif atom_model_name in [AtomicModel.Layout]:
|
130
143
|
key = (atom_model_name, layout_model_name)
|
131
144
|
elif atom_model_name in [AtomicModel.Table]:
|
132
|
-
key = (atom_model_name, table_model_name)
|
145
|
+
key = (atom_model_name, table_model_name, lang)
|
133
146
|
else:
|
134
147
|
key = atom_model_name
|
135
148
|
|
@@ -177,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
|
|
177
190
|
kwargs.get('table_model_path'),
|
178
191
|
kwargs.get('table_max_time'),
|
179
192
|
kwargs.get('device'),
|
180
|
-
kwargs.get('
|
193
|
+
kwargs.get('lang'),
|
181
194
|
kwargs.get('table_sub_model_name')
|
182
195
|
)
|
183
196
|
elif model_name == AtomicModel.LangDetect:
|
@@ -1,25 +1,31 @@
|
|
1
1
|
import time
|
2
|
-
|
3
2
|
import torch
|
4
|
-
from PIL import Image
|
5
3
|
from loguru import logger
|
6
|
-
|
4
|
+
import numpy as np
|
7
5
|
from magic_pdf.libs.clean_memory import clean_memory
|
8
6
|
|
9
7
|
|
10
|
-
def crop_img(input_res,
|
8
|
+
def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
|
9
|
+
|
11
10
|
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
12
11
|
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
13
|
-
|
12
|
+
|
13
|
+
# Calculate new dimensions
|
14
14
|
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
15
15
|
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
16
|
-
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
17
16
|
|
18
|
-
#
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
# Create a white background array
|
18
|
+
return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
|
19
|
+
|
20
|
+
# Crop the original image using numpy slicing
|
21
|
+
cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
|
22
|
+
|
23
|
+
# Paste the cropped image onto the white background
|
24
|
+
return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
|
25
|
+
crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
|
26
|
+
|
27
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
|
28
|
+
crop_new_height]
|
23
29
|
return return_image, return_list
|
24
30
|
|
25
31
|
|
@@ -0,0 +1 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
@@ -1,58 +1,67 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
2
|
+
import copy
|
3
|
+
|
1
4
|
import cv2
|
2
5
|
import numpy as np
|
3
|
-
from loguru import logger
|
4
|
-
from io import BytesIO
|
5
|
-
from PIL import Image
|
6
|
-
import base64
|
7
|
-
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
8
6
|
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
9
|
-
|
10
|
-
import importlib.resources
|
11
|
-
from paddleocr import PaddleOCR
|
12
|
-
from ppocr.utils.utility import check_and_read
|
7
|
+
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
13
8
|
|
14
9
|
|
15
10
|
def img_decode(content: bytes):
|
16
11
|
np_arr = np.frombuffer(content, dtype=np.uint8)
|
17
12
|
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
18
13
|
|
19
|
-
|
20
14
|
def check_img(img):
|
21
15
|
if isinstance(img, bytes):
|
22
16
|
img = img_decode(img)
|
23
|
-
if isinstance(img, str):
|
24
|
-
image_file = img
|
25
|
-
img, flag_gif, flag_pdf = check_and_read(image_file)
|
26
|
-
if not flag_gif and not flag_pdf:
|
27
|
-
with open(image_file, 'rb') as f:
|
28
|
-
img_str = f.read()
|
29
|
-
img = img_decode(img_str)
|
30
|
-
if img is None:
|
31
|
-
try:
|
32
|
-
buf = BytesIO()
|
33
|
-
image = BytesIO(img_str)
|
34
|
-
im = Image.open(image)
|
35
|
-
rgb = im.convert('RGB')
|
36
|
-
rgb.save(buf, 'jpeg')
|
37
|
-
buf.seek(0)
|
38
|
-
image_bytes = buf.read()
|
39
|
-
data_base64 = str(base64.b64encode(image_bytes),
|
40
|
-
encoding="utf-8")
|
41
|
-
image_decode = base64.b64decode(data_base64)
|
42
|
-
img_array = np.frombuffer(image_decode, np.uint8)
|
43
|
-
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
44
|
-
except:
|
45
|
-
logger.error("error in loading image:{}".format(image_file))
|
46
|
-
return None
|
47
|
-
if img is None:
|
48
|
-
logger.error("error in loading image:{}".format(image_file))
|
49
|
-
return None
|
50
17
|
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
51
18
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
19
|
+
return img
|
52
20
|
|
21
|
+
|
22
|
+
def alpha_to_color(img, alpha_color=(255, 255, 255)):
|
23
|
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
24
|
+
B, G, R, A = cv2.split(img)
|
25
|
+
alpha = A / 255
|
26
|
+
|
27
|
+
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
|
28
|
+
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
|
29
|
+
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
|
30
|
+
|
31
|
+
img = cv2.merge((B, G, R))
|
53
32
|
return img
|
54
33
|
|
55
34
|
|
35
|
+
def preprocess_image(_image):
|
36
|
+
alpha_color = (255, 255, 255)
|
37
|
+
_image = alpha_to_color(_image, alpha_color)
|
38
|
+
return _image
|
39
|
+
|
40
|
+
|
41
|
+
def sorted_boxes(dt_boxes):
|
42
|
+
"""
|
43
|
+
Sort text boxes in order from top to bottom, left to right
|
44
|
+
args:
|
45
|
+
dt_boxes(array):detected text boxes with shape [4, 2]
|
46
|
+
return:
|
47
|
+
sorted boxes(array) with shape [4, 2]
|
48
|
+
"""
|
49
|
+
num_boxes = dt_boxes.shape[0]
|
50
|
+
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
51
|
+
_boxes = list(sorted_boxes)
|
52
|
+
|
53
|
+
for i in range(num_boxes - 1):
|
54
|
+
for j in range(i, -1, -1):
|
55
|
+
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
56
|
+
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
57
|
+
tmp = _boxes[j]
|
58
|
+
_boxes[j] = _boxes[j + 1]
|
59
|
+
_boxes[j + 1] = tmp
|
60
|
+
else:
|
61
|
+
break
|
62
|
+
return _boxes
|
63
|
+
|
64
|
+
|
56
65
|
def bbox_to_points(bbox):
|
57
66
|
""" 将bbox格式转换为四个顶点的数组 """
|
58
67
|
x0, y0, x1, y1 = bbox
|
@@ -252,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
|
|
252
261
|
return adjusted_mfdetrec_res
|
253
262
|
|
254
263
|
|
255
|
-
def get_ocr_result_list(ocr_res, useful_list):
|
264
|
+
def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
|
256
265
|
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
257
266
|
ocr_result_list = []
|
267
|
+
ori_im = new_image.copy()
|
258
268
|
for box_ocr_res in ocr_res:
|
259
269
|
|
260
270
|
if len(box_ocr_res) == 2:
|
@@ -266,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
|
|
266
276
|
else:
|
267
277
|
p1, p2, p3, p4 = box_ocr_res
|
268
278
|
text, score = "", 1
|
279
|
+
|
280
|
+
if ocr_enable:
|
281
|
+
tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
|
282
|
+
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
283
|
+
|
269
284
|
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
|
270
285
|
# if average_angle_degrees > 0.5:
|
271
286
|
poly = [p1, p2, p3, p4]
|
@@ -288,12 +303,22 @@ def get_ocr_result_list(ocr_res, useful_list):
|
|
288
303
|
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
289
304
|
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
290
305
|
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
306
|
+
if ocr_enable:
|
307
|
+
ocr_result_list.append({
|
308
|
+
'category_id': 15,
|
309
|
+
'poly': p1 + p2 + p3 + p4,
|
310
|
+
'score': 1,
|
311
|
+
'text': text,
|
312
|
+
'np_img': img_crop,
|
313
|
+
'lang': lang,
|
314
|
+
})
|
315
|
+
else:
|
316
|
+
ocr_result_list.append({
|
317
|
+
'category_id': 15,
|
318
|
+
'poly': p1 + p2 + p3 + p4,
|
319
|
+
'score': float(round(score, 2)),
|
320
|
+
'text': text,
|
321
|
+
})
|
297
322
|
|
298
323
|
return ocr_result_list
|
299
324
|
|
@@ -308,56 +333,36 @@ def calculate_is_angle(poly):
|
|
308
333
|
return True
|
309
334
|
|
310
335
|
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
"use_dilation": key[2],
|
345
|
-
"det_db_unclip_ratio": key[3],
|
346
|
-
}
|
347
|
-
|
348
|
-
if key[0] is not None:
|
349
|
-
additional_ocr_params["lang"] = key[0]
|
350
|
-
|
351
|
-
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
|
352
|
-
|
353
|
-
onnx_model = PaddleOCR(**additional_ocr_params)
|
354
|
-
|
355
|
-
if onnx_model is None:
|
356
|
-
logger.error('model init failed')
|
357
|
-
exit(1)
|
358
|
-
else:
|
359
|
-
return onnx_model
|
360
|
-
|
361
|
-
except Exception as e:
|
362
|
-
logger.exception(f'Error initializing model: {e}')
|
363
|
-
exit(1)
|
336
|
+
def get_rotate_crop_image(img, points):
|
337
|
+
'''
|
338
|
+
img_height, img_width = img.shape[0:2]
|
339
|
+
left = int(np.min(points[:, 0]))
|
340
|
+
right = int(np.max(points[:, 0]))
|
341
|
+
top = int(np.min(points[:, 1]))
|
342
|
+
bottom = int(np.max(points[:, 1]))
|
343
|
+
img_crop = img[top:bottom, left:right, :].copy()
|
344
|
+
points[:, 0] = points[:, 0] - left
|
345
|
+
points[:, 1] = points[:, 1] - top
|
346
|
+
'''
|
347
|
+
assert len(points) == 4, "shape of points must be 4*2"
|
348
|
+
img_crop_width = int(
|
349
|
+
max(
|
350
|
+
np.linalg.norm(points[0] - points[1]),
|
351
|
+
np.linalg.norm(points[2] - points[3])))
|
352
|
+
img_crop_height = int(
|
353
|
+
max(
|
354
|
+
np.linalg.norm(points[0] - points[3]),
|
355
|
+
np.linalg.norm(points[1] - points[2])))
|
356
|
+
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
357
|
+
[img_crop_width, img_crop_height],
|
358
|
+
[0, img_crop_height]])
|
359
|
+
M = cv2.getPerspectiveTransform(points, pts_std)
|
360
|
+
dst_img = cv2.warpPerspective(
|
361
|
+
img,
|
362
|
+
M, (img_crop_width, img_crop_height),
|
363
|
+
borderMode=cv2.BORDER_REPLICATE,
|
364
|
+
flags=cv2.INTER_CUBIC)
|
365
|
+
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
366
|
+
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
367
|
+
dst_img = np.rot90(dst_img)
|
368
|
+
return dst_img
|
@@ -0,0 +1,193 @@
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
2
|
+
import copy
|
3
|
+
import os.path
|
4
|
+
import warnings
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import cv2
|
8
|
+
import numpy as np
|
9
|
+
import yaml
|
10
|
+
from loguru import logger
|
11
|
+
|
12
|
+
from magic_pdf.libs.config_reader import get_device, get_local_models_dir
|
13
|
+
from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
14
|
+
from .tools.infer.predict_system import TextSystem
|
15
|
+
from .tools.infer import pytorchocr_utility as utility
|
16
|
+
import argparse
|
17
|
+
|
18
|
+
|
19
|
+
latin_lang = [
|
20
|
+
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
|
21
|
+
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
|
22
|
+
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
|
23
|
+
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
|
24
|
+
]
|
25
|
+
arabic_lang = ['ar', 'fa', 'ug', 'ur']
|
26
|
+
cyrillic_lang = [
|
27
|
+
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
|
28
|
+
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
|
29
|
+
]
|
30
|
+
devanagari_lang = [
|
31
|
+
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
|
32
|
+
'sa', 'bgc'
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
def get_model_params(lang, config):
|
37
|
+
if lang in config['lang']:
|
38
|
+
params = config['lang'][lang]
|
39
|
+
det = params.get('det')
|
40
|
+
rec = params.get('rec')
|
41
|
+
dict_file = params.get('dict')
|
42
|
+
return det, rec, dict_file
|
43
|
+
else:
|
44
|
+
raise Exception (f'Language {lang} not supported')
|
45
|
+
|
46
|
+
|
47
|
+
root_dir = Path(__file__).resolve().parent
|
48
|
+
|
49
|
+
|
50
|
+
class PytorchPaddleOCR(TextSystem):
|
51
|
+
def __init__(self, *args, **kwargs):
|
52
|
+
parser = utility.init_args()
|
53
|
+
args = parser.parse_args(args)
|
54
|
+
|
55
|
+
self.lang = kwargs.get('lang', 'ch')
|
56
|
+
if self.lang in latin_lang:
|
57
|
+
self.lang = 'latin'
|
58
|
+
elif self.lang in arabic_lang:
|
59
|
+
self.lang = 'arabic'
|
60
|
+
elif self.lang in cyrillic_lang:
|
61
|
+
self.lang = 'cyrillic'
|
62
|
+
elif self.lang in devanagari_lang:
|
63
|
+
self.lang = 'devanagari'
|
64
|
+
else:
|
65
|
+
pass
|
66
|
+
|
67
|
+
models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
|
68
|
+
with open(models_config_path) as file:
|
69
|
+
config = yaml.safe_load(file)
|
70
|
+
det, rec, dict_file = get_model_params(self.lang, config)
|
71
|
+
ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch')
|
72
|
+
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
|
73
|
+
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
|
74
|
+
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
|
75
|
+
# kwargs['rec_batch_num'] = 8
|
76
|
+
|
77
|
+
kwargs['device'] = get_device()
|
78
|
+
|
79
|
+
default_args = vars(args)
|
80
|
+
default_args.update(kwargs)
|
81
|
+
args = argparse.Namespace(**default_args)
|
82
|
+
|
83
|
+
super().__init__(args)
|
84
|
+
|
85
|
+
def ocr(self,
|
86
|
+
img,
|
87
|
+
det=True,
|
88
|
+
rec=True,
|
89
|
+
mfd_res=None,
|
90
|
+
tqdm_enable=False,
|
91
|
+
):
|
92
|
+
assert isinstance(img, (np.ndarray, list, str, bytes))
|
93
|
+
if isinstance(img, list) and det == True:
|
94
|
+
logger.error('When input a list of images, det must be false')
|
95
|
+
exit(0)
|
96
|
+
img = check_img(img)
|
97
|
+
imgs = [img]
|
98
|
+
with warnings.catch_warnings():
|
99
|
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
100
|
+
if det and rec:
|
101
|
+
ocr_res = []
|
102
|
+
for img in imgs:
|
103
|
+
img = preprocess_image(img)
|
104
|
+
dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
|
105
|
+
if not dt_boxes and not rec_res:
|
106
|
+
ocr_res.append(None)
|
107
|
+
continue
|
108
|
+
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
109
|
+
ocr_res.append(tmp_res)
|
110
|
+
return ocr_res
|
111
|
+
elif det and not rec:
|
112
|
+
ocr_res = []
|
113
|
+
for img in imgs:
|
114
|
+
img = preprocess_image(img)
|
115
|
+
dt_boxes, elapse = self.text_detector(img)
|
116
|
+
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
|
117
|
+
if dt_boxes is None:
|
118
|
+
ocr_res.append(None)
|
119
|
+
continue
|
120
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
121
|
+
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
|
122
|
+
dt_boxes = merge_det_boxes(dt_boxes)
|
123
|
+
if mfd_res:
|
124
|
+
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
125
|
+
tmp_res = [box.tolist() for box in dt_boxes]
|
126
|
+
ocr_res.append(tmp_res)
|
127
|
+
return ocr_res
|
128
|
+
elif not det and rec:
|
129
|
+
ocr_res = []
|
130
|
+
for img in imgs:
|
131
|
+
if not isinstance(img, list):
|
132
|
+
img = preprocess_image(img)
|
133
|
+
img = [img]
|
134
|
+
rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
|
135
|
+
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
|
136
|
+
ocr_res.append(rec_res)
|
137
|
+
return ocr_res
|
138
|
+
|
139
|
+
def __call__(self, img, mfd_res=None):
|
140
|
+
|
141
|
+
if img is None:
|
142
|
+
logger.debug("no valid image provided")
|
143
|
+
return None, None
|
144
|
+
|
145
|
+
ori_im = img.copy()
|
146
|
+
dt_boxes, elapse = self.text_detector(img)
|
147
|
+
|
148
|
+
if dt_boxes is None:
|
149
|
+
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
|
150
|
+
return None, None
|
151
|
+
else:
|
152
|
+
pass
|
153
|
+
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
|
154
|
+
img_crop_list = []
|
155
|
+
|
156
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
157
|
+
|
158
|
+
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
|
159
|
+
dt_boxes = merge_det_boxes(dt_boxes)
|
160
|
+
|
161
|
+
if mfd_res:
|
162
|
+
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
163
|
+
|
164
|
+
for bno in range(len(dt_boxes)):
|
165
|
+
tmp_box = copy.deepcopy(dt_boxes[bno])
|
166
|
+
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
167
|
+
img_crop_list.append(img_crop)
|
168
|
+
|
169
|
+
rec_res, elapse = self.text_recognizer(img_crop_list)
|
170
|
+
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
|
171
|
+
|
172
|
+
filter_boxes, filter_rec_res = [], []
|
173
|
+
for box, rec_result in zip(dt_boxes, rec_res):
|
174
|
+
text, score = rec_result
|
175
|
+
if score >= self.drop_score:
|
176
|
+
filter_boxes.append(box)
|
177
|
+
filter_rec_res.append(rec_result)
|
178
|
+
|
179
|
+
return filter_boxes, filter_rec_res
|
180
|
+
|
181
|
+
if __name__ == '__main__':
|
182
|
+
pytorch_paddle_ocr = PytorchPaddleOCR()
|
183
|
+
img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
|
184
|
+
dt_boxes, rec_res = pytorch_paddle_ocr(img)
|
185
|
+
ocr_res = []
|
186
|
+
if not dt_boxes and not rec_res:
|
187
|
+
ocr_res.append(None)
|
188
|
+
else:
|
189
|
+
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
190
|
+
ocr_res.append(tmp_res)
|
191
|
+
print(ocr_res)
|
192
|
+
|
193
|
+
|