doc-page-extractor 0.2.4__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- doc_page_extractor/__init__.py +16 -0
- doc_page_extractor/clipper.py +119 -0
- doc_page_extractor/downloader.py +16 -0
- doc_page_extractor/extractor.py +218 -0
- doc_page_extractor/latex.py +33 -0
- doc_page_extractor/layout_order.py +239 -0
- doc_page_extractor/layoutreader.py +126 -0
- doc_page_extractor/model.py +133 -0
- doc_page_extractor/ocr.py +196 -0
- doc_page_extractor/ocr_corrector.py +126 -0
- doc_page_extractor/onnxocr/__init__.py +1 -0
- doc_page_extractor/onnxocr/cls_postprocess.py +26 -0
- doc_page_extractor/onnxocr/db_postprocess.py +246 -0
- doc_page_extractor/onnxocr/imaug.py +32 -0
- doc_page_extractor/onnxocr/operators.py +187 -0
- doc_page_extractor/onnxocr/predict_base.py +57 -0
- doc_page_extractor/onnxocr/predict_cls.py +109 -0
- doc_page_extractor/onnxocr/predict_det.py +139 -0
- doc_page_extractor/onnxocr/predict_rec.py +344 -0
- doc_page_extractor/onnxocr/predict_system.py +97 -0
- doc_page_extractor/onnxocr/rec_postprocess.py +896 -0
- doc_page_extractor/onnxocr/utils.py +71 -0
- doc_page_extractor/overlap.py +167 -0
- doc_page_extractor/plot.py +93 -0
- doc_page_extractor/raw_optimizer.py +104 -0
- doc_page_extractor/rectangle.py +72 -0
- doc_page_extractor/rotation.py +158 -0
- doc_page_extractor/table.py +60 -0
- doc_page_extractor/types.py +68 -0
- doc_page_extractor/utils.py +32 -0
- doc_page_extractor-0.2.4.dist-info/LICENSE +661 -0
- doc_page_extractor-0.2.4.dist-info/METADATA +88 -0
- doc_page_extractor-0.2.4.dist-info/RECORD +34 -0
- doc_page_extractor-0.2.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from .imaug import transform, create_operators
|
|
3
|
+
from .db_postprocess import DBPostProcess
|
|
4
|
+
from .predict_base import PredictBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TextDetector(PredictBase):
|
|
8
|
+
def __init__(self, args):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self._args = args
|
|
11
|
+
self.det_algorithm = args.det_algorithm
|
|
12
|
+
pre_process_list = [
|
|
13
|
+
{
|
|
14
|
+
"DetResizeForTest": {
|
|
15
|
+
"limit_side_len": args.det_limit_side_len,
|
|
16
|
+
"limit_type": args.det_limit_type,
|
|
17
|
+
}
|
|
18
|
+
},
|
|
19
|
+
{
|
|
20
|
+
"NormalizeImage": {
|
|
21
|
+
"std": [0.229, 0.224, 0.225],
|
|
22
|
+
"mean": [0.485, 0.456, 0.406],
|
|
23
|
+
"scale": "1./255.",
|
|
24
|
+
"order": "hwc",
|
|
25
|
+
}
|
|
26
|
+
},
|
|
27
|
+
{"ToCHWImage": None},
|
|
28
|
+
{"KeepKeys": {"keep_keys": ["image", "shape"]}},
|
|
29
|
+
]
|
|
30
|
+
postprocess_params = {}
|
|
31
|
+
postprocess_params["name"] = "DBPostProcess"
|
|
32
|
+
postprocess_params["thresh"] = args.det_db_thresh
|
|
33
|
+
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
|
34
|
+
postprocess_params["max_candidates"] = 1000
|
|
35
|
+
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
|
36
|
+
postprocess_params["use_dilation"] = args.use_dilation
|
|
37
|
+
postprocess_params["score_mode"] = args.det_db_score_mode
|
|
38
|
+
postprocess_params["box_type"] = args.det_box_type
|
|
39
|
+
|
|
40
|
+
# 实例化预处理操作类
|
|
41
|
+
self.preprocess_op = create_operators(pre_process_list)
|
|
42
|
+
# self.postprocess_op = build_post_process(postprocess_params)
|
|
43
|
+
# 实例化后处理操作类
|
|
44
|
+
self.postprocess_op = DBPostProcess(**postprocess_params)
|
|
45
|
+
|
|
46
|
+
# 初始化模型
|
|
47
|
+
self._det_onnx_session = None
|
|
48
|
+
self._det_input_name = None
|
|
49
|
+
self._det_output_name = None
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def det_onnx_session(self):
|
|
53
|
+
if self._det_onnx_session is None:
|
|
54
|
+
self._det_onnx_session = self.get_onnx_session(self._args.det_model_dir, self._args.use_gpu)
|
|
55
|
+
return self._det_onnx_session
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def det_input_name(self):
|
|
59
|
+
if self._det_input_name is None:
|
|
60
|
+
self._det_input_name = self.get_input_name(self.det_onnx_session)
|
|
61
|
+
return self._det_input_name
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def det_output_name(self):
|
|
65
|
+
if self._det_output_name is None:
|
|
66
|
+
self._det_output_name = self.get_output_name(self.det_onnx_session)
|
|
67
|
+
return self._det_output_name
|
|
68
|
+
|
|
69
|
+
def order_points_clockwise(self, pts):
|
|
70
|
+
rect = np.zeros((4, 2), dtype="float32")
|
|
71
|
+
s = pts.sum(axis=1)
|
|
72
|
+
rect[0] = pts[np.argmin(s)]
|
|
73
|
+
rect[2] = pts[np.argmax(s)]
|
|
74
|
+
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
|
75
|
+
diff = np.diff(np.array(tmp), axis=1)
|
|
76
|
+
rect[1] = tmp[np.argmin(diff)]
|
|
77
|
+
rect[3] = tmp[np.argmax(diff)]
|
|
78
|
+
return rect
|
|
79
|
+
|
|
80
|
+
def clip_det_res(self, points, img_height, img_width):
|
|
81
|
+
for pno in range(points.shape[0]):
|
|
82
|
+
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
|
83
|
+
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
|
84
|
+
return points
|
|
85
|
+
|
|
86
|
+
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
87
|
+
img_height, img_width = image_shape[0:2]
|
|
88
|
+
dt_boxes_new = []
|
|
89
|
+
for box in dt_boxes:
|
|
90
|
+
if type(box) is list:
|
|
91
|
+
box = np.array(box)
|
|
92
|
+
box = self.order_points_clockwise(box)
|
|
93
|
+
box = self.clip_det_res(box, img_height, img_width)
|
|
94
|
+
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
95
|
+
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
96
|
+
if rect_width <= 3 or rect_height <= 3:
|
|
97
|
+
continue
|
|
98
|
+
dt_boxes_new.append(box)
|
|
99
|
+
dt_boxes = np.array(dt_boxes_new)
|
|
100
|
+
return dt_boxes
|
|
101
|
+
|
|
102
|
+
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
|
103
|
+
img_height, img_width = image_shape[0:2]
|
|
104
|
+
dt_boxes_new = []
|
|
105
|
+
for box in dt_boxes:
|
|
106
|
+
if type(box) is list:
|
|
107
|
+
box = np.array(box)
|
|
108
|
+
box = self.clip_det_res(box, img_height, img_width)
|
|
109
|
+
dt_boxes_new.append(box)
|
|
110
|
+
dt_boxes = np.array(dt_boxes_new)
|
|
111
|
+
return dt_boxes
|
|
112
|
+
|
|
113
|
+
def __call__(self, img):
|
|
114
|
+
ori_im = img.copy()
|
|
115
|
+
data = {"image": img}
|
|
116
|
+
|
|
117
|
+
data = transform(data, self.preprocess_op)
|
|
118
|
+
img, shape_list = data
|
|
119
|
+
if img is None:
|
|
120
|
+
return None, 0
|
|
121
|
+
img = np.expand_dims(img, axis=0)
|
|
122
|
+
shape_list = np.expand_dims(shape_list, axis=0)
|
|
123
|
+
img = img.copy()
|
|
124
|
+
|
|
125
|
+
input_feed = self.get_input_feed(self.det_input_name, img)
|
|
126
|
+
outputs = self.det_onnx_session.run(self.det_output_name, input_feed=input_feed)
|
|
127
|
+
|
|
128
|
+
preds = {}
|
|
129
|
+
preds["maps"] = outputs[0]
|
|
130
|
+
|
|
131
|
+
post_result = self.postprocess_op(preds, shape_list)
|
|
132
|
+
dt_boxes = post_result[0]["points"]
|
|
133
|
+
|
|
134
|
+
if self._args.det_box_type == "poly":
|
|
135
|
+
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
|
136
|
+
else:
|
|
137
|
+
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
|
138
|
+
|
|
139
|
+
return dt_boxes
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
import math
|
|
4
|
+
from PIL import Image
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from .rec_postprocess import CTCLabelDecode
|
|
8
|
+
from .predict_base import PredictBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TextRecognizer(PredictBase):
|
|
12
|
+
def __init__(self, args):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self._args = args
|
|
15
|
+
self.rec_image_shape = args.rec_image_shape
|
|
16
|
+
self.rec_batch_num = args.rec_batch_num
|
|
17
|
+
self.rec_algorithm = args.rec_algorithm
|
|
18
|
+
self.postprocess_op = CTCLabelDecode(
|
|
19
|
+
character_dict_path=args.rec_char_dict_path,
|
|
20
|
+
use_space_char=args.use_space_char,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# 初始化模型
|
|
24
|
+
self._rec_onnx_session = None
|
|
25
|
+
self._rec_input_name = None
|
|
26
|
+
self._rec_output_name = None
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def rec_onnx_session(self):
|
|
30
|
+
if self._rec_onnx_session is None:
|
|
31
|
+
self._rec_onnx_session = self.get_onnx_session(
|
|
32
|
+
self._args.rec_model_dir, self._args.use_gpu
|
|
33
|
+
)
|
|
34
|
+
return self._rec_onnx_session
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def rec_input_name(self):
|
|
38
|
+
if self._rec_input_name is None:
|
|
39
|
+
self._rec_input_name = self.get_input_name(self.rec_onnx_session)
|
|
40
|
+
return self._rec_input_name
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def rec_output_name(self):
|
|
44
|
+
if self._rec_output_name is None:
|
|
45
|
+
self._rec_output_name = self.get_output_name(self.rec_onnx_session)
|
|
46
|
+
return self._rec_output_name
|
|
47
|
+
|
|
48
|
+
def resize_norm_img(self, img, max_wh_ratio):
|
|
49
|
+
imgC, imgH, imgW = self.rec_image_shape
|
|
50
|
+
if self.rec_algorithm == "NRTR" or self.rec_algorithm == "ViTSTR":
|
|
51
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
52
|
+
# return padding_im
|
|
53
|
+
image_pil = Image.fromarray(np.uint8(img))
|
|
54
|
+
if self.rec_algorithm == "ViTSTR":
|
|
55
|
+
img = image_pil.resize([imgW, imgH], Image.Resampling.BICUBIC)
|
|
56
|
+
else:
|
|
57
|
+
img = image_pil.resize([imgW, imgH], Image.Resampling.LANCZOS)
|
|
58
|
+
img = np.array(img)
|
|
59
|
+
norm_img = np.expand_dims(img, -1)
|
|
60
|
+
norm_img = norm_img.transpose((2, 0, 1))
|
|
61
|
+
if self.rec_algorithm == "ViTSTR":
|
|
62
|
+
norm_img = norm_img.astype(np.float32) / 255.0
|
|
63
|
+
else:
|
|
64
|
+
norm_img = norm_img.astype(np.float32) / 128.0 - 1.0
|
|
65
|
+
return norm_img
|
|
66
|
+
elif self.rec_algorithm == "RFL":
|
|
67
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
68
|
+
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
|
|
69
|
+
resized_image = resized_image.astype("float32")
|
|
70
|
+
resized_image = resized_image / 255
|
|
71
|
+
resized_image = resized_image[np.newaxis, :]
|
|
72
|
+
resized_image -= 0.5
|
|
73
|
+
resized_image /= 0.5
|
|
74
|
+
return resized_image
|
|
75
|
+
|
|
76
|
+
assert imgC == img.shape[2]
|
|
77
|
+
imgW = int((imgH * max_wh_ratio))
|
|
78
|
+
|
|
79
|
+
# w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
|
|
80
|
+
# w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
|
|
81
|
+
# print(w)
|
|
82
|
+
# if w is not None and w > 0:
|
|
83
|
+
# imgW = w
|
|
84
|
+
|
|
85
|
+
h, w = img.shape[:2]
|
|
86
|
+
ratio = w / float(h)
|
|
87
|
+
if math.ceil(imgH * ratio) > imgW:
|
|
88
|
+
resized_w = imgW
|
|
89
|
+
else:
|
|
90
|
+
resized_w = int(math.ceil(imgH * ratio))
|
|
91
|
+
if self.rec_algorithm == "RARE":
|
|
92
|
+
if resized_w > self.rec_image_shape[2]:
|
|
93
|
+
resized_w = self.rec_image_shape[2]
|
|
94
|
+
imgW = self.rec_image_shape[2]
|
|
95
|
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
|
96
|
+
resized_image = resized_image.astype("float32")
|
|
97
|
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
98
|
+
resized_image -= 0.5
|
|
99
|
+
resized_image /= 0.5
|
|
100
|
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
|
101
|
+
padding_im[:, :, 0:resized_w] = resized_image
|
|
102
|
+
return padding_im
|
|
103
|
+
|
|
104
|
+
def resize_norm_img_vl(self, img, image_shape):
|
|
105
|
+
imgC, imgH, imgW = image_shape
|
|
106
|
+
img = img[:, :, ::-1] # bgr2rgb
|
|
107
|
+
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
|
108
|
+
resized_image = resized_image.astype("float32")
|
|
109
|
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
110
|
+
return resized_image
|
|
111
|
+
|
|
112
|
+
def resize_norm_img_srn(self, img, image_shape):
|
|
113
|
+
imgC, imgH, imgW = image_shape
|
|
114
|
+
|
|
115
|
+
img_black = np.zeros((imgH, imgW))
|
|
116
|
+
im_hei = img.shape[0]
|
|
117
|
+
im_wid = img.shape[1]
|
|
118
|
+
|
|
119
|
+
if im_wid <= im_hei * 1:
|
|
120
|
+
img_new = cv2.resize(img, (imgH * 1, imgH))
|
|
121
|
+
elif im_wid <= im_hei * 2:
|
|
122
|
+
img_new = cv2.resize(img, (imgH * 2, imgH))
|
|
123
|
+
elif im_wid <= im_hei * 3:
|
|
124
|
+
img_new = cv2.resize(img, (imgH * 3, imgH))
|
|
125
|
+
else:
|
|
126
|
+
img_new = cv2.resize(img, (imgW, imgH))
|
|
127
|
+
|
|
128
|
+
img_np = np.asarray(img_new)
|
|
129
|
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
|
130
|
+
img_black[:, 0 : img_np.shape[1]] = img_np
|
|
131
|
+
img_black = img_black[:, :, np.newaxis]
|
|
132
|
+
|
|
133
|
+
row, col, c = img_black.shape
|
|
134
|
+
c = 1
|
|
135
|
+
|
|
136
|
+
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
|
137
|
+
|
|
138
|
+
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
|
139
|
+
imgC, imgH, imgW = image_shape
|
|
140
|
+
feature_dim = int((imgH / 8) * (imgW / 8))
|
|
141
|
+
|
|
142
|
+
encoder_word_pos = (
|
|
143
|
+
np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64")
|
|
144
|
+
)
|
|
145
|
+
gsrm_word_pos = (
|
|
146
|
+
np.array(range(0, max_text_length))
|
|
147
|
+
.reshape((max_text_length, 1))
|
|
148
|
+
.astype("int64")
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
|
152
|
+
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
|
153
|
+
[-1, 1, max_text_length, max_text_length]
|
|
154
|
+
)
|
|
155
|
+
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype(
|
|
156
|
+
"float32"
|
|
157
|
+
) * [-1e9]
|
|
158
|
+
|
|
159
|
+
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
|
160
|
+
[-1, 1, max_text_length, max_text_length]
|
|
161
|
+
)
|
|
162
|
+
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype(
|
|
163
|
+
"float32"
|
|
164
|
+
) * [-1e9]
|
|
165
|
+
|
|
166
|
+
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
|
167
|
+
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
|
168
|
+
|
|
169
|
+
return [
|
|
170
|
+
encoder_word_pos,
|
|
171
|
+
gsrm_word_pos,
|
|
172
|
+
gsrm_slf_attn_bias1,
|
|
173
|
+
gsrm_slf_attn_bias2,
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
|
177
|
+
norm_img = self.resize_norm_img_srn(img, image_shape)
|
|
178
|
+
norm_img = norm_img[np.newaxis, :]
|
|
179
|
+
|
|
180
|
+
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = (
|
|
181
|
+
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
|
185
|
+
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
|
186
|
+
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
|
187
|
+
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
|
188
|
+
|
|
189
|
+
return (
|
|
190
|
+
norm_img,
|
|
191
|
+
encoder_word_pos,
|
|
192
|
+
gsrm_word_pos,
|
|
193
|
+
gsrm_slf_attn_bias1,
|
|
194
|
+
gsrm_slf_attn_bias2,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25):
|
|
198
|
+
imgC, imgH, imgW_min, imgW_max = image_shape
|
|
199
|
+
h = img.shape[0]
|
|
200
|
+
w = img.shape[1]
|
|
201
|
+
valid_ratio = 1.0
|
|
202
|
+
# make sure new_width is an integral multiple of width_divisor.
|
|
203
|
+
width_divisor = int(1 / width_downsample_ratio)
|
|
204
|
+
# resize
|
|
205
|
+
ratio = w / float(h)
|
|
206
|
+
resize_w = math.ceil(imgH * ratio)
|
|
207
|
+
if resize_w % width_divisor != 0:
|
|
208
|
+
resize_w = round(resize_w / width_divisor) * width_divisor
|
|
209
|
+
if imgW_min is not None:
|
|
210
|
+
resize_w = max(imgW_min, resize_w)
|
|
211
|
+
if imgW_max is not None:
|
|
212
|
+
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
|
213
|
+
resize_w = min(imgW_max, resize_w)
|
|
214
|
+
resized_image = cv2.resize(img, (resize_w, imgH))
|
|
215
|
+
resized_image = resized_image.astype("float32")
|
|
216
|
+
# norm
|
|
217
|
+
if image_shape[0] == 1:
|
|
218
|
+
resized_image = resized_image / 255
|
|
219
|
+
resized_image = resized_image[np.newaxis, :]
|
|
220
|
+
else:
|
|
221
|
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
222
|
+
resized_image -= 0.5
|
|
223
|
+
resized_image /= 0.5
|
|
224
|
+
resize_shape = resized_image.shape
|
|
225
|
+
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
|
226
|
+
padding_im[:, :, 0:resize_w] = resized_image
|
|
227
|
+
pad_shape = padding_im.shape
|
|
228
|
+
|
|
229
|
+
return padding_im, resize_shape, pad_shape, valid_ratio
|
|
230
|
+
|
|
231
|
+
def resize_norm_img_spin(self, img):
|
|
232
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
233
|
+
# return padding_im
|
|
234
|
+
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
|
235
|
+
img = np.array(img, np.float32)
|
|
236
|
+
img = np.expand_dims(img, -1)
|
|
237
|
+
img = img.transpose((2, 0, 1))
|
|
238
|
+
mean = [127.5]
|
|
239
|
+
std = [127.5]
|
|
240
|
+
mean = np.array(mean, dtype=np.float32)
|
|
241
|
+
std = np.array(std, dtype=np.float32)
|
|
242
|
+
mean = np.float32(mean.reshape(1, -1))
|
|
243
|
+
stdinv = 1 / np.float32(std.reshape(1, -1))
|
|
244
|
+
img -= mean
|
|
245
|
+
img *= stdinv
|
|
246
|
+
return img
|
|
247
|
+
|
|
248
|
+
def resize_norm_img_svtr(self, img, image_shape):
|
|
249
|
+
imgC, imgH, imgW = image_shape
|
|
250
|
+
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
|
251
|
+
resized_image = resized_image.astype("float32")
|
|
252
|
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
253
|
+
resized_image -= 0.5
|
|
254
|
+
resized_image /= 0.5
|
|
255
|
+
return resized_image
|
|
256
|
+
|
|
257
|
+
def resize_norm_img_abinet(self, img, image_shape):
|
|
258
|
+
imgC, imgH, imgW = image_shape
|
|
259
|
+
|
|
260
|
+
resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
|
261
|
+
resized_image = resized_image.astype("float32")
|
|
262
|
+
resized_image = resized_image / 255.0
|
|
263
|
+
|
|
264
|
+
mean = np.array([0.485, 0.456, 0.406])
|
|
265
|
+
std = np.array([0.229, 0.224, 0.225])
|
|
266
|
+
resized_image = (resized_image - mean[None, None, ...]) / std[None, None, ...]
|
|
267
|
+
resized_image = resized_image.transpose((2, 0, 1))
|
|
268
|
+
resized_image = resized_image.astype("float32")
|
|
269
|
+
|
|
270
|
+
return resized_image
|
|
271
|
+
|
|
272
|
+
def norm_img_can(self, img, image_shape):
|
|
273
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
|
274
|
+
|
|
275
|
+
# FIXME
|
|
276
|
+
# if self.inverse:
|
|
277
|
+
# img = 255 - img
|
|
278
|
+
|
|
279
|
+
if self.rec_image_shape[0] == 1:
|
|
280
|
+
h, w = img.shape
|
|
281
|
+
_, imgH, imgW = self.rec_image_shape
|
|
282
|
+
if h < imgH or w < imgW:
|
|
283
|
+
padding_h = max(imgH - h, 0)
|
|
284
|
+
padding_w = max(imgW - w, 0)
|
|
285
|
+
img_padded = np.pad(
|
|
286
|
+
img,
|
|
287
|
+
((0, padding_h), (0, padding_w)),
|
|
288
|
+
"constant",
|
|
289
|
+
constant_values=(255),
|
|
290
|
+
)
|
|
291
|
+
img = img_padded
|
|
292
|
+
|
|
293
|
+
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
|
294
|
+
img = img.astype("float32")
|
|
295
|
+
|
|
296
|
+
return img
|
|
297
|
+
|
|
298
|
+
def __call__(self, img_list):
|
|
299
|
+
img_num = len(img_list)
|
|
300
|
+
# Calculate the aspect ratio of all text bars
|
|
301
|
+
width_list = []
|
|
302
|
+
for img in img_list:
|
|
303
|
+
width_list.append(img.shape[1] / float(img.shape[0]))
|
|
304
|
+
# Sorting can speed up the recognition process
|
|
305
|
+
indices = np.argsort(np.array(width_list))
|
|
306
|
+
rec_res = [["", 0.0]] * img_num
|
|
307
|
+
batch_num = self.rec_batch_num
|
|
308
|
+
|
|
309
|
+
for beg_img_no in range(0, img_num, batch_num):
|
|
310
|
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
|
311
|
+
norm_img_batch = []
|
|
312
|
+
imgC, imgH, imgW = self.rec_image_shape[:3]
|
|
313
|
+
max_wh_ratio = imgW / imgH
|
|
314
|
+
# max_wh_ratio = 0
|
|
315
|
+
for ino in range(beg_img_no, end_img_no):
|
|
316
|
+
h, w = img_list[indices[ino]].shape[0:2]
|
|
317
|
+
wh_ratio = w * 1.0 / h
|
|
318
|
+
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
|
319
|
+
for ino in range(beg_img_no, end_img_no):
|
|
320
|
+
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
|
321
|
+
norm_img = norm_img[np.newaxis, :]
|
|
322
|
+
norm_img_batch.append(norm_img)
|
|
323
|
+
|
|
324
|
+
norm_img_batch = np.concatenate(norm_img_batch)
|
|
325
|
+
norm_img_batch = norm_img_batch.copy()
|
|
326
|
+
|
|
327
|
+
# img = img[:, :, ::-1].transpose(2, 0, 1)
|
|
328
|
+
# img = img[:, :, ::-1]
|
|
329
|
+
# img = img.transpose(2, 0, 1)
|
|
330
|
+
# img = img.astype(np.float32)
|
|
331
|
+
# img = np.expand_dims(img, axis=0)
|
|
332
|
+
# print(img.shape)
|
|
333
|
+
input_feed = self.get_input_feed(self.rec_input_name, norm_img_batch)
|
|
334
|
+
outputs = self.rec_onnx_session.run(
|
|
335
|
+
self.rec_output_name, input_feed=input_feed
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
preds = outputs[0]
|
|
339
|
+
|
|
340
|
+
rec_result = self.postprocess_op(preds)
|
|
341
|
+
for rno in range(len(rec_result)):
|
|
342
|
+
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
|
343
|
+
|
|
344
|
+
return rec_res
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import copy
|
|
4
|
+
|
|
5
|
+
from . import predict_det
|
|
6
|
+
from . import predict_cls
|
|
7
|
+
from . import predict_rec
|
|
8
|
+
from .utils import get_rotate_crop_image, get_minarea_rect_crop
|
|
9
|
+
|
|
10
|
+
class TextSystem:
|
|
11
|
+
def __init__(self, args):
|
|
12
|
+
self.text_detector = predict_det.TextDetector(args)
|
|
13
|
+
self.text_recognizer = predict_rec.TextRecognizer(args)
|
|
14
|
+
self.use_angle_cls = True
|
|
15
|
+
self.drop_score = args.drop_score
|
|
16
|
+
if self.use_angle_cls:
|
|
17
|
+
self.text_classifier = predict_cls.TextClassifier(args)
|
|
18
|
+
|
|
19
|
+
self.args = args
|
|
20
|
+
self.crop_image_res_index = 0
|
|
21
|
+
|
|
22
|
+
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
|
23
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
24
|
+
bbox_num = len(img_crop_list)
|
|
25
|
+
for bno in range(bbox_num):
|
|
26
|
+
cv2.imwrite(
|
|
27
|
+
os.path.join(
|
|
28
|
+
output_dir, f"mg_crop_{bno + self.crop_image_res_index}.jpg"
|
|
29
|
+
),
|
|
30
|
+
img_crop_list[bno],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.crop_image_res_index += bbox_num
|
|
34
|
+
|
|
35
|
+
def __call__(self, img, cls=True):
|
|
36
|
+
ori_im = img.copy()
|
|
37
|
+
# 文字检测
|
|
38
|
+
dt_boxes = self.text_detector(img)
|
|
39
|
+
|
|
40
|
+
if dt_boxes is None:
|
|
41
|
+
return None, None
|
|
42
|
+
|
|
43
|
+
img_crop_list = []
|
|
44
|
+
|
|
45
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
|
46
|
+
|
|
47
|
+
# 图片裁剪
|
|
48
|
+
for bno in range(len(dt_boxes)):
|
|
49
|
+
tmp_box = copy.deepcopy(dt_boxes[bno])
|
|
50
|
+
if self.args.det_box_type == "quad":
|
|
51
|
+
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
|
52
|
+
else:
|
|
53
|
+
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
|
54
|
+
img_crop_list.append(img_crop)
|
|
55
|
+
|
|
56
|
+
# 方向分类
|
|
57
|
+
if self.use_angle_cls and cls:
|
|
58
|
+
img_crop_list, angle_list = self.text_classifier(img_crop_list)
|
|
59
|
+
|
|
60
|
+
# 图像识别
|
|
61
|
+
rec_res = self.text_recognizer(img_crop_list)
|
|
62
|
+
|
|
63
|
+
if self.args.save_crop_res:
|
|
64
|
+
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
|
|
65
|
+
filter_boxes, filter_rec_res = [], []
|
|
66
|
+
for box, rec_result in zip(dt_boxes, rec_res):
|
|
67
|
+
text, score = rec_result
|
|
68
|
+
if score >= self.drop_score:
|
|
69
|
+
filter_boxes.append(box)
|
|
70
|
+
filter_rec_res.append(rec_result)
|
|
71
|
+
|
|
72
|
+
return filter_boxes, filter_rec_res
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def sorted_boxes(dt_boxes):
|
|
76
|
+
"""
|
|
77
|
+
Sort text boxes in order from top to bottom, left to right
|
|
78
|
+
args:
|
|
79
|
+
dt_boxes(array):detected text boxes with shape [4, 2]
|
|
80
|
+
return:
|
|
81
|
+
sorted boxes(array) with shape [4, 2]
|
|
82
|
+
"""
|
|
83
|
+
num_boxes = dt_boxes.shape[0]
|
|
84
|
+
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
|
85
|
+
_boxes = list(sorted_boxes)
|
|
86
|
+
|
|
87
|
+
for i in range(num_boxes - 1):
|
|
88
|
+
for j in range(i, -1, -1):
|
|
89
|
+
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
|
|
90
|
+
_boxes[j + 1][0][0] < _boxes[j][0][0]
|
|
91
|
+
):
|
|
92
|
+
tmp = _boxes[j]
|
|
93
|
+
_boxes[j] = _boxes[j + 1]
|
|
94
|
+
_boxes[j + 1] = tmp
|
|
95
|
+
else:
|
|
96
|
+
break
|
|
97
|
+
return _boxes
|