doc-page-extractor 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. doc_page_extractor/__init__.py +5 -15
  2. doc_page_extractor/check_env.py +40 -0
  3. doc_page_extractor/extractor.py +88 -215
  4. doc_page_extractor/model.py +97 -0
  5. doc_page_extractor/parser.py +51 -0
  6. doc_page_extractor/plot.py +52 -79
  7. doc_page_extractor/redacter.py +111 -0
  8. doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
  9. doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
  10. {doc_page_extractor-0.2.0.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
  11. doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
  12. doc_page_extractor/clipper.py +0 -119
  13. doc_page_extractor/downloader.py +0 -16
  14. doc_page_extractor/latex.py +0 -31
  15. doc_page_extractor/layout_order.py +0 -237
  16. doc_page_extractor/layoutreader.py +0 -126
  17. doc_page_extractor/models.py +0 -92
  18. doc_page_extractor/ocr.py +0 -200
  19. doc_page_extractor/ocr_corrector.py +0 -126
  20. doc_page_extractor/onnxocr/__init__.py +0 -1
  21. doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
  22. doc_page_extractor/onnxocr/db_postprocess.py +0 -246
  23. doc_page_extractor/onnxocr/imaug.py +0 -32
  24. doc_page_extractor/onnxocr/operators.py +0 -187
  25. doc_page_extractor/onnxocr/predict_base.py +0 -57
  26. doc_page_extractor/onnxocr/predict_cls.py +0 -109
  27. doc_page_extractor/onnxocr/predict_det.py +0 -139
  28. doc_page_extractor/onnxocr/predict_rec.py +0 -344
  29. doc_page_extractor/onnxocr/predict_system.py +0 -97
  30. doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
  31. doc_page_extractor/onnxocr/utils.py +0 -71
  32. doc_page_extractor/overlap.py +0 -167
  33. doc_page_extractor/raw_optimizer.py +0 -104
  34. doc_page_extractor/rectangle.py +0 -72
  35. doc_page_extractor/rotation.py +0 -158
  36. doc_page_extractor/struct_eqtable/__init__.py +0 -49
  37. doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
  38. doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
  39. doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
  40. doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
  41. doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
  42. doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
  43. doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
  44. doc_page_extractor/table.py +0 -70
  45. doc_page_extractor/types.py +0 -91
  46. doc_page_extractor/utils.py +0 -32
  47. doc_page_extractor-0.2.0.dist-info/METADATA +0 -85
  48. doc_page_extractor-0.2.0.dist-info/RECORD +0 -45
  49. doc_page_extractor-0.2.0.dist-info/licenses/LICENSE +0 -661
  50. doc_page_extractor-0.2.0.dist-info/top_level.txt +0 -2
  51. tests/__init__.py +0 -0
  52. tests/test_history_bus.py +0 -55
@@ -1,139 +0,0 @@
1
- import numpy as np
2
- from .imaug import transform, create_operators
3
- from .db_postprocess import DBPostProcess
4
- from .predict_base import PredictBase
5
-
6
-
7
- class TextDetector(PredictBase):
8
- def __init__(self, args):
9
- super().__init__()
10
- self._args = args
11
- self.det_algorithm = args.det_algorithm
12
- pre_process_list = [
13
- {
14
- "DetResizeForTest": {
15
- "limit_side_len": args.det_limit_side_len,
16
- "limit_type": args.det_limit_type,
17
- }
18
- },
19
- {
20
- "NormalizeImage": {
21
- "std": [0.229, 0.224, 0.225],
22
- "mean": [0.485, 0.456, 0.406],
23
- "scale": "1./255.",
24
- "order": "hwc",
25
- }
26
- },
27
- {"ToCHWImage": None},
28
- {"KeepKeys": {"keep_keys": ["image", "shape"]}},
29
- ]
30
- postprocess_params = {}
31
- postprocess_params["name"] = "DBPostProcess"
32
- postprocess_params["thresh"] = args.det_db_thresh
33
- postprocess_params["box_thresh"] = args.det_db_box_thresh
34
- postprocess_params["max_candidates"] = 1000
35
- postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
36
- postprocess_params["use_dilation"] = args.use_dilation
37
- postprocess_params["score_mode"] = args.det_db_score_mode
38
- postprocess_params["box_type"] = args.det_box_type
39
-
40
- # 实例化预处理操作类
41
- self.preprocess_op = create_operators(pre_process_list)
42
- # self.postprocess_op = build_post_process(postprocess_params)
43
- # 实例化后处理操作类
44
- self.postprocess_op = DBPostProcess(**postprocess_params)
45
-
46
- # 初始化模型
47
- self._det_onnx_session = None
48
- self._det_input_name = None
49
- self._det_output_name = None
50
-
51
- @property
52
- def det_onnx_session(self):
53
- if self._det_onnx_session is None:
54
- self._det_onnx_session = self.get_onnx_session(self._args.det_model_dir, self._args.use_gpu)
55
- return self._det_onnx_session
56
-
57
- @property
58
- def det_input_name(self):
59
- if self._det_input_name is None:
60
- self._det_input_name = self.get_input_name(self.det_onnx_session)
61
- return self._det_input_name
62
-
63
- @property
64
- def det_output_name(self):
65
- if self._det_output_name is None:
66
- self._det_output_name = self.get_output_name(self.det_onnx_session)
67
- return self._det_output_name
68
-
69
- def order_points_clockwise(self, pts):
70
- rect = np.zeros((4, 2), dtype="float32")
71
- s = pts.sum(axis=1)
72
- rect[0] = pts[np.argmin(s)]
73
- rect[2] = pts[np.argmax(s)]
74
- tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
75
- diff = np.diff(np.array(tmp), axis=1)
76
- rect[1] = tmp[np.argmin(diff)]
77
- rect[3] = tmp[np.argmax(diff)]
78
- return rect
79
-
80
- def clip_det_res(self, points, img_height, img_width):
81
- for pno in range(points.shape[0]):
82
- points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
83
- points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
84
- return points
85
-
86
- def filter_tag_det_res(self, dt_boxes, image_shape):
87
- img_height, img_width = image_shape[0:2]
88
- dt_boxes_new = []
89
- for box in dt_boxes:
90
- if type(box) is list:
91
- box = np.array(box)
92
- box = self.order_points_clockwise(box)
93
- box = self.clip_det_res(box, img_height, img_width)
94
- rect_width = int(np.linalg.norm(box[0] - box[1]))
95
- rect_height = int(np.linalg.norm(box[0] - box[3]))
96
- if rect_width <= 3 or rect_height <= 3:
97
- continue
98
- dt_boxes_new.append(box)
99
- dt_boxes = np.array(dt_boxes_new)
100
- return dt_boxes
101
-
102
- def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
103
- img_height, img_width = image_shape[0:2]
104
- dt_boxes_new = []
105
- for box in dt_boxes:
106
- if type(box) is list:
107
- box = np.array(box)
108
- box = self.clip_det_res(box, img_height, img_width)
109
- dt_boxes_new.append(box)
110
- dt_boxes = np.array(dt_boxes_new)
111
- return dt_boxes
112
-
113
- def __call__(self, img):
114
- ori_im = img.copy()
115
- data = {"image": img}
116
-
117
- data = transform(data, self.preprocess_op)
118
- img, shape_list = data
119
- if img is None:
120
- return None, 0
121
- img = np.expand_dims(img, axis=0)
122
- shape_list = np.expand_dims(shape_list, axis=0)
123
- img = img.copy()
124
-
125
- input_feed = self.get_input_feed(self.det_input_name, img)
126
- outputs = self.det_onnx_session.run(self.det_output_name, input_feed=input_feed)
127
-
128
- preds = {}
129
- preds["maps"] = outputs[0]
130
-
131
- post_result = self.postprocess_op(preds, shape_list)
132
- dt_boxes = post_result[0]["points"]
133
-
134
- if self._args.det_box_type == "poly":
135
- dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
136
- else:
137
- dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
138
-
139
- return dt_boxes
@@ -1,344 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import math
4
- from PIL import Image
5
-
6
-
7
- from .rec_postprocess import CTCLabelDecode
8
- from .predict_base import PredictBase
9
-
10
-
11
- class TextRecognizer(PredictBase):
12
- def __init__(self, args):
13
- super().__init__()
14
- self._args = args
15
- self.rec_image_shape = args.rec_image_shape
16
- self.rec_batch_num = args.rec_batch_num
17
- self.rec_algorithm = args.rec_algorithm
18
- self.postprocess_op = CTCLabelDecode(
19
- character_dict_path=args.rec_char_dict_path,
20
- use_space_char=args.use_space_char,
21
- )
22
-
23
- # 初始化模型
24
- self._rec_onnx_session = None
25
- self._rec_input_name = None
26
- self._rec_output_name = None
27
-
28
- @property
29
- def rec_onnx_session(self):
30
- if self._rec_onnx_session is None:
31
- self._rec_onnx_session = self.get_onnx_session(
32
- self._args.rec_model_dir, self._args.use_gpu
33
- )
34
- return self._rec_onnx_session
35
-
36
- @property
37
- def rec_input_name(self):
38
- if self._rec_input_name is None:
39
- self._rec_input_name = self.get_input_name(self.rec_onnx_session)
40
- return self._rec_input_name
41
-
42
- @property
43
- def rec_output_name(self):
44
- if self._rec_output_name is None:
45
- self._rec_output_name = self.get_output_name(self.rec_onnx_session)
46
- return self._rec_output_name
47
-
48
- def resize_norm_img(self, img, max_wh_ratio):
49
- imgC, imgH, imgW = self.rec_image_shape
50
- if self.rec_algorithm == "NRTR" or self.rec_algorithm == "ViTSTR":
51
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
52
- # return padding_im
53
- image_pil = Image.fromarray(np.uint8(img))
54
- if self.rec_algorithm == "ViTSTR":
55
- img = image_pil.resize([imgW, imgH], Image.Resampling.BICUBIC)
56
- else:
57
- img = image_pil.resize([imgW, imgH], Image.Resampling.LANCZOS)
58
- img = np.array(img)
59
- norm_img = np.expand_dims(img, -1)
60
- norm_img = norm_img.transpose((2, 0, 1))
61
- if self.rec_algorithm == "ViTSTR":
62
- norm_img = norm_img.astype(np.float32) / 255.0
63
- else:
64
- norm_img = norm_img.astype(np.float32) / 128.0 - 1.0
65
- return norm_img
66
- elif self.rec_algorithm == "RFL":
67
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
68
- resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
69
- resized_image = resized_image.astype("float32")
70
- resized_image = resized_image / 255
71
- resized_image = resized_image[np.newaxis, :]
72
- resized_image -= 0.5
73
- resized_image /= 0.5
74
- return resized_image
75
-
76
- assert imgC == img.shape[2]
77
- imgW = int((imgH * max_wh_ratio))
78
-
79
- # w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
80
- # w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
81
- # print(w)
82
- # if w is not None and w > 0:
83
- # imgW = w
84
-
85
- h, w = img.shape[:2]
86
- ratio = w / float(h)
87
- if math.ceil(imgH * ratio) > imgW:
88
- resized_w = imgW
89
- else:
90
- resized_w = int(math.ceil(imgH * ratio))
91
- if self.rec_algorithm == "RARE":
92
- if resized_w > self.rec_image_shape[2]:
93
- resized_w = self.rec_image_shape[2]
94
- imgW = self.rec_image_shape[2]
95
- resized_image = cv2.resize(img, (resized_w, imgH))
96
- resized_image = resized_image.astype("float32")
97
- resized_image = resized_image.transpose((2, 0, 1)) / 255
98
- resized_image -= 0.5
99
- resized_image /= 0.5
100
- padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
101
- padding_im[:, :, 0:resized_w] = resized_image
102
- return padding_im
103
-
104
- def resize_norm_img_vl(self, img, image_shape):
105
- imgC, imgH, imgW = image_shape
106
- img = img[:, :, ::-1] # bgr2rgb
107
- resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
108
- resized_image = resized_image.astype("float32")
109
- resized_image = resized_image.transpose((2, 0, 1)) / 255
110
- return resized_image
111
-
112
- def resize_norm_img_srn(self, img, image_shape):
113
- imgC, imgH, imgW = image_shape
114
-
115
- img_black = np.zeros((imgH, imgW))
116
- im_hei = img.shape[0]
117
- im_wid = img.shape[1]
118
-
119
- if im_wid <= im_hei * 1:
120
- img_new = cv2.resize(img, (imgH * 1, imgH))
121
- elif im_wid <= im_hei * 2:
122
- img_new = cv2.resize(img, (imgH * 2, imgH))
123
- elif im_wid <= im_hei * 3:
124
- img_new = cv2.resize(img, (imgH * 3, imgH))
125
- else:
126
- img_new = cv2.resize(img, (imgW, imgH))
127
-
128
- img_np = np.asarray(img_new)
129
- img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
130
- img_black[:, 0 : img_np.shape[1]] = img_np
131
- img_black = img_black[:, :, np.newaxis]
132
-
133
- row, col, c = img_black.shape
134
- c = 1
135
-
136
- return np.reshape(img_black, (c, row, col)).astype(np.float32)
137
-
138
- def srn_other_inputs(self, image_shape, num_heads, max_text_length):
139
- imgC, imgH, imgW = image_shape
140
- feature_dim = int((imgH / 8) * (imgW / 8))
141
-
142
- encoder_word_pos = (
143
- np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
144
- )
145
- gsrm_word_pos = (
146
- np.array(range(0, max_text_length))
147
- .reshape((max_text_length, 1))
148
- .astype("int64")
149
- )
150
-
151
- gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
152
- gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
153
- [-1, 1, max_text_length, max_text_length]
154
- )
155
- gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype(
156
- "float32"
157
- ) * [-1e9]
158
-
159
- gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
160
- [-1, 1, max_text_length, max_text_length]
161
- )
162
- gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype(
163
- "float32"
164
- ) * [-1e9]
165
-
166
- encoder_word_pos = encoder_word_pos[np.newaxis, :]
167
- gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
168
-
169
- return [
170
- encoder_word_pos,
171
- gsrm_word_pos,
172
- gsrm_slf_attn_bias1,
173
- gsrm_slf_attn_bias2,
174
- ]
175
-
176
- def process_image_srn(self, img, image_shape, num_heads, max_text_length):
177
- norm_img = self.resize_norm_img_srn(img, image_shape)
178
- norm_img = norm_img[np.newaxis, :]
179
-
180
- [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = (
181
- self.srn_other_inputs(image_shape, num_heads, max_text_length)
182
- )
183
-
184
- gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
185
- gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
186
- encoder_word_pos = encoder_word_pos.astype(np.int64)
187
- gsrm_word_pos = gsrm_word_pos.astype(np.int64)
188
-
189
- return (
190
- norm_img,
191
- encoder_word_pos,
192
- gsrm_word_pos,
193
- gsrm_slf_attn_bias1,
194
- gsrm_slf_attn_bias2,
195
- )
196
-
197
- def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25):
198
- imgC, imgH, imgW_min, imgW_max = image_shape
199
- h = img.shape[0]
200
- w = img.shape[1]
201
- valid_ratio = 1.0
202
- # make sure new_width is an integral multiple of width_divisor.
203
- width_divisor = int(1 / width_downsample_ratio)
204
- # resize
205
- ratio = w / float(h)
206
- resize_w = math.ceil(imgH * ratio)
207
- if resize_w % width_divisor != 0:
208
- resize_w = round(resize_w / width_divisor) * width_divisor
209
- if imgW_min is not None:
210
- resize_w = max(imgW_min, resize_w)
211
- if imgW_max is not None:
212
- valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
213
- resize_w = min(imgW_max, resize_w)
214
- resized_image = cv2.resize(img, (resize_w, imgH))
215
- resized_image = resized_image.astype("float32")
216
- # norm
217
- if image_shape[0] == 1:
218
- resized_image = resized_image / 255
219
- resized_image = resized_image[np.newaxis, :]
220
- else:
221
- resized_image = resized_image.transpose((2, 0, 1)) / 255
222
- resized_image -= 0.5
223
- resized_image /= 0.5
224
- resize_shape = resized_image.shape
225
- padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
226
- padding_im[:, :, 0:resize_w] = resized_image
227
- pad_shape = padding_im.shape
228
-
229
- return padding_im, resize_shape, pad_shape, valid_ratio
230
-
231
- def resize_norm_img_spin(self, img):
232
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
233
- # return padding_im
234
- img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
235
- img = np.array(img, np.float32)
236
- img = np.expand_dims(img, -1)
237
- img = img.transpose((2, 0, 1))
238
- mean = [127.5]
239
- std = [127.5]
240
- mean = np.array(mean, dtype=np.float32)
241
- std = np.array(std, dtype=np.float32)
242
- mean = np.float32(mean.reshape(1, -1))
243
- stdinv = 1 / np.float32(std.reshape(1, -1))
244
- img -= mean
245
- img *= stdinv
246
- return img
247
-
248
- def resize_norm_img_svtr(self, img, image_shape):
249
- imgC, imgH, imgW = image_shape
250
- resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
251
- resized_image = resized_image.astype("float32")
252
- resized_image = resized_image.transpose((2, 0, 1)) / 255
253
- resized_image -= 0.5
254
- resized_image /= 0.5
255
- return resized_image
256
-
257
- def resize_norm_img_abinet(self, img, image_shape):
258
- imgC, imgH, imgW = image_shape
259
-
260
- resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
261
- resized_image = resized_image.astype("float32")
262
- resized_image = resized_image / 255.0
263
-
264
- mean = np.array([0.485, 0.456, 0.406])
265
- std = np.array([0.229, 0.224, 0.225])
266
- resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
267
- resized_image = resized_image.transpose((2, 0, 1))
268
- resized_image = resized_image.astype("float32")
269
-
270
- return resized_image
271
-
272
- def norm_img_can(self, img, image_shape):
273
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
274
-
275
- # FIXME
276
- # if self.inverse:
277
- # img = 255 - img
278
-
279
- if self.rec_image_shape[0] == 1:
280
- h, w = img.shape
281
- _, imgH, imgW = self.rec_image_shape
282
- if h < imgH or w < imgW:
283
- padding_h = max(imgH - h, 0)
284
- padding_w = max(imgW - w, 0)
285
- img_padded = np.pad(
286
- img,
287
- ((0, padding_h), (0, padding_w)),
288
- "constant",
289
- constant_values=(255),
290
- )
291
- img = img_padded
292
-
293
- img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
294
- img = img.astype("float32")
295
-
296
- return img
297
-
298
- def __call__(self, img_list):
299
- img_num = len(img_list)
300
- # Calculate the aspect ratio of all text bars
301
- width_list = []
302
- for img in img_list:
303
- width_list.append(img.shape[1] / float(img.shape[0]))
304
- # Sorting can speed up the recognition process
305
- indices = np.argsort(np.array(width_list))
306
- rec_res = [["", 0.0]] * img_num
307
- batch_num = self.rec_batch_num
308
-
309
- for beg_img_no in range(0, img_num, batch_num):
310
- end_img_no = min(img_num, beg_img_no + batch_num)
311
- norm_img_batch = []
312
- imgC, imgH, imgW = self.rec_image_shape[:3]
313
- max_wh_ratio = imgW / imgH
314
- # max_wh_ratio = 0
315
- for ino in range(beg_img_no, end_img_no):
316
- h, w = img_list[indices[ino]].shape[0:2]
317
- wh_ratio = w * 1.0 / h
318
- max_wh_ratio = max(max_wh_ratio, wh_ratio)
319
- for ino in range(beg_img_no, end_img_no):
320
- norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
321
- norm_img = norm_img[np.newaxis, :]
322
- norm_img_batch.append(norm_img)
323
-
324
- norm_img_batch = np.concatenate(norm_img_batch)
325
- norm_img_batch = norm_img_batch.copy()
326
-
327
- # img = img[:, :, ::-1].transpose(2, 0, 1)
328
- # img = img[:, :, ::-1]
329
- # img = img.transpose(2, 0, 1)
330
- # img = img.astype(np.float32)
331
- # img = np.expand_dims(img, axis=0)
332
- # print(img.shape)
333
- input_feed = self.get_input_feed(self.rec_input_name, norm_img_batch)
334
- outputs = self.rec_onnx_session.run(
335
- self.rec_output_name, input_feed=input_feed
336
- )
337
-
338
- preds = outputs[0]
339
-
340
- rec_result = self.postprocess_op(preds)
341
- for rno in range(len(rec_result)):
342
- rec_res[indices[beg_img_no + rno]] = rec_result[rno]
343
-
344
- return rec_res
@@ -1,97 +0,0 @@
1
- import os
2
- import cv2
3
- import copy
4
-
5
- from . import predict_det
6
- from . import predict_cls
7
- from . import predict_rec
8
- from .utils import get_rotate_crop_image, get_minarea_rect_crop
9
-
10
- class TextSystem:
11
- def __init__(self, args):
12
- self.text_detector = predict_det.TextDetector(args)
13
- self.text_recognizer = predict_rec.TextRecognizer(args)
14
- self.use_angle_cls = True
15
- self.drop_score = args.drop_score
16
- if self.use_angle_cls:
17
- self.text_classifier = predict_cls.TextClassifier(args)
18
-
19
- self.args = args
20
- self.crop_image_res_index = 0
21
-
22
- def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
23
- os.makedirs(output_dir, exist_ok=True)
24
- bbox_num = len(img_crop_list)
25
- for bno in range(bbox_num):
26
- cv2.imwrite(
27
- os.path.join(
28
- output_dir, f"mg_crop_{bno + self.crop_image_res_index}.jpg"
29
- ),
30
- img_crop_list[bno],
31
- )
32
-
33
- self.crop_image_res_index += bbox_num
34
-
35
- def __call__(self, img, cls=True):
36
- ori_im = img.copy()
37
- # 文字检测
38
- dt_boxes = self.text_detector(img)
39
-
40
- if dt_boxes is None:
41
- return None, None
42
-
43
- img_crop_list = []
44
-
45
- dt_boxes = sorted_boxes(dt_boxes)
46
-
47
- # 图片裁剪
48
- for bno in range(len(dt_boxes)):
49
- tmp_box = copy.deepcopy(dt_boxes[bno])
50
- if self.args.det_box_type == "quad":
51
- img_crop = get_rotate_crop_image(ori_im, tmp_box)
52
- else:
53
- img_crop = get_minarea_rect_crop(ori_im, tmp_box)
54
- img_crop_list.append(img_crop)
55
-
56
- # 方向分类
57
- if self.use_angle_cls and cls:
58
- img_crop_list, angle_list = self.text_classifier(img_crop_list)
59
-
60
- # 图像识别
61
- rec_res = self.text_recognizer(img_crop_list)
62
-
63
- if self.args.save_crop_res:
64
- self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
65
- filter_boxes, filter_rec_res = [], []
66
- for box, rec_result in zip(dt_boxes, rec_res):
67
- text, score = rec_result
68
- if score >= self.drop_score:
69
- filter_boxes.append(box)
70
- filter_rec_res.append(rec_result)
71
-
72
- return filter_boxes, filter_rec_res
73
-
74
-
75
- def sorted_boxes(dt_boxes):
76
- """
77
- Sort text boxes in order from top to bottom, left to right
78
- args:
79
- dt_boxes(array):detected text boxes with shape [4, 2]
80
- return:
81
- sorted boxes(array) with shape [4, 2]
82
- """
83
- num_boxes = dt_boxes.shape[0]
84
- sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
85
- _boxes = list(sorted_boxes)
86
-
87
- for i in range(num_boxes - 1):
88
- for j in range(i, -1, -1):
89
- if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
90
- _boxes[j + 1][0][0] < _boxes[j][0][0]
91
- ):
92
- tmp = _boxes[j]
93
- _boxes[j] = _boxes[j + 1]
94
- _boxes[j + 1] = tmp
95
- else:
96
- break
97
- return _boxes