magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/data/batch_build_dataset.py +156 -0
- magic_pdf/data/dataset.py +44 -24
- magic_pdf/data/utils.py +108 -9
- magic_pdf/dict2md/ocr_mkcontent.py +4 -3
- magic_pdf/libs/pdf_image_tools.py +11 -6
- magic_pdf/libs/performance_stats.py +12 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +175 -201
- magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
- magic_pdf/model/pdf_extract_kit.py +5 -38
- magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
- magic_pdf/model/sub_modules/model_init.py +50 -37
- magic_pdf/model/sub_modules/model_utils.py +17 -11
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
- magic_pdf/pdf_parse_union_core_v2.py +112 -74
- magic_pdf/post_proc/para_split_v3.py +16 -13
- magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
- magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
- magic_pdf/resources/model_config/model_configs.yaml +1 -1
- magic_pdf/tools/cli.py +30 -12
- magic_pdf/tools/common.py +90 -12
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
- magic_pdf-1.3.0.dist-info/RECORD +202 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
- magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
- magic_pdf-1.2.1.dist-info/RECORD +0 -147
- /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
- /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
- /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,39 @@
|
|
1
|
+
import os
|
2
|
+
import torch
|
3
|
+
from .modeling.architectures.base_model import BaseModel
|
4
|
+
|
5
|
+
class BaseOCRV20:
|
6
|
+
def __init__(self, config, **kwargs):
|
7
|
+
self.config = config
|
8
|
+
self.build_net(**kwargs)
|
9
|
+
self.net.eval()
|
10
|
+
|
11
|
+
|
12
|
+
def build_net(self, **kwargs):
|
13
|
+
self.net = BaseModel(self.config, **kwargs)
|
14
|
+
|
15
|
+
def read_pytorch_weights(self, weights_path):
|
16
|
+
if not os.path.exists(weights_path):
|
17
|
+
raise FileNotFoundError('{} is not existed.'.format(weights_path))
|
18
|
+
weights = torch.load(weights_path)
|
19
|
+
return weights
|
20
|
+
|
21
|
+
def get_out_channels(self, weights):
|
22
|
+
if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
|
23
|
+
out_channels = list(weights.values())[-1].numpy().shape[1]
|
24
|
+
else:
|
25
|
+
out_channels = list(weights.values())[-1].numpy().shape[0]
|
26
|
+
return out_channels
|
27
|
+
|
28
|
+
def load_state_dict(self, weights):
|
29
|
+
self.net.load_state_dict(weights)
|
30
|
+
# print('weights is loaded.')
|
31
|
+
|
32
|
+
def load_pytorch_weights(self, weights_path):
|
33
|
+
self.net.load_state_dict(torch.load(weights_path, weights_only=True))
|
34
|
+
# print('model is loaded: {}'.format(weights_path))
|
35
|
+
|
36
|
+
def inference(self, inputs):
|
37
|
+
with torch.no_grad():
|
38
|
+
infer = self.net(inputs)
|
39
|
+
return infer
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from __future__ import absolute_import
|
2
|
+
from __future__ import division
|
3
|
+
from __future__ import print_function
|
4
|
+
from __future__ import unicode_literals
|
5
|
+
|
6
|
+
# from .iaa_augment import IaaAugment
|
7
|
+
# from .make_border_map import MakeBorderMap
|
8
|
+
# from .make_shrink_map import MakeShrinkMap
|
9
|
+
# from .random_crop_data import EastRandomCropData, PSERandomCrop
|
10
|
+
|
11
|
+
# from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
|
12
|
+
# from .randaugment import RandAugment
|
13
|
+
from .operators import *
|
14
|
+
# from .label_ops import *
|
15
|
+
|
16
|
+
# from .east_process import *
|
17
|
+
# from .sast_process import *
|
18
|
+
# from .gen_table_mask import *
|
19
|
+
|
20
|
+
def transform(data, ops=None):
|
21
|
+
""" transform """
|
22
|
+
if ops is None:
|
23
|
+
ops = []
|
24
|
+
for op in ops:
|
25
|
+
data = op(data)
|
26
|
+
if data is None:
|
27
|
+
return None
|
28
|
+
return data
|
29
|
+
|
30
|
+
|
31
|
+
def create_operators(op_param_list, global_config=None):
|
32
|
+
"""
|
33
|
+
create operators based on the config
|
34
|
+
Args:
|
35
|
+
params(list): a dict list, used to create some operators
|
36
|
+
"""
|
37
|
+
assert isinstance(op_param_list, list), ('operator config should be a list')
|
38
|
+
ops = []
|
39
|
+
for operator in op_param_list:
|
40
|
+
assert isinstance(operator,
|
41
|
+
dict) and len(operator) == 1, "yaml format error"
|
42
|
+
op_name = list(operator)[0]
|
43
|
+
param = {} if operator[op_name] is None else operator[op_name]
|
44
|
+
if global_config is not None:
|
45
|
+
param.update(global_config)
|
46
|
+
op = eval(op_name)(**param)
|
47
|
+
ops.append(op)
|
48
|
+
return ops
|
@@ -0,0 +1,418 @@
|
|
1
|
+
"""
|
2
|
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
"""
|
16
|
+
|
17
|
+
from __future__ import absolute_import
|
18
|
+
from __future__ import division
|
19
|
+
from __future__ import print_function
|
20
|
+
from __future__ import unicode_literals
|
21
|
+
|
22
|
+
import sys
|
23
|
+
import six
|
24
|
+
import cv2
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
|
28
|
+
class DecodeImage(object):
|
29
|
+
""" decode image """
|
30
|
+
|
31
|
+
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
32
|
+
self.img_mode = img_mode
|
33
|
+
self.channel_first = channel_first
|
34
|
+
|
35
|
+
def __call__(self, data):
|
36
|
+
img = data['image']
|
37
|
+
if six.PY2:
|
38
|
+
assert type(img) is str and len(
|
39
|
+
img) > 0, "invalid input 'img' in DecodeImage"
|
40
|
+
else:
|
41
|
+
assert type(img) is bytes and len(
|
42
|
+
img) > 0, "invalid input 'img' in DecodeImage"
|
43
|
+
img = np.frombuffer(img, dtype='uint8')
|
44
|
+
img = cv2.imdecode(img, 1)
|
45
|
+
if img is None:
|
46
|
+
return None
|
47
|
+
if self.img_mode == 'GRAY':
|
48
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
49
|
+
elif self.img_mode == 'RGB':
|
50
|
+
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
51
|
+
img = img[:, :, ::-1]
|
52
|
+
|
53
|
+
if self.channel_first:
|
54
|
+
img = img.transpose((2, 0, 1))
|
55
|
+
|
56
|
+
data['image'] = img
|
57
|
+
return data
|
58
|
+
|
59
|
+
|
60
|
+
class NRTRDecodeImage(object):
|
61
|
+
""" decode image """
|
62
|
+
|
63
|
+
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
64
|
+
self.img_mode = img_mode
|
65
|
+
self.channel_first = channel_first
|
66
|
+
|
67
|
+
def __call__(self, data):
|
68
|
+
img = data['image']
|
69
|
+
if six.PY2:
|
70
|
+
assert type(img) is str and len(
|
71
|
+
img) > 0, "invalid input 'img' in DecodeImage"
|
72
|
+
else:
|
73
|
+
assert type(img) is bytes and len(
|
74
|
+
img) > 0, "invalid input 'img' in DecodeImage"
|
75
|
+
img = np.frombuffer(img, dtype='uint8')
|
76
|
+
|
77
|
+
img = cv2.imdecode(img, 1)
|
78
|
+
|
79
|
+
if img is None:
|
80
|
+
return None
|
81
|
+
if self.img_mode == 'GRAY':
|
82
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
83
|
+
elif self.img_mode == 'RGB':
|
84
|
+
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
85
|
+
img = img[:, :, ::-1]
|
86
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
87
|
+
if self.channel_first:
|
88
|
+
img = img.transpose((2, 0, 1))
|
89
|
+
data['image'] = img
|
90
|
+
return data
|
91
|
+
|
92
|
+
|
93
|
+
class NormalizeImage(object):
|
94
|
+
""" normalize image such as substract mean, divide std
|
95
|
+
"""
|
96
|
+
|
97
|
+
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
98
|
+
if isinstance(scale, str):
|
99
|
+
scale = eval(scale)
|
100
|
+
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
101
|
+
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
102
|
+
std = std if std is not None else [0.229, 0.224, 0.225]
|
103
|
+
|
104
|
+
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
105
|
+
self.mean = np.array(mean).reshape(shape).astype('float32')
|
106
|
+
self.std = np.array(std).reshape(shape).astype('float32')
|
107
|
+
|
108
|
+
def __call__(self, data):
|
109
|
+
img = data['image']
|
110
|
+
from PIL import Image
|
111
|
+
if isinstance(img, Image.Image):
|
112
|
+
img = np.array(img)
|
113
|
+
assert isinstance(img,
|
114
|
+
np.ndarray), "invalid input 'img' in NormalizeImage"
|
115
|
+
data['image'] = (
|
116
|
+
img.astype('float32') * self.scale - self.mean) / self.std
|
117
|
+
return data
|
118
|
+
|
119
|
+
|
120
|
+
class ToCHWImage(object):
|
121
|
+
""" convert hwc image to chw image
|
122
|
+
"""
|
123
|
+
|
124
|
+
def __init__(self, **kwargs):
|
125
|
+
pass
|
126
|
+
|
127
|
+
def __call__(self, data):
|
128
|
+
img = data['image']
|
129
|
+
from PIL import Image
|
130
|
+
if isinstance(img, Image.Image):
|
131
|
+
img = np.array(img)
|
132
|
+
data['image'] = img.transpose((2, 0, 1))
|
133
|
+
return data
|
134
|
+
|
135
|
+
|
136
|
+
class Fasttext(object):
|
137
|
+
def __init__(self, path="None", **kwargs):
|
138
|
+
import fasttext
|
139
|
+
self.fast_model = fasttext.load_model(path)
|
140
|
+
|
141
|
+
def __call__(self, data):
|
142
|
+
label = data['label']
|
143
|
+
fast_label = self.fast_model[label]
|
144
|
+
data['fast_label'] = fast_label
|
145
|
+
return data
|
146
|
+
|
147
|
+
|
148
|
+
class KeepKeys(object):
|
149
|
+
def __init__(self, keep_keys, **kwargs):
|
150
|
+
self.keep_keys = keep_keys
|
151
|
+
|
152
|
+
def __call__(self, data):
|
153
|
+
data_list = []
|
154
|
+
for key in self.keep_keys:
|
155
|
+
data_list.append(data[key])
|
156
|
+
return data_list
|
157
|
+
|
158
|
+
|
159
|
+
class Resize(object):
|
160
|
+
def __init__(self, size=(640, 640), **kwargs):
|
161
|
+
self.size = size
|
162
|
+
|
163
|
+
def resize_image(self, img):
|
164
|
+
resize_h, resize_w = self.size
|
165
|
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
166
|
+
ratio_h = float(resize_h) / ori_h
|
167
|
+
ratio_w = float(resize_w) / ori_w
|
168
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
169
|
+
return img, [ratio_h, ratio_w]
|
170
|
+
|
171
|
+
def __call__(self, data):
|
172
|
+
img = data['image']
|
173
|
+
text_polys = data['polys']
|
174
|
+
|
175
|
+
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
176
|
+
new_boxes = []
|
177
|
+
for box in text_polys:
|
178
|
+
new_box = []
|
179
|
+
for cord in box:
|
180
|
+
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
181
|
+
new_boxes.append(new_box)
|
182
|
+
data['image'] = img_resize
|
183
|
+
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
184
|
+
return data
|
185
|
+
|
186
|
+
|
187
|
+
class DetResizeForTest(object):
|
188
|
+
def __init__(self, **kwargs):
|
189
|
+
super(DetResizeForTest, self).__init__()
|
190
|
+
self.resize_type = 0
|
191
|
+
if 'image_shape' in kwargs:
|
192
|
+
self.image_shape = kwargs['image_shape']
|
193
|
+
self.resize_type = 1
|
194
|
+
elif 'limit_side_len' in kwargs:
|
195
|
+
self.limit_side_len = kwargs['limit_side_len']
|
196
|
+
self.limit_type = kwargs.get('limit_type', 'min')
|
197
|
+
elif 'resize_long' in kwargs:
|
198
|
+
self.resize_type = 2
|
199
|
+
self.resize_long = kwargs.get('resize_long', 960)
|
200
|
+
else:
|
201
|
+
self.limit_side_len = 736
|
202
|
+
self.limit_type = 'min'
|
203
|
+
|
204
|
+
def __call__(self, data):
|
205
|
+
img = data['image']
|
206
|
+
src_h, src_w, _ = img.shape
|
207
|
+
|
208
|
+
if self.resize_type == 0:
|
209
|
+
# img, shape = self.resize_image_type0(img)
|
210
|
+
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
211
|
+
elif self.resize_type == 2:
|
212
|
+
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
213
|
+
else:
|
214
|
+
# img, shape = self.resize_image_type1(img)
|
215
|
+
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
216
|
+
data['image'] = img
|
217
|
+
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
218
|
+
return data
|
219
|
+
|
220
|
+
def resize_image_type1(self, img):
|
221
|
+
resize_h, resize_w = self.image_shape
|
222
|
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
223
|
+
ratio_h = float(resize_h) / ori_h
|
224
|
+
ratio_w = float(resize_w) / ori_w
|
225
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
226
|
+
# return img, np.array([ori_h, ori_w])
|
227
|
+
return img, [ratio_h, ratio_w]
|
228
|
+
|
229
|
+
def resize_image_type0(self, img):
|
230
|
+
"""
|
231
|
+
resize image to a size multiple of 32 which is required by the network
|
232
|
+
args:
|
233
|
+
img(array): array with shape [h, w, c]
|
234
|
+
return(tuple):
|
235
|
+
img, (ratio_h, ratio_w)
|
236
|
+
"""
|
237
|
+
limit_side_len = self.limit_side_len
|
238
|
+
h, w, c = img.shape
|
239
|
+
|
240
|
+
# limit the max side
|
241
|
+
if self.limit_type == 'max':
|
242
|
+
if max(h, w) > limit_side_len:
|
243
|
+
if h > w:
|
244
|
+
ratio = float(limit_side_len) / h
|
245
|
+
else:
|
246
|
+
ratio = float(limit_side_len) / w
|
247
|
+
else:
|
248
|
+
ratio = 1.
|
249
|
+
elif self.limit_type == 'min':
|
250
|
+
if min(h, w) < limit_side_len:
|
251
|
+
if h < w:
|
252
|
+
ratio = float(limit_side_len) / h
|
253
|
+
else:
|
254
|
+
ratio = float(limit_side_len) / w
|
255
|
+
else:
|
256
|
+
ratio = 1.
|
257
|
+
elif self.limit_type == 'resize_long':
|
258
|
+
ratio = float(limit_side_len) / max(h, w)
|
259
|
+
else:
|
260
|
+
raise Exception('not support limit type, image ')
|
261
|
+
resize_h = int(h * ratio)
|
262
|
+
resize_w = int(w * ratio)
|
263
|
+
|
264
|
+
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
265
|
+
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
266
|
+
|
267
|
+
try:
|
268
|
+
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
269
|
+
return None, (None, None)
|
270
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
271
|
+
except:
|
272
|
+
print(img.shape, resize_w, resize_h)
|
273
|
+
sys.exit(0)
|
274
|
+
ratio_h = resize_h / float(h)
|
275
|
+
ratio_w = resize_w / float(w)
|
276
|
+
return img, [ratio_h, ratio_w]
|
277
|
+
|
278
|
+
def resize_image_type2(self, img):
|
279
|
+
h, w, _ = img.shape
|
280
|
+
|
281
|
+
resize_w = w
|
282
|
+
resize_h = h
|
283
|
+
|
284
|
+
if resize_h > resize_w:
|
285
|
+
ratio = float(self.resize_long) / resize_h
|
286
|
+
else:
|
287
|
+
ratio = float(self.resize_long) / resize_w
|
288
|
+
|
289
|
+
resize_h = int(resize_h * ratio)
|
290
|
+
resize_w = int(resize_w * ratio)
|
291
|
+
|
292
|
+
max_stride = 128
|
293
|
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
294
|
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
295
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
296
|
+
ratio_h = resize_h / float(h)
|
297
|
+
ratio_w = resize_w / float(w)
|
298
|
+
|
299
|
+
return img, [ratio_h, ratio_w]
|
300
|
+
|
301
|
+
|
302
|
+
class E2EResizeForTest(object):
|
303
|
+
def __init__(self, **kwargs):
|
304
|
+
super(E2EResizeForTest, self).__init__()
|
305
|
+
self.max_side_len = kwargs['max_side_len']
|
306
|
+
self.valid_set = kwargs['valid_set']
|
307
|
+
|
308
|
+
def __call__(self, data):
|
309
|
+
img = data['image']
|
310
|
+
src_h, src_w, _ = img.shape
|
311
|
+
if self.valid_set == 'totaltext':
|
312
|
+
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
313
|
+
img, max_side_len=self.max_side_len)
|
314
|
+
else:
|
315
|
+
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
316
|
+
img, max_side_len=self.max_side_len)
|
317
|
+
data['image'] = im_resized
|
318
|
+
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
319
|
+
return data
|
320
|
+
|
321
|
+
def resize_image_for_totaltext(self, im, max_side_len=512):
|
322
|
+
|
323
|
+
h, w, _ = im.shape
|
324
|
+
resize_w = w
|
325
|
+
resize_h = h
|
326
|
+
ratio = 1.25
|
327
|
+
if h * ratio > max_side_len:
|
328
|
+
ratio = float(max_side_len) / resize_h
|
329
|
+
resize_h = int(resize_h * ratio)
|
330
|
+
resize_w = int(resize_w * ratio)
|
331
|
+
|
332
|
+
max_stride = 128
|
333
|
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
334
|
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
335
|
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
336
|
+
ratio_h = resize_h / float(h)
|
337
|
+
ratio_w = resize_w / float(w)
|
338
|
+
return im, (ratio_h, ratio_w)
|
339
|
+
|
340
|
+
def resize_image(self, im, max_side_len=512):
|
341
|
+
"""
|
342
|
+
resize image to a size multiple of max_stride which is required by the network
|
343
|
+
:param im: the resized image
|
344
|
+
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
345
|
+
:return: the resized image and the resize ratio
|
346
|
+
"""
|
347
|
+
h, w, _ = im.shape
|
348
|
+
|
349
|
+
resize_w = w
|
350
|
+
resize_h = h
|
351
|
+
|
352
|
+
# Fix the longer side
|
353
|
+
if resize_h > resize_w:
|
354
|
+
ratio = float(max_side_len) / resize_h
|
355
|
+
else:
|
356
|
+
ratio = float(max_side_len) / resize_w
|
357
|
+
|
358
|
+
resize_h = int(resize_h * ratio)
|
359
|
+
resize_w = int(resize_w * ratio)
|
360
|
+
|
361
|
+
max_stride = 128
|
362
|
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
363
|
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
364
|
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
365
|
+
ratio_h = resize_h / float(h)
|
366
|
+
ratio_w = resize_w / float(w)
|
367
|
+
|
368
|
+
return im, (ratio_h, ratio_w)
|
369
|
+
|
370
|
+
|
371
|
+
class KieResize(object):
|
372
|
+
def __init__(self, **kwargs):
|
373
|
+
super(KieResize, self).__init__()
|
374
|
+
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
375
|
+
'img_scale'][1]
|
376
|
+
|
377
|
+
def __call__(self, data):
|
378
|
+
img = data['image']
|
379
|
+
points = data['points']
|
380
|
+
src_h, src_w, _ = img.shape
|
381
|
+
im_resized, scale_factor, [ratio_h, ratio_w
|
382
|
+
], [new_h, new_w] = self.resize_image(img)
|
383
|
+
resize_points = self.resize_boxes(img, points, scale_factor)
|
384
|
+
data['ori_image'] = img
|
385
|
+
data['ori_boxes'] = points
|
386
|
+
data['points'] = resize_points
|
387
|
+
data['image'] = im_resized
|
388
|
+
data['shape'] = np.array([new_h, new_w])
|
389
|
+
return data
|
390
|
+
|
391
|
+
def resize_image(self, img):
|
392
|
+
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
393
|
+
scale = [512, 1024]
|
394
|
+
h, w = img.shape[:2]
|
395
|
+
max_long_edge = max(scale)
|
396
|
+
max_short_edge = min(scale)
|
397
|
+
scale_factor = min(max_long_edge / max(h, w),
|
398
|
+
max_short_edge / min(h, w))
|
399
|
+
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
400
|
+
scale_factor) + 0.5)
|
401
|
+
max_stride = 32
|
402
|
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
403
|
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
404
|
+
im = cv2.resize(img, (resize_w, resize_h))
|
405
|
+
new_h, new_w = im.shape[:2]
|
406
|
+
w_scale = new_w / w
|
407
|
+
h_scale = new_h / h
|
408
|
+
scale_factor = np.array(
|
409
|
+
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
410
|
+
norm_img[:new_h, :new_w, :] = im
|
411
|
+
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
412
|
+
|
413
|
+
def resize_boxes(self, im, points, scale_factor):
|
414
|
+
points = points * scale_factor
|
415
|
+
img_shape = im.shape[:2]
|
416
|
+
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
417
|
+
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
418
|
+
return points
|
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
|
17
|
+
__all__ = ["build_model"]
|
18
|
+
|
19
|
+
|
20
|
+
def build_model(config, **kwargs):
|
21
|
+
from .base_model import BaseModel
|
22
|
+
|
23
|
+
config = copy.deepcopy(config)
|
24
|
+
module_class = BaseModel(config, **kwargs)
|
25
|
+
return module_class
|
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
from torch import nn
|
2
|
+
|
3
|
+
from ..backbones import build_backbone
|
4
|
+
from ..heads import build_head
|
5
|
+
from ..necks import build_neck
|
6
|
+
|
7
|
+
|
8
|
+
class BaseModel(nn.Module):
|
9
|
+
def __init__(self, config, **kwargs):
|
10
|
+
"""
|
11
|
+
the module for OCR.
|
12
|
+
args:
|
13
|
+
config (dict): the super parameters for module.
|
14
|
+
"""
|
15
|
+
super(BaseModel, self).__init__()
|
16
|
+
|
17
|
+
in_channels = config.get("in_channels", 3)
|
18
|
+
model_type = config["model_type"]
|
19
|
+
# build backbone, backbone is need for del, rec and cls
|
20
|
+
if "Backbone" not in config or config["Backbone"] is None:
|
21
|
+
self.use_backbone = False
|
22
|
+
else:
|
23
|
+
self.use_backbone = True
|
24
|
+
config["Backbone"]["in_channels"] = in_channels
|
25
|
+
self.backbone = build_backbone(config["Backbone"], model_type)
|
26
|
+
in_channels = self.backbone.out_channels
|
27
|
+
|
28
|
+
# build neck
|
29
|
+
# for rec, neck can be cnn,rnn or reshape(None)
|
30
|
+
# for det, neck can be FPN, BIFPN and so on.
|
31
|
+
# for cls, neck should be none
|
32
|
+
if "Neck" not in config or config["Neck"] is None:
|
33
|
+
self.use_neck = False
|
34
|
+
else:
|
35
|
+
self.use_neck = True
|
36
|
+
config["Neck"]["in_channels"] = in_channels
|
37
|
+
self.neck = build_neck(config["Neck"])
|
38
|
+
in_channels = self.neck.out_channels
|
39
|
+
|
40
|
+
# # build head, head is need for det, rec and cls
|
41
|
+
if "Head" not in config or config["Head"] is None:
|
42
|
+
self.use_head = False
|
43
|
+
else:
|
44
|
+
self.use_head = True
|
45
|
+
config["Head"]["in_channels"] = in_channels
|
46
|
+
self.head = build_head(config["Head"], **kwargs)
|
47
|
+
|
48
|
+
self.return_all_feats = config.get("return_all_feats", False)
|
49
|
+
|
50
|
+
self._initialize_weights()
|
51
|
+
|
52
|
+
def _initialize_weights(self):
|
53
|
+
# weight initialization
|
54
|
+
for m in self.modules():
|
55
|
+
if isinstance(m, nn.Conv2d):
|
56
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
57
|
+
if m.bias is not None:
|
58
|
+
nn.init.zeros_(m.bias)
|
59
|
+
elif isinstance(m, nn.BatchNorm2d):
|
60
|
+
nn.init.ones_(m.weight)
|
61
|
+
nn.init.zeros_(m.bias)
|
62
|
+
elif isinstance(m, nn.Linear):
|
63
|
+
nn.init.normal_(m.weight, 0, 0.01)
|
64
|
+
if m.bias is not None:
|
65
|
+
nn.init.zeros_(m.bias)
|
66
|
+
elif isinstance(m, nn.ConvTranspose2d):
|
67
|
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
68
|
+
if m.bias is not None:
|
69
|
+
nn.init.zeros_(m.bias)
|
70
|
+
|
71
|
+
def forward(self, x):
|
72
|
+
y = dict()
|
73
|
+
if self.use_backbone:
|
74
|
+
x = self.backbone(x)
|
75
|
+
if isinstance(x, dict):
|
76
|
+
y.update(x)
|
77
|
+
else:
|
78
|
+
y["backbone_out"] = x
|
79
|
+
final_name = "backbone_out"
|
80
|
+
if self.use_neck:
|
81
|
+
x = self.neck(x)
|
82
|
+
if isinstance(x, dict):
|
83
|
+
y.update(x)
|
84
|
+
else:
|
85
|
+
y["neck_out"] = x
|
86
|
+
final_name = "neck_out"
|
87
|
+
if self.use_head:
|
88
|
+
x = self.head(x)
|
89
|
+
# for multi head, save ctc neck out for udml
|
90
|
+
if isinstance(x, dict) and "ctc_nect" in x.keys():
|
91
|
+
y["neck_out"] = x["ctc_neck"]
|
92
|
+
y["head_out"] = x
|
93
|
+
elif isinstance(x, dict):
|
94
|
+
y.update(x)
|
95
|
+
else:
|
96
|
+
y["head_out"] = x
|
97
|
+
if self.return_all_feats:
|
98
|
+
if self.training:
|
99
|
+
return y
|
100
|
+
elif isinstance(x, dict):
|
101
|
+
return x
|
102
|
+
else:
|
103
|
+
return {final_name: x}
|
104
|
+
else:
|
105
|
+
return x
|