magic-pdf 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +56 -25
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +142 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +18 -12
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +15 -19
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/resources/slanet_plus/slanet-plus.onnx +0 -0
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/METADATA +92 -59
  88. magic_pdf-1.3.1.dist-info/RECORD +203 -0
  89. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/WHEEL +1 -1
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  91. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  92. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  93. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  94. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  95. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  96. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  97. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  98. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  99. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/LICENSE.md +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,440 @@
1
+ from PIL import Image
2
+ import cv2
3
+ import numpy as np
4
+ import math
5
+ import time
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from ...pytorchocr.base_ocr_v20 import BaseOCRV20
10
+ from . import pytorchocr_utility as utility
11
+ from ...pytorchocr.postprocess import build_post_process
12
+
13
+
14
+ class TextRecognizer(BaseOCRV20):
15
+ def __init__(self, args, **kwargs):
16
+ self.device = args.device
17
+ self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
18
+ self.character_type = args.rec_char_type
19
+ self.rec_batch_num = args.rec_batch_num
20
+ self.rec_algorithm = args.rec_algorithm
21
+ self.max_text_length = args.max_text_length
22
+ postprocess_params = {
23
+ 'name': 'CTCLabelDecode',
24
+ "character_type": args.rec_char_type,
25
+ "character_dict_path": args.rec_char_dict_path,
26
+ "use_space_char": args.use_space_char
27
+ }
28
+ if self.rec_algorithm == "SRN":
29
+ postprocess_params = {
30
+ 'name': 'SRNLabelDecode',
31
+ "character_type": args.rec_char_type,
32
+ "character_dict_path": args.rec_char_dict_path,
33
+ "use_space_char": args.use_space_char
34
+ }
35
+ elif self.rec_algorithm == "RARE":
36
+ postprocess_params = {
37
+ 'name': 'AttnLabelDecode',
38
+ "character_type": args.rec_char_type,
39
+ "character_dict_path": args.rec_char_dict_path,
40
+ "use_space_char": args.use_space_char
41
+ }
42
+ elif self.rec_algorithm == 'NRTR':
43
+ postprocess_params = {
44
+ 'name': 'NRTRLabelDecode',
45
+ "character_dict_path": args.rec_char_dict_path,
46
+ "use_space_char": args.use_space_char
47
+ }
48
+ elif self.rec_algorithm == "SAR":
49
+ postprocess_params = {
50
+ 'name': 'SARLabelDecode',
51
+ "character_dict_path": args.rec_char_dict_path,
52
+ "use_space_char": args.use_space_char
53
+ }
54
+ elif self.rec_algorithm == 'ViTSTR':
55
+ postprocess_params = {
56
+ 'name': 'ViTSTRLabelDecode',
57
+ "character_dict_path": args.rec_char_dict_path,
58
+ "use_space_char": args.use_space_char
59
+ }
60
+ elif self.rec_algorithm == "CAN":
61
+ self.inverse = args.rec_image_inverse
62
+ postprocess_params = {
63
+ 'name': 'CANLabelDecode',
64
+ "character_dict_path": args.rec_char_dict_path,
65
+ "use_space_char": args.use_space_char
66
+ }
67
+ elif self.rec_algorithm == 'RFL':
68
+ postprocess_params = {
69
+ 'name': 'RFLLabelDecode',
70
+ "character_dict_path": None,
71
+ "use_space_char": args.use_space_char
72
+ }
73
+ self.postprocess_op = build_post_process(postprocess_params)
74
+
75
+ self.limited_max_width = args.limited_max_width
76
+ self.limited_min_width = args.limited_min_width
77
+
78
+ self.weights_path = args.rec_model_path
79
+ self.yaml_path = args.rec_yaml_path
80
+
81
+ network_config = utility.get_arch_config(self.weights_path)
82
+ weights = self.read_pytorch_weights(self.weights_path)
83
+
84
+ self.out_channels = self.get_out_channels(weights)
85
+ if self.rec_algorithm == 'NRTR':
86
+ self.out_channels = list(weights.values())[-1].numpy().shape[0]
87
+ elif self.rec_algorithm == 'SAR':
88
+ self.out_channels = list(weights.values())[-3].numpy().shape[0]
89
+
90
+ kwargs['out_channels'] = self.out_channels
91
+ super(TextRecognizer, self).__init__(network_config, **kwargs)
92
+
93
+ self.load_state_dict(weights)
94
+ self.net.eval()
95
+ self.net.to(self.device)
96
+
97
+ def resize_norm_img(self, img, max_wh_ratio):
98
+ imgC, imgH, imgW = self.rec_image_shape
99
+ if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
100
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
101
+ # return padding_im
102
+ image_pil = Image.fromarray(np.uint8(img))
103
+ if self.rec_algorithm == 'ViTSTR':
104
+ img = image_pil.resize([imgW, imgH], Image.BICUBIC)
105
+ else:
106
+ img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
107
+ img = np.array(img)
108
+ norm_img = np.expand_dims(img, -1)
109
+ norm_img = norm_img.transpose((2, 0, 1))
110
+ if self.rec_algorithm == 'ViTSTR':
111
+ norm_img = norm_img.astype(np.float32) / 255.
112
+ else:
113
+ norm_img = norm_img.astype(np.float32) / 128. - 1.
114
+ return norm_img
115
+ elif self.rec_algorithm == 'RFL':
116
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
117
+ resized_image = cv2.resize(
118
+ img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
119
+ resized_image = resized_image.astype('float32')
120
+ resized_image = resized_image / 255
121
+ resized_image = resized_image[np.newaxis, :]
122
+ resized_image -= 0.5
123
+ resized_image /= 0.5
124
+ return resized_image
125
+
126
+ assert imgC == img.shape[2]
127
+ max_wh_ratio = max(max_wh_ratio, imgW / imgH)
128
+ imgW = int((imgH * max_wh_ratio))
129
+ imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
130
+ h, w = img.shape[:2]
131
+ ratio = w / float(h)
132
+ ratio_imgH = math.ceil(imgH * ratio)
133
+ ratio_imgH = max(ratio_imgH, self.limited_min_width)
134
+ if ratio_imgH > imgW:
135
+ resized_w = imgW
136
+ else:
137
+ resized_w = int(ratio_imgH)
138
+ resized_image = cv2.resize(img, (resized_w, imgH))
139
+ resized_image = resized_image.astype('float32')
140
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
141
+ resized_image -= 0.5
142
+ resized_image /= 0.5
143
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
144
+ padding_im[:, :, 0:resized_w] = resized_image
145
+ return padding_im
146
+
147
+ def resize_norm_img_svtr(self, img, image_shape):
148
+
149
+ imgC, imgH, imgW = image_shape
150
+ resized_image = cv2.resize(
151
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
152
+ resized_image = resized_image.astype('float32')
153
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
154
+ resized_image -= 0.5
155
+ resized_image /= 0.5
156
+ return resized_image
157
+
158
+
159
+ def resize_norm_img_srn(self, img, image_shape):
160
+ imgC, imgH, imgW = image_shape
161
+
162
+ img_black = np.zeros((imgH, imgW))
163
+ im_hei = img.shape[0]
164
+ im_wid = img.shape[1]
165
+
166
+ if im_wid <= im_hei * 1:
167
+ img_new = cv2.resize(img, (imgH * 1, imgH))
168
+ elif im_wid <= im_hei * 2:
169
+ img_new = cv2.resize(img, (imgH * 2, imgH))
170
+ elif im_wid <= im_hei * 3:
171
+ img_new = cv2.resize(img, (imgH * 3, imgH))
172
+ else:
173
+ img_new = cv2.resize(img, (imgW, imgH))
174
+
175
+ img_np = np.asarray(img_new)
176
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
177
+ img_black[:, 0:img_np.shape[1]] = img_np
178
+ img_black = img_black[:, :, np.newaxis]
179
+
180
+ row, col, c = img_black.shape
181
+ c = 1
182
+
183
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
184
+
185
+ def srn_other_inputs(self, image_shape, num_heads, max_text_length):
186
+
187
+ imgC, imgH, imgW = image_shape
188
+ feature_dim = int((imgH / 8) * (imgW / 8))
189
+
190
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
191
+ (feature_dim, 1)).astype('int64')
192
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
193
+ (max_text_length, 1)).astype('int64')
194
+
195
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
196
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
197
+ [-1, 1, max_text_length, max_text_length])
198
+ gsrm_slf_attn_bias1 = np.tile(
199
+ gsrm_slf_attn_bias1,
200
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
201
+
202
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
203
+ [-1, 1, max_text_length, max_text_length])
204
+ gsrm_slf_attn_bias2 = np.tile(
205
+ gsrm_slf_attn_bias2,
206
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
207
+
208
+ encoder_word_pos = encoder_word_pos[np.newaxis, :]
209
+ gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
210
+
211
+ return [
212
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
213
+ gsrm_slf_attn_bias2
214
+ ]
215
+
216
+ def process_image_srn(self, img, image_shape, num_heads, max_text_length):
217
+ norm_img = self.resize_norm_img_srn(img, image_shape)
218
+ norm_img = norm_img[np.newaxis, :]
219
+
220
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
221
+ self.srn_other_inputs(image_shape, num_heads, max_text_length)
222
+
223
+ gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
224
+ gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
225
+ encoder_word_pos = encoder_word_pos.astype(np.int64)
226
+ gsrm_word_pos = gsrm_word_pos.astype(np.int64)
227
+
228
+ return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
229
+ gsrm_slf_attn_bias2)
230
+
231
+ def resize_norm_img_sar(self, img, image_shape,
232
+ width_downsample_ratio=0.25):
233
+ imgC, imgH, imgW_min, imgW_max = image_shape
234
+ h = img.shape[0]
235
+ w = img.shape[1]
236
+ valid_ratio = 1.0
237
+ # make sure new_width is an integral multiple of width_divisor.
238
+ width_divisor = int(1 / width_downsample_ratio)
239
+ # resize
240
+ ratio = w / float(h)
241
+ resize_w = math.ceil(imgH * ratio)
242
+ if resize_w % width_divisor != 0:
243
+ resize_w = round(resize_w / width_divisor) * width_divisor
244
+ if imgW_min is not None:
245
+ resize_w = max(imgW_min, resize_w)
246
+ if imgW_max is not None:
247
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
248
+ resize_w = min(imgW_max, resize_w)
249
+ resized_image = cv2.resize(img, (resize_w, imgH))
250
+ resized_image = resized_image.astype('float32')
251
+ # norm
252
+ if image_shape[0] == 1:
253
+ resized_image = resized_image / 255
254
+ resized_image = resized_image[np.newaxis, :]
255
+ else:
256
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
257
+ resized_image -= 0.5
258
+ resized_image /= 0.5
259
+ resize_shape = resized_image.shape
260
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
261
+ padding_im[:, :, 0:resize_w] = resized_image
262
+ pad_shape = padding_im.shape
263
+
264
+ return padding_im, resize_shape, pad_shape, valid_ratio
265
+
266
+
267
+ def norm_img_can(self, img, image_shape):
268
+
269
+ img = cv2.cvtColor(
270
+ img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
271
+
272
+ if self.inverse:
273
+ img = 255 - img
274
+
275
+ if self.rec_image_shape[0] == 1:
276
+ h, w = img.shape
277
+ _, imgH, imgW = self.rec_image_shape
278
+ if h < imgH or w < imgW:
279
+ padding_h = max(imgH - h, 0)
280
+ padding_w = max(imgW - w, 0)
281
+ img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
282
+ 'constant',
283
+ constant_values=(255))
284
+ img = img_padded
285
+
286
+ img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
287
+ img = img.astype('float32')
288
+
289
+ return img
290
+
291
+ def __call__(self, img_list, tqdm_enable=False):
292
+ img_num = len(img_list)
293
+ # Calculate the aspect ratio of all text bars
294
+ width_list = []
295
+ for img in img_list:
296
+ width_list.append(img.shape[1] / float(img.shape[0]))
297
+ # Sorting can speed up the recognition process
298
+ indices = np.argsort(np.array(width_list))
299
+
300
+ # rec_res = []
301
+ rec_res = [['', 0.0]] * img_num
302
+ batch_num = self.rec_batch_num
303
+ elapse = 0
304
+ # for beg_img_no in range(0, img_num, batch_num):
305
+ with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
306
+ index = 0
307
+ for beg_img_no in range(0, img_num, batch_num):
308
+ end_img_no = min(img_num, beg_img_no + batch_num)
309
+ norm_img_batch = []
310
+ max_wh_ratio = 0
311
+ for ino in range(beg_img_no, end_img_no):
312
+ # h, w = img_list[ino].shape[0:2]
313
+ h, w = img_list[indices[ino]].shape[0:2]
314
+ wh_ratio = w * 1.0 / h
315
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
316
+ for ino in range(beg_img_no, end_img_no):
317
+ if self.rec_algorithm == "SAR":
318
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
319
+ img_list[indices[ino]], self.rec_image_shape)
320
+ norm_img = norm_img[np.newaxis, :]
321
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
322
+ valid_ratios = []
323
+ valid_ratios.append(valid_ratio)
324
+ norm_img_batch.append(norm_img)
325
+
326
+ elif self.rec_algorithm == "SVTR":
327
+ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
328
+ self.rec_image_shape)
329
+ norm_img = norm_img[np.newaxis, :]
330
+ norm_img_batch.append(norm_img)
331
+ elif self.rec_algorithm == "SRN":
332
+ norm_img = self.process_image_srn(img_list[indices[ino]],
333
+ self.rec_image_shape, 8,
334
+ self.max_text_length)
335
+ encoder_word_pos_list = []
336
+ gsrm_word_pos_list = []
337
+ gsrm_slf_attn_bias1_list = []
338
+ gsrm_slf_attn_bias2_list = []
339
+ encoder_word_pos_list.append(norm_img[1])
340
+ gsrm_word_pos_list.append(norm_img[2])
341
+ gsrm_slf_attn_bias1_list.append(norm_img[3])
342
+ gsrm_slf_attn_bias2_list.append(norm_img[4])
343
+ norm_img_batch.append(norm_img[0])
344
+ elif self.rec_algorithm == "CAN":
345
+ norm_img = self.norm_img_can(img_list[indices[ino]],
346
+ max_wh_ratio)
347
+ norm_img = norm_img[np.newaxis, :]
348
+ norm_img_batch.append(norm_img)
349
+ norm_image_mask = np.ones(norm_img.shape, dtype='float32')
350
+ word_label = np.ones([1, 36], dtype='int64')
351
+ norm_img_mask_batch = []
352
+ word_label_list = []
353
+ norm_img_mask_batch.append(norm_image_mask)
354
+ word_label_list.append(word_label)
355
+ else:
356
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
357
+ max_wh_ratio)
358
+ norm_img = norm_img[np.newaxis, :]
359
+ norm_img_batch.append(norm_img)
360
+ norm_img_batch = np.concatenate(norm_img_batch)
361
+ norm_img_batch = norm_img_batch.copy()
362
+
363
+ if self.rec_algorithm == "SRN":
364
+ starttime = time.time()
365
+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
366
+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
367
+ gsrm_slf_attn_bias1_list = np.concatenate(
368
+ gsrm_slf_attn_bias1_list)
369
+ gsrm_slf_attn_bias2_list = np.concatenate(
370
+ gsrm_slf_attn_bias2_list)
371
+
372
+ with torch.no_grad():
373
+ inp = torch.from_numpy(norm_img_batch)
374
+ encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
375
+ gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
376
+ gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
377
+ gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
378
+
379
+ inp = inp.to(self.device)
380
+ encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
381
+ gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
382
+ gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
383
+ gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
384
+
385
+ backbone_out = self.net.backbone(inp) # backbone_feat
386
+ prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
387
+ # preds = {"predict": prob_out[2]}
388
+ preds = {"predict": prob_out["predict"]}
389
+
390
+ elif self.rec_algorithm == "SAR":
391
+ starttime = time.time()
392
+ # valid_ratios = np.concatenate(valid_ratios)
393
+ # inputs = [
394
+ # norm_img_batch,
395
+ # valid_ratios,
396
+ # ]
397
+
398
+ with torch.no_grad():
399
+ inp = torch.from_numpy(norm_img_batch)
400
+ inp = inp.to(self.device)
401
+ preds = self.net(inp)
402
+
403
+ elif self.rec_algorithm == "CAN":
404
+ starttime = time.time()
405
+ norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
406
+ word_label_list = np.concatenate(word_label_list)
407
+ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
408
+
409
+ inp = [torch.from_numpy(e_i) for e_i in inputs]
410
+ inp = [e_i.to(self.device) for e_i in inp]
411
+ with torch.no_grad():
412
+ outputs = self.net(inp)
413
+ outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
414
+
415
+ preds = outputs
416
+
417
+ else:
418
+ starttime = time.time()
419
+
420
+ with torch.no_grad():
421
+ inp = torch.from_numpy(norm_img_batch)
422
+ inp = inp.to(self.device)
423
+ prob_out = self.net(inp)
424
+
425
+ if isinstance(prob_out, list):
426
+ preds = [v.cpu().numpy() for v in prob_out]
427
+ else:
428
+ preds = prob_out.cpu().numpy()
429
+
430
+ rec_result = self.postprocess_op(preds)
431
+ for rno in range(len(rec_result)):
432
+ rec_res[indices[beg_img_no + rno]] = rec_result[rno]
433
+ elapse += time.time() - starttime
434
+
435
+ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
436
+ current_batch_size = min(batch_num, img_num - index * batch_num)
437
+ index += 1
438
+ pbar.update(current_batch_size)
439
+
440
+ return rec_res, elapse
@@ -0,0 +1,104 @@
1
+ import cv2
2
+ import copy
3
+ import numpy as np
4
+
5
+ from . import predict_rec
6
+ from . import predict_det
7
+ from . import predict_cls
8
+
9
+
10
+ class TextSystem(object):
11
+ def __init__(self, args, **kwargs):
12
+ self.text_detector = predict_det.TextDetector(args, **kwargs)
13
+ self.text_recognizer = predict_rec.TextRecognizer(args, **kwargs)
14
+ self.use_angle_cls = args.use_angle_cls
15
+ self.drop_score = args.drop_score
16
+ if self.use_angle_cls:
17
+ self.text_classifier = predict_cls.TextClassifier(args, **kwargs)
18
+
19
+ def get_rotate_crop_image(self, img, points):
20
+ '''
21
+ img_height, img_width = img.shape[0:2]
22
+ left = int(np.min(points[:, 0]))
23
+ right = int(np.max(points[:, 0]))
24
+ top = int(np.min(points[:, 1]))
25
+ bottom = int(np.max(points[:, 1]))
26
+ img_crop = img[top:bottom, left:right, :].copy()
27
+ points[:, 0] = points[:, 0] - left
28
+ points[:, 1] = points[:, 1] - top
29
+ '''
30
+ img_crop_width = int(
31
+ max(
32
+ np.linalg.norm(points[0] - points[1]),
33
+ np.linalg.norm(points[2] - points[3])))
34
+ img_crop_height = int(
35
+ max(
36
+ np.linalg.norm(points[0] - points[3]),
37
+ np.linalg.norm(points[1] - points[2])))
38
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
39
+ [img_crop_width, img_crop_height],
40
+ [0, img_crop_height]])
41
+ M = cv2.getPerspectiveTransform(points, pts_std)
42
+ dst_img = cv2.warpPerspective(
43
+ img,
44
+ M, (img_crop_width, img_crop_height),
45
+ borderMode=cv2.BORDER_REPLICATE,
46
+ flags=cv2.INTER_CUBIC)
47
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
48
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
49
+ dst_img = np.rot90(dst_img)
50
+ return dst_img
51
+
52
+ def __call__(self, img):
53
+ ori_im = img.copy()
54
+ dt_boxes, elapse = self.text_detector(img)
55
+ print("dt_boxes num : {}, elapse : {}".format(
56
+ len(dt_boxes), elapse))
57
+ if dt_boxes is None:
58
+ return None, None
59
+ img_crop_list = []
60
+
61
+ dt_boxes = sorted_boxes(dt_boxes)
62
+
63
+ for bno in range(len(dt_boxes)):
64
+ tmp_box = copy.deepcopy(dt_boxes[bno])
65
+ img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
66
+ img_crop_list.append(img_crop)
67
+ if self.use_angle_cls:
68
+ img_crop_list, angle_list, elapse = self.text_classifier(
69
+ img_crop_list)
70
+ print("cls num : {}, elapse : {}".format(
71
+ len(img_crop_list), elapse))
72
+
73
+ rec_res, elapse = self.text_recognizer(img_crop_list)
74
+ print("rec_res num : {}, elapse : {}".format(
75
+ len(rec_res), elapse))
76
+ # self.print_draw_crop_rec_res(img_crop_list, rec_res)
77
+ filter_boxes, filter_rec_res = [], []
78
+ for box, rec_reuslt in zip(dt_boxes, rec_res):
79
+ text, score = rec_reuslt
80
+ if score >= self.drop_score:
81
+ filter_boxes.append(box)
82
+ filter_rec_res.append(rec_reuslt)
83
+ return filter_boxes, filter_rec_res
84
+
85
+
86
+ def sorted_boxes(dt_boxes):
87
+ """
88
+ Sort text boxes in order from top to bottom, left to right
89
+ args:
90
+ dt_boxes(array):detected text boxes with shape [4, 2]
91
+ return:
92
+ sorted boxes(array) with shape [4, 2]
93
+ """
94
+ num_boxes = dt_boxes.shape[0]
95
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
96
+ _boxes = list(sorted_boxes)
97
+
98
+ for i in range(num_boxes - 1):
99
+ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
100
+ (_boxes[i + 1][0][0] < _boxes[i][0][0]):
101
+ tmp = _boxes[i]
102
+ _boxes[i] = _boxes[i + 1]
103
+ _boxes[i + 1] = tmp
104
+ return _boxes