magic-pdf 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/tools/cli.py +30 -12
  85. magic_pdf/tools/common.py +90 -12
  86. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +50 -40
  87. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  88. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  90. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  91. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  92. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  93. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  94. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  95. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  96. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  97. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  98. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  99. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,39 @@
1
+ import os
2
+ import torch
3
+ from .modeling.architectures.base_model import BaseModel
4
+
5
+ class BaseOCRV20:
6
+ def __init__(self, config, **kwargs):
7
+ self.config = config
8
+ self.build_net(**kwargs)
9
+ self.net.eval()
10
+
11
+
12
+ def build_net(self, **kwargs):
13
+ self.net = BaseModel(self.config, **kwargs)
14
+
15
+ def read_pytorch_weights(self, weights_path):
16
+ if not os.path.exists(weights_path):
17
+ raise FileNotFoundError('{} is not existed.'.format(weights_path))
18
+ weights = torch.load(weights_path)
19
+ return weights
20
+
21
+ def get_out_channels(self, weights):
22
+ if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
23
+ out_channels = list(weights.values())[-1].numpy().shape[1]
24
+ else:
25
+ out_channels = list(weights.values())[-1].numpy().shape[0]
26
+ return out_channels
27
+
28
+ def load_state_dict(self, weights):
29
+ self.net.load_state_dict(weights)
30
+ # print('weights is loaded.')
31
+
32
+ def load_pytorch_weights(self, weights_path):
33
+ self.net.load_state_dict(torch.load(weights_path, weights_only=True))
34
+ # print('model is loaded: {}'.format(weights_path))
35
+
36
+ def inference(self, inputs):
37
+ with torch.no_grad():
38
+ infer = self.net(inputs)
39
+ return infer
@@ -0,0 +1,8 @@
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ from __future__ import unicode_literals
5
+
6
+ from .imaug import transform, create_operators
7
+
8
+
@@ -0,0 +1,48 @@
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ from __future__ import unicode_literals
5
+
6
+ # from .iaa_augment import IaaAugment
7
+ # from .make_border_map import MakeBorderMap
8
+ # from .make_shrink_map import MakeShrinkMap
9
+ # from .random_crop_data import EastRandomCropData, PSERandomCrop
10
+
11
+ # from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
12
+ # from .randaugment import RandAugment
13
+ from .operators import *
14
+ # from .label_ops import *
15
+
16
+ # from .east_process import *
17
+ # from .sast_process import *
18
+ # from .gen_table_mask import *
19
+
20
+ def transform(data, ops=None):
21
+ """ transform """
22
+ if ops is None:
23
+ ops = []
24
+ for op in ops:
25
+ data = op(data)
26
+ if data is None:
27
+ return None
28
+ return data
29
+
30
+
31
+ def create_operators(op_param_list, global_config=None):
32
+ """
33
+ create operators based on the config
34
+ Args:
35
+ params(list): a dict list, used to create some operators
36
+ """
37
+ assert isinstance(op_param_list, list), ('operator config should be a list')
38
+ ops = []
39
+ for operator in op_param_list:
40
+ assert isinstance(operator,
41
+ dict) and len(operator) == 1, "yaml format error"
42
+ op_name = list(operator)[0]
43
+ param = {} if operator[op_name] is None else operator[op_name]
44
+ if global_config is not None:
45
+ param.update(global_config)
46
+ op = eval(op_name)(**param)
47
+ ops.append(op)
48
+ return ops
@@ -0,0 +1,418 @@
1
+ """
2
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+ from __future__ import unicode_literals
21
+
22
+ import sys
23
+ import six
24
+ import cv2
25
+ import numpy as np
26
+
27
+
28
+ class DecodeImage(object):
29
+ """ decode image """
30
+
31
+ def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
32
+ self.img_mode = img_mode
33
+ self.channel_first = channel_first
34
+
35
+ def __call__(self, data):
36
+ img = data['image']
37
+ if six.PY2:
38
+ assert type(img) is str and len(
39
+ img) > 0, "invalid input 'img' in DecodeImage"
40
+ else:
41
+ assert type(img) is bytes and len(
42
+ img) > 0, "invalid input 'img' in DecodeImage"
43
+ img = np.frombuffer(img, dtype='uint8')
44
+ img = cv2.imdecode(img, 1)
45
+ if img is None:
46
+ return None
47
+ if self.img_mode == 'GRAY':
48
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
49
+ elif self.img_mode == 'RGB':
50
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
51
+ img = img[:, :, ::-1]
52
+
53
+ if self.channel_first:
54
+ img = img.transpose((2, 0, 1))
55
+
56
+ data['image'] = img
57
+ return data
58
+
59
+
60
+ class NRTRDecodeImage(object):
61
+ """ decode image """
62
+
63
+ def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
64
+ self.img_mode = img_mode
65
+ self.channel_first = channel_first
66
+
67
+ def __call__(self, data):
68
+ img = data['image']
69
+ if six.PY2:
70
+ assert type(img) is str and len(
71
+ img) > 0, "invalid input 'img' in DecodeImage"
72
+ else:
73
+ assert type(img) is bytes and len(
74
+ img) > 0, "invalid input 'img' in DecodeImage"
75
+ img = np.frombuffer(img, dtype='uint8')
76
+
77
+ img = cv2.imdecode(img, 1)
78
+
79
+ if img is None:
80
+ return None
81
+ if self.img_mode == 'GRAY':
82
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
83
+ elif self.img_mode == 'RGB':
84
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
85
+ img = img[:, :, ::-1]
86
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
87
+ if self.channel_first:
88
+ img = img.transpose((2, 0, 1))
89
+ data['image'] = img
90
+ return data
91
+
92
+
93
+ class NormalizeImage(object):
94
+ """ normalize image such as substract mean, divide std
95
+ """
96
+
97
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
98
+ if isinstance(scale, str):
99
+ scale = eval(scale)
100
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
101
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
102
+ std = std if std is not None else [0.229, 0.224, 0.225]
103
+
104
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
105
+ self.mean = np.array(mean).reshape(shape).astype('float32')
106
+ self.std = np.array(std).reshape(shape).astype('float32')
107
+
108
+ def __call__(self, data):
109
+ img = data['image']
110
+ from PIL import Image
111
+ if isinstance(img, Image.Image):
112
+ img = np.array(img)
113
+ assert isinstance(img,
114
+ np.ndarray), "invalid input 'img' in NormalizeImage"
115
+ data['image'] = (
116
+ img.astype('float32') * self.scale - self.mean) / self.std
117
+ return data
118
+
119
+
120
+ class ToCHWImage(object):
121
+ """ convert hwc image to chw image
122
+ """
123
+
124
+ def __init__(self, **kwargs):
125
+ pass
126
+
127
+ def __call__(self, data):
128
+ img = data['image']
129
+ from PIL import Image
130
+ if isinstance(img, Image.Image):
131
+ img = np.array(img)
132
+ data['image'] = img.transpose((2, 0, 1))
133
+ return data
134
+
135
+
136
+ class Fasttext(object):
137
+ def __init__(self, path="None", **kwargs):
138
+ import fasttext
139
+ self.fast_model = fasttext.load_model(path)
140
+
141
+ def __call__(self, data):
142
+ label = data['label']
143
+ fast_label = self.fast_model[label]
144
+ data['fast_label'] = fast_label
145
+ return data
146
+
147
+
148
+ class KeepKeys(object):
149
+ def __init__(self, keep_keys, **kwargs):
150
+ self.keep_keys = keep_keys
151
+
152
+ def __call__(self, data):
153
+ data_list = []
154
+ for key in self.keep_keys:
155
+ data_list.append(data[key])
156
+ return data_list
157
+
158
+
159
+ class Resize(object):
160
+ def __init__(self, size=(640, 640), **kwargs):
161
+ self.size = size
162
+
163
+ def resize_image(self, img):
164
+ resize_h, resize_w = self.size
165
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
166
+ ratio_h = float(resize_h) / ori_h
167
+ ratio_w = float(resize_w) / ori_w
168
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
169
+ return img, [ratio_h, ratio_w]
170
+
171
+ def __call__(self, data):
172
+ img = data['image']
173
+ text_polys = data['polys']
174
+
175
+ img_resize, [ratio_h, ratio_w] = self.resize_image(img)
176
+ new_boxes = []
177
+ for box in text_polys:
178
+ new_box = []
179
+ for cord in box:
180
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
181
+ new_boxes.append(new_box)
182
+ data['image'] = img_resize
183
+ data['polys'] = np.array(new_boxes, dtype=np.float32)
184
+ return data
185
+
186
+
187
+ class DetResizeForTest(object):
188
+ def __init__(self, **kwargs):
189
+ super(DetResizeForTest, self).__init__()
190
+ self.resize_type = 0
191
+ if 'image_shape' in kwargs:
192
+ self.image_shape = kwargs['image_shape']
193
+ self.resize_type = 1
194
+ elif 'limit_side_len' in kwargs:
195
+ self.limit_side_len = kwargs['limit_side_len']
196
+ self.limit_type = kwargs.get('limit_type', 'min')
197
+ elif 'resize_long' in kwargs:
198
+ self.resize_type = 2
199
+ self.resize_long = kwargs.get('resize_long', 960)
200
+ else:
201
+ self.limit_side_len = 736
202
+ self.limit_type = 'min'
203
+
204
+ def __call__(self, data):
205
+ img = data['image']
206
+ src_h, src_w, _ = img.shape
207
+
208
+ if self.resize_type == 0:
209
+ # img, shape = self.resize_image_type0(img)
210
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
211
+ elif self.resize_type == 2:
212
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
213
+ else:
214
+ # img, shape = self.resize_image_type1(img)
215
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
216
+ data['image'] = img
217
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
218
+ return data
219
+
220
+ def resize_image_type1(self, img):
221
+ resize_h, resize_w = self.image_shape
222
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
223
+ ratio_h = float(resize_h) / ori_h
224
+ ratio_w = float(resize_w) / ori_w
225
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
226
+ # return img, np.array([ori_h, ori_w])
227
+ return img, [ratio_h, ratio_w]
228
+
229
+ def resize_image_type0(self, img):
230
+ """
231
+ resize image to a size multiple of 32 which is required by the network
232
+ args:
233
+ img(array): array with shape [h, w, c]
234
+ return(tuple):
235
+ img, (ratio_h, ratio_w)
236
+ """
237
+ limit_side_len = self.limit_side_len
238
+ h, w, c = img.shape
239
+
240
+ # limit the max side
241
+ if self.limit_type == 'max':
242
+ if max(h, w) > limit_side_len:
243
+ if h > w:
244
+ ratio = float(limit_side_len) / h
245
+ else:
246
+ ratio = float(limit_side_len) / w
247
+ else:
248
+ ratio = 1.
249
+ elif self.limit_type == 'min':
250
+ if min(h, w) < limit_side_len:
251
+ if h < w:
252
+ ratio = float(limit_side_len) / h
253
+ else:
254
+ ratio = float(limit_side_len) / w
255
+ else:
256
+ ratio = 1.
257
+ elif self.limit_type == 'resize_long':
258
+ ratio = float(limit_side_len) / max(h, w)
259
+ else:
260
+ raise Exception('not support limit type, image ')
261
+ resize_h = int(h * ratio)
262
+ resize_w = int(w * ratio)
263
+
264
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
265
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
266
+
267
+ try:
268
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
269
+ return None, (None, None)
270
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
271
+ except:
272
+ print(img.shape, resize_w, resize_h)
273
+ sys.exit(0)
274
+ ratio_h = resize_h / float(h)
275
+ ratio_w = resize_w / float(w)
276
+ return img, [ratio_h, ratio_w]
277
+
278
+ def resize_image_type2(self, img):
279
+ h, w, _ = img.shape
280
+
281
+ resize_w = w
282
+ resize_h = h
283
+
284
+ if resize_h > resize_w:
285
+ ratio = float(self.resize_long) / resize_h
286
+ else:
287
+ ratio = float(self.resize_long) / resize_w
288
+
289
+ resize_h = int(resize_h * ratio)
290
+ resize_w = int(resize_w * ratio)
291
+
292
+ max_stride = 128
293
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
294
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
295
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
296
+ ratio_h = resize_h / float(h)
297
+ ratio_w = resize_w / float(w)
298
+
299
+ return img, [ratio_h, ratio_w]
300
+
301
+
302
+ class E2EResizeForTest(object):
303
+ def __init__(self, **kwargs):
304
+ super(E2EResizeForTest, self).__init__()
305
+ self.max_side_len = kwargs['max_side_len']
306
+ self.valid_set = kwargs['valid_set']
307
+
308
+ def __call__(self, data):
309
+ img = data['image']
310
+ src_h, src_w, _ = img.shape
311
+ if self.valid_set == 'totaltext':
312
+ im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
313
+ img, max_side_len=self.max_side_len)
314
+ else:
315
+ im_resized, (ratio_h, ratio_w) = self.resize_image(
316
+ img, max_side_len=self.max_side_len)
317
+ data['image'] = im_resized
318
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
319
+ return data
320
+
321
+ def resize_image_for_totaltext(self, im, max_side_len=512):
322
+
323
+ h, w, _ = im.shape
324
+ resize_w = w
325
+ resize_h = h
326
+ ratio = 1.25
327
+ if h * ratio > max_side_len:
328
+ ratio = float(max_side_len) / resize_h
329
+ resize_h = int(resize_h * ratio)
330
+ resize_w = int(resize_w * ratio)
331
+
332
+ max_stride = 128
333
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
334
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
335
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
336
+ ratio_h = resize_h / float(h)
337
+ ratio_w = resize_w / float(w)
338
+ return im, (ratio_h, ratio_w)
339
+
340
+ def resize_image(self, im, max_side_len=512):
341
+ """
342
+ resize image to a size multiple of max_stride which is required by the network
343
+ :param im: the resized image
344
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
345
+ :return: the resized image and the resize ratio
346
+ """
347
+ h, w, _ = im.shape
348
+
349
+ resize_w = w
350
+ resize_h = h
351
+
352
+ # Fix the longer side
353
+ if resize_h > resize_w:
354
+ ratio = float(max_side_len) / resize_h
355
+ else:
356
+ ratio = float(max_side_len) / resize_w
357
+
358
+ resize_h = int(resize_h * ratio)
359
+ resize_w = int(resize_w * ratio)
360
+
361
+ max_stride = 128
362
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
363
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
364
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
365
+ ratio_h = resize_h / float(h)
366
+ ratio_w = resize_w / float(w)
367
+
368
+ return im, (ratio_h, ratio_w)
369
+
370
+
371
+ class KieResize(object):
372
+ def __init__(self, **kwargs):
373
+ super(KieResize, self).__init__()
374
+ self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
375
+ 'img_scale'][1]
376
+
377
+ def __call__(self, data):
378
+ img = data['image']
379
+ points = data['points']
380
+ src_h, src_w, _ = img.shape
381
+ im_resized, scale_factor, [ratio_h, ratio_w
382
+ ], [new_h, new_w] = self.resize_image(img)
383
+ resize_points = self.resize_boxes(img, points, scale_factor)
384
+ data['ori_image'] = img
385
+ data['ori_boxes'] = points
386
+ data['points'] = resize_points
387
+ data['image'] = im_resized
388
+ data['shape'] = np.array([new_h, new_w])
389
+ return data
390
+
391
+ def resize_image(self, img):
392
+ norm_img = np.zeros([1024, 1024, 3], dtype='float32')
393
+ scale = [512, 1024]
394
+ h, w = img.shape[:2]
395
+ max_long_edge = max(scale)
396
+ max_short_edge = min(scale)
397
+ scale_factor = min(max_long_edge / max(h, w),
398
+ max_short_edge / min(h, w))
399
+ resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
400
+ scale_factor) + 0.5)
401
+ max_stride = 32
402
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
403
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
404
+ im = cv2.resize(img, (resize_w, resize_h))
405
+ new_h, new_w = im.shape[:2]
406
+ w_scale = new_w / w
407
+ h_scale = new_h / h
408
+ scale_factor = np.array(
409
+ [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
410
+ norm_img[:new_h, :new_w, :] = im
411
+ return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
412
+
413
+ def resize_boxes(self, im, points, scale_factor):
414
+ points = points * scale_factor
415
+ img_shape = im.shape[:2]
416
+ points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
417
+ points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
418
+ return points
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+
17
+ __all__ = ["build_model"]
18
+
19
+
20
+ def build_model(config, **kwargs):
21
+ from .base_model import BaseModel
22
+
23
+ config = copy.deepcopy(config)
24
+ module_class = BaseModel(config, **kwargs)
25
+ return module_class
@@ -0,0 +1,105 @@
1
+ from torch import nn
2
+
3
+ from ..backbones import build_backbone
4
+ from ..heads import build_head
5
+ from ..necks import build_neck
6
+
7
+
8
+ class BaseModel(nn.Module):
9
+ def __init__(self, config, **kwargs):
10
+ """
11
+ the module for OCR.
12
+ args:
13
+ config (dict): the super parameters for module.
14
+ """
15
+ super(BaseModel, self).__init__()
16
+
17
+ in_channels = config.get("in_channels", 3)
18
+ model_type = config["model_type"]
19
+ # build backbone, backbone is need for del, rec and cls
20
+ if "Backbone" not in config or config["Backbone"] is None:
21
+ self.use_backbone = False
22
+ else:
23
+ self.use_backbone = True
24
+ config["Backbone"]["in_channels"] = in_channels
25
+ self.backbone = build_backbone(config["Backbone"], model_type)
26
+ in_channels = self.backbone.out_channels
27
+
28
+ # build neck
29
+ # for rec, neck can be cnn,rnn or reshape(None)
30
+ # for det, neck can be FPN, BIFPN and so on.
31
+ # for cls, neck should be none
32
+ if "Neck" not in config or config["Neck"] is None:
33
+ self.use_neck = False
34
+ else:
35
+ self.use_neck = True
36
+ config["Neck"]["in_channels"] = in_channels
37
+ self.neck = build_neck(config["Neck"])
38
+ in_channels = self.neck.out_channels
39
+
40
+ # # build head, head is need for det, rec and cls
41
+ if "Head" not in config or config["Head"] is None:
42
+ self.use_head = False
43
+ else:
44
+ self.use_head = True
45
+ config["Head"]["in_channels"] = in_channels
46
+ self.head = build_head(config["Head"], **kwargs)
47
+
48
+ self.return_all_feats = config.get("return_all_feats", False)
49
+
50
+ self._initialize_weights()
51
+
52
+ def _initialize_weights(self):
53
+ # weight initialization
54
+ for m in self.modules():
55
+ if isinstance(m, nn.Conv2d):
56
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
57
+ if m.bias is not None:
58
+ nn.init.zeros_(m.bias)
59
+ elif isinstance(m, nn.BatchNorm2d):
60
+ nn.init.ones_(m.weight)
61
+ nn.init.zeros_(m.bias)
62
+ elif isinstance(m, nn.Linear):
63
+ nn.init.normal_(m.weight, 0, 0.01)
64
+ if m.bias is not None:
65
+ nn.init.zeros_(m.bias)
66
+ elif isinstance(m, nn.ConvTranspose2d):
67
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
68
+ if m.bias is not None:
69
+ nn.init.zeros_(m.bias)
70
+
71
+ def forward(self, x):
72
+ y = dict()
73
+ if self.use_backbone:
74
+ x = self.backbone(x)
75
+ if isinstance(x, dict):
76
+ y.update(x)
77
+ else:
78
+ y["backbone_out"] = x
79
+ final_name = "backbone_out"
80
+ if self.use_neck:
81
+ x = self.neck(x)
82
+ if isinstance(x, dict):
83
+ y.update(x)
84
+ else:
85
+ y["neck_out"] = x
86
+ final_name = "neck_out"
87
+ if self.use_head:
88
+ x = self.head(x)
89
+ # for multi head, save ctc neck out for udml
90
+ if isinstance(x, dict) and "ctc_nect" in x.keys():
91
+ y["neck_out"] = x["ctc_neck"]
92
+ y["head_out"] = x
93
+ elif isinstance(x, dict):
94
+ y.update(x)
95
+ else:
96
+ y["head_out"] = x
97
+ if self.return_all_feats:
98
+ if self.training:
99
+ return y
100
+ elif isinstance(x, dict):
101
+ return x
102
+ else:
103
+ return {final_name: x}
104
+ else:
105
+ return x