doc-page-extractor 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- doc_page_extractor/__init__.py +1 -1
- doc_page_extractor/downloader.py +4 -1
- doc_page_extractor/extractor.py +6 -7
- doc_page_extractor/ocr.py +110 -58
- doc_page_extractor/ocr_corrector.py +3 -3
- doc_page_extractor/onnxocr/__init__.py +1 -0
- doc_page_extractor/onnxocr/cls_postprocess.py +26 -0
- doc_page_extractor/onnxocr/db_postprocess.py +246 -0
- doc_page_extractor/onnxocr/imaug.py +32 -0
- doc_page_extractor/onnxocr/operators.py +187 -0
- doc_page_extractor/onnxocr/predict_base.py +52 -0
- doc_page_extractor/onnxocr/predict_cls.py +89 -0
- doc_page_extractor/onnxocr/predict_det.py +120 -0
- doc_page_extractor/onnxocr/predict_rec.py +321 -0
- doc_page_extractor/onnxocr/predict_system.py +97 -0
- doc_page_extractor/onnxocr/rec_postprocess.py +896 -0
- doc_page_extractor/onnxocr/utils.py +71 -0
- {doc_page_extractor-0.0.4.dist-info → doc_page_extractor-0.0.6.dist-info}/METADATA +7 -3
- doc_page_extractor-0.0.6.dist-info/RECORD +33 -0
- {doc_page_extractor-0.0.4.dist-info → doc_page_extractor-0.0.6.dist-info}/WHEEL +1 -1
- doc_page_extractor-0.0.4.dist-info/RECORD +0 -21
- {doc_page_extractor-0.0.4.dist-info → doc_page_extractor-0.0.6.dist-info}/LICENSE +0 -0
- {doc_page_extractor-0.0.4.dist-info → doc_page_extractor-0.0.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import cv2
|
|
3
|
+
import sys
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NormalizeImage(object):
|
|
8
|
+
""" normalize image such as substract mean, divide std
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
|
12
|
+
if isinstance(scale, str):
|
|
13
|
+
scale = eval(scale)
|
|
14
|
+
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
|
15
|
+
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
|
16
|
+
std = std if std is not None else [0.229, 0.224, 0.225]
|
|
17
|
+
|
|
18
|
+
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
|
19
|
+
self.mean = np.array(mean).reshape(shape).astype('float32')
|
|
20
|
+
self.std = np.array(std).reshape(shape).astype('float32')
|
|
21
|
+
|
|
22
|
+
def __call__(self, data):
|
|
23
|
+
img = data['image']
|
|
24
|
+
from PIL import Image
|
|
25
|
+
if isinstance(img, Image.Image):
|
|
26
|
+
img = np.array(img)
|
|
27
|
+
assert isinstance(img,
|
|
28
|
+
np.ndarray), "invalid input 'img' in NormalizeImage"
|
|
29
|
+
data['image'] = (
|
|
30
|
+
img.astype('float32') * self.scale - self.mean) / self.std
|
|
31
|
+
return data
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DetResizeForTest(object):
|
|
35
|
+
def __init__(self, **kwargs):
|
|
36
|
+
super(DetResizeForTest, self).__init__()
|
|
37
|
+
self.resize_type = 0
|
|
38
|
+
self.keep_ratio = False
|
|
39
|
+
if 'image_shape' in kwargs:
|
|
40
|
+
self.image_shape = kwargs['image_shape']
|
|
41
|
+
self.resize_type = 1
|
|
42
|
+
if 'keep_ratio' in kwargs:
|
|
43
|
+
self.keep_ratio = kwargs['keep_ratio']
|
|
44
|
+
elif 'limit_side_len' in kwargs:
|
|
45
|
+
self.limit_side_len = kwargs['limit_side_len']
|
|
46
|
+
self.limit_type = kwargs.get('limit_type', 'min')
|
|
47
|
+
elif 'resize_long' in kwargs:
|
|
48
|
+
self.resize_type = 2
|
|
49
|
+
self.resize_long = kwargs.get('resize_long', 960)
|
|
50
|
+
else:
|
|
51
|
+
self.limit_side_len = 736
|
|
52
|
+
self.limit_type = 'min'
|
|
53
|
+
|
|
54
|
+
def __call__(self, data):
|
|
55
|
+
img = data['image']
|
|
56
|
+
src_h, src_w, _ = img.shape
|
|
57
|
+
if sum([src_h, src_w]) < 64:
|
|
58
|
+
img = self.image_padding(img)
|
|
59
|
+
|
|
60
|
+
if self.resize_type == 0:
|
|
61
|
+
# img, shape = self.resize_image_type0(img)
|
|
62
|
+
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
|
63
|
+
elif self.resize_type == 2:
|
|
64
|
+
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
|
65
|
+
else:
|
|
66
|
+
# img, shape = self.resize_image_type1(img)
|
|
67
|
+
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
|
68
|
+
data['image'] = img
|
|
69
|
+
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
|
70
|
+
return data
|
|
71
|
+
|
|
72
|
+
def image_padding(self, im, value=0):
|
|
73
|
+
h, w, c = im.shape
|
|
74
|
+
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
|
75
|
+
im_pad[:h, :w, :] = im
|
|
76
|
+
return im_pad
|
|
77
|
+
|
|
78
|
+
def resize_image_type1(self, img):
|
|
79
|
+
resize_h, resize_w = self.image_shape
|
|
80
|
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
|
81
|
+
if self.keep_ratio is True:
|
|
82
|
+
resize_w = ori_w * resize_h / ori_h
|
|
83
|
+
N = math.ceil(resize_w / 32)
|
|
84
|
+
resize_w = N * 32
|
|
85
|
+
ratio_h = float(resize_h) / ori_h
|
|
86
|
+
ratio_w = float(resize_w) / ori_w
|
|
87
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
|
88
|
+
# return img, np.array([ori_h, ori_w])
|
|
89
|
+
return img, [ratio_h, ratio_w]
|
|
90
|
+
|
|
91
|
+
def resize_image_type0(self, img):
|
|
92
|
+
"""
|
|
93
|
+
resize image to a size multiple of 32 which is required by the network
|
|
94
|
+
args:
|
|
95
|
+
img(array): array with shape [h, w, c]
|
|
96
|
+
return(tuple):
|
|
97
|
+
img, (ratio_h, ratio_w)
|
|
98
|
+
"""
|
|
99
|
+
limit_side_len = self.limit_side_len
|
|
100
|
+
h, w, c = img.shape
|
|
101
|
+
|
|
102
|
+
# limit the max side
|
|
103
|
+
if self.limit_type == 'max':
|
|
104
|
+
if max(h, w) > limit_side_len:
|
|
105
|
+
if h > w:
|
|
106
|
+
ratio = float(limit_side_len) / h
|
|
107
|
+
else:
|
|
108
|
+
ratio = float(limit_side_len) / w
|
|
109
|
+
else:
|
|
110
|
+
ratio = 1.
|
|
111
|
+
elif self.limit_type == 'min':
|
|
112
|
+
if min(h, w) < limit_side_len:
|
|
113
|
+
if h < w:
|
|
114
|
+
ratio = float(limit_side_len) / h
|
|
115
|
+
else:
|
|
116
|
+
ratio = float(limit_side_len) / w
|
|
117
|
+
else:
|
|
118
|
+
ratio = 1.
|
|
119
|
+
elif self.limit_type == 'resize_long':
|
|
120
|
+
ratio = float(limit_side_len) / max(h, w)
|
|
121
|
+
else:
|
|
122
|
+
raise Exception('not support limit type, image ')
|
|
123
|
+
resize_h = int(h * ratio)
|
|
124
|
+
resize_w = int(w * ratio)
|
|
125
|
+
|
|
126
|
+
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
|
127
|
+
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
|
131
|
+
return None, (None, None)
|
|
132
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
|
133
|
+
except:
|
|
134
|
+
print(img.shape, resize_w, resize_h)
|
|
135
|
+
sys.exit(0)
|
|
136
|
+
ratio_h = resize_h / float(h)
|
|
137
|
+
ratio_w = resize_w / float(w)
|
|
138
|
+
return img, [ratio_h, ratio_w]
|
|
139
|
+
|
|
140
|
+
def resize_image_type2(self, img):
|
|
141
|
+
h, w, _ = img.shape
|
|
142
|
+
|
|
143
|
+
resize_w = w
|
|
144
|
+
resize_h = h
|
|
145
|
+
|
|
146
|
+
if resize_h > resize_w:
|
|
147
|
+
ratio = float(self.resize_long) / resize_h
|
|
148
|
+
else:
|
|
149
|
+
ratio = float(self.resize_long) / resize_w
|
|
150
|
+
|
|
151
|
+
resize_h = int(resize_h * ratio)
|
|
152
|
+
resize_w = int(resize_w * ratio)
|
|
153
|
+
|
|
154
|
+
max_stride = 128
|
|
155
|
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
|
156
|
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
|
157
|
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
|
158
|
+
ratio_h = resize_h / float(h)
|
|
159
|
+
ratio_w = resize_w / float(w)
|
|
160
|
+
|
|
161
|
+
return img, [ratio_h, ratio_w]
|
|
162
|
+
|
|
163
|
+
class ToCHWImage(object):
|
|
164
|
+
""" convert hwc image to chw image
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
def __init__(self, **kwargs):
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
def __call__(self, data):
|
|
171
|
+
img = data['image']
|
|
172
|
+
from PIL import Image
|
|
173
|
+
if isinstance(img, Image.Image):
|
|
174
|
+
img = np.array(img)
|
|
175
|
+
data['image'] = img.transpose((2, 0, 1))
|
|
176
|
+
return data
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class KeepKeys(object):
|
|
180
|
+
def __init__(self, keep_keys, **kwargs):
|
|
181
|
+
self.keep_keys = keep_keys
|
|
182
|
+
|
|
183
|
+
def __call__(self, data):
|
|
184
|
+
data_list = []
|
|
185
|
+
for key in self.keep_keys:
|
|
186
|
+
data_list.append(data[key])
|
|
187
|
+
return data_list
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import onnxruntime
|
|
2
|
+
|
|
3
|
+
class PredictBase(object):
|
|
4
|
+
def __init__(self):
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
def get_onnx_session(self, model_dir, use_gpu):
|
|
8
|
+
# 使用gpu
|
|
9
|
+
if use_gpu:
|
|
10
|
+
providers = providers=['CUDAExecutionProvider']
|
|
11
|
+
else:
|
|
12
|
+
providers = providers = ['CPUExecutionProvider']
|
|
13
|
+
|
|
14
|
+
onnx_session = onnxruntime.InferenceSession(model_dir, None,providers=providers)
|
|
15
|
+
|
|
16
|
+
# print("providers:", onnxruntime.get_device())
|
|
17
|
+
return onnx_session
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_output_name(self, onnx_session):
|
|
21
|
+
"""
|
|
22
|
+
output_name = onnx_session.get_outputs()[0].name
|
|
23
|
+
:param onnx_session:
|
|
24
|
+
:return:
|
|
25
|
+
"""
|
|
26
|
+
output_name = []
|
|
27
|
+
for node in onnx_session.get_outputs():
|
|
28
|
+
output_name.append(node.name)
|
|
29
|
+
return output_name
|
|
30
|
+
|
|
31
|
+
def get_input_name(self, onnx_session):
|
|
32
|
+
"""
|
|
33
|
+
input_name = onnx_session.get_inputs()[0].name
|
|
34
|
+
:param onnx_session:
|
|
35
|
+
:return:
|
|
36
|
+
"""
|
|
37
|
+
input_name = []
|
|
38
|
+
for node in onnx_session.get_inputs():
|
|
39
|
+
input_name.append(node.name)
|
|
40
|
+
return input_name
|
|
41
|
+
|
|
42
|
+
def get_input_feed(self, input_name, image_numpy):
|
|
43
|
+
"""
|
|
44
|
+
input_feed={self.input_name: image_numpy}
|
|
45
|
+
:param input_name:
|
|
46
|
+
:param image_numpy:
|
|
47
|
+
:return:
|
|
48
|
+
"""
|
|
49
|
+
input_feed = {}
|
|
50
|
+
for name in input_name:
|
|
51
|
+
input_feed[name] = image_numpy
|
|
52
|
+
return input_feed
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import copy
|
|
3
|
+
import numpy as np
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from .cls_postprocess import ClsPostProcess
|
|
7
|
+
from .predict_base import PredictBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TextClassifier(PredictBase):
|
|
11
|
+
def __init__(self, args):
|
|
12
|
+
self.cls_image_shape = args.cls_image_shape
|
|
13
|
+
self.cls_batch_num = args.cls_batch_num
|
|
14
|
+
self.cls_thresh = args.cls_thresh
|
|
15
|
+
self.postprocess_op = ClsPostProcess(label_list=args.label_list)
|
|
16
|
+
|
|
17
|
+
# 初始化模型
|
|
18
|
+
self.cls_onnx_session = self.get_onnx_session(args.cls_model_dir, args.use_gpu)
|
|
19
|
+
self.cls_input_name = self.get_input_name(self.cls_onnx_session)
|
|
20
|
+
self.cls_output_name = self.get_output_name(self.cls_onnx_session)
|
|
21
|
+
|
|
22
|
+
def resize_norm_img(self, img):
|
|
23
|
+
imgC, imgH, imgW = self.cls_image_shape
|
|
24
|
+
h = img.shape[0]
|
|
25
|
+
w = img.shape[1]
|
|
26
|
+
ratio = w / float(h)
|
|
27
|
+
if math.ceil(imgH * ratio) > imgW:
|
|
28
|
+
resized_w = imgW
|
|
29
|
+
else:
|
|
30
|
+
resized_w = int(math.ceil(imgH * ratio))
|
|
31
|
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
|
32
|
+
resized_image = resized_image.astype("float32")
|
|
33
|
+
if self.cls_image_shape[0] == 1:
|
|
34
|
+
resized_image = resized_image / 255
|
|
35
|
+
resized_image = resized_image[np.newaxis, :]
|
|
36
|
+
else:
|
|
37
|
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
38
|
+
resized_image -= 0.5
|
|
39
|
+
resized_image /= 0.5
|
|
40
|
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
|
41
|
+
padding_im[:, :, 0:resized_w] = resized_image
|
|
42
|
+
return padding_im
|
|
43
|
+
|
|
44
|
+
def __call__(self, img_list):
|
|
45
|
+
img_list = copy.deepcopy(img_list)
|
|
46
|
+
img_num = len(img_list)
|
|
47
|
+
# Calculate the aspect ratio of all text bars
|
|
48
|
+
width_list = []
|
|
49
|
+
for img in img_list:
|
|
50
|
+
width_list.append(img.shape[1] / float(img.shape[0]))
|
|
51
|
+
# Sorting can speed up the cls process
|
|
52
|
+
indices = np.argsort(np.array(width_list))
|
|
53
|
+
|
|
54
|
+
cls_res = [["", 0.0]] * img_num
|
|
55
|
+
batch_num = self.cls_batch_num
|
|
56
|
+
|
|
57
|
+
for beg_img_no in range(0, img_num, batch_num):
|
|
58
|
+
|
|
59
|
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
|
60
|
+
norm_img_batch = []
|
|
61
|
+
max_wh_ratio = 0
|
|
62
|
+
|
|
63
|
+
for ino in range(beg_img_no, end_img_no):
|
|
64
|
+
h, w = img_list[indices[ino]].shape[0:2]
|
|
65
|
+
wh_ratio = w * 1.0 / h
|
|
66
|
+
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
|
67
|
+
for ino in range(beg_img_no, end_img_no):
|
|
68
|
+
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
|
69
|
+
norm_img = norm_img[np.newaxis, :]
|
|
70
|
+
norm_img_batch.append(norm_img)
|
|
71
|
+
norm_img_batch = np.concatenate(norm_img_batch)
|
|
72
|
+
norm_img_batch = norm_img_batch.copy()
|
|
73
|
+
|
|
74
|
+
input_feed = self.get_input_feed(self.cls_input_name, norm_img_batch)
|
|
75
|
+
outputs = self.cls_onnx_session.run(
|
|
76
|
+
self.cls_output_name, input_feed=input_feed
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
prob_out = outputs[0]
|
|
80
|
+
|
|
81
|
+
cls_result = self.postprocess_op(prob_out)
|
|
82
|
+
for rno in range(len(cls_result)):
|
|
83
|
+
label, score = cls_result[rno]
|
|
84
|
+
cls_res[indices[beg_img_no + rno]] = [label, score]
|
|
85
|
+
if "180" in label and score > self.cls_thresh:
|
|
86
|
+
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
|
87
|
+
img_list[indices[beg_img_no + rno]], 1
|
|
88
|
+
)
|
|
89
|
+
return img_list, cls_res
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from .imaug import transform, create_operators
|
|
3
|
+
from .db_postprocess import DBPostProcess
|
|
4
|
+
from .predict_base import PredictBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TextDetector(PredictBase):
|
|
8
|
+
def __init__(self, args):
|
|
9
|
+
self.args = args
|
|
10
|
+
self.det_algorithm = args.det_algorithm
|
|
11
|
+
pre_process_list = [
|
|
12
|
+
{
|
|
13
|
+
"DetResizeForTest": {
|
|
14
|
+
"limit_side_len": args.det_limit_side_len,
|
|
15
|
+
"limit_type": args.det_limit_type,
|
|
16
|
+
}
|
|
17
|
+
},
|
|
18
|
+
{
|
|
19
|
+
"NormalizeImage": {
|
|
20
|
+
"std": [0.229, 0.224, 0.225],
|
|
21
|
+
"mean": [0.485, 0.456, 0.406],
|
|
22
|
+
"scale": "1./255.",
|
|
23
|
+
"order": "hwc",
|
|
24
|
+
}
|
|
25
|
+
},
|
|
26
|
+
{"ToCHWImage": None},
|
|
27
|
+
{"KeepKeys": {"keep_keys": ["image", "shape"]}},
|
|
28
|
+
]
|
|
29
|
+
postprocess_params = {}
|
|
30
|
+
postprocess_params["name"] = "DBPostProcess"
|
|
31
|
+
postprocess_params["thresh"] = args.det_db_thresh
|
|
32
|
+
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
|
33
|
+
postprocess_params["max_candidates"] = 1000
|
|
34
|
+
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
|
35
|
+
postprocess_params["use_dilation"] = args.use_dilation
|
|
36
|
+
postprocess_params["score_mode"] = args.det_db_score_mode
|
|
37
|
+
postprocess_params["box_type"] = args.det_box_type
|
|
38
|
+
|
|
39
|
+
# 实例化预处理操作类
|
|
40
|
+
self.preprocess_op = create_operators(pre_process_list)
|
|
41
|
+
# self.postprocess_op = build_post_process(postprocess_params)
|
|
42
|
+
# 实例化后处理操作类
|
|
43
|
+
self.postprocess_op = DBPostProcess(**postprocess_params)
|
|
44
|
+
|
|
45
|
+
# 初始化模型
|
|
46
|
+
self.det_onnx_session = self.get_onnx_session(args.det_model_dir, args.use_gpu)
|
|
47
|
+
self.det_input_name = self.get_input_name(self.det_onnx_session)
|
|
48
|
+
self.det_output_name = self.get_output_name(self.det_onnx_session)
|
|
49
|
+
|
|
50
|
+
def order_points_clockwise(self, pts):
|
|
51
|
+
rect = np.zeros((4, 2), dtype="float32")
|
|
52
|
+
s = pts.sum(axis=1)
|
|
53
|
+
rect[0] = pts[np.argmin(s)]
|
|
54
|
+
rect[2] = pts[np.argmax(s)]
|
|
55
|
+
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
|
56
|
+
diff = np.diff(np.array(tmp), axis=1)
|
|
57
|
+
rect[1] = tmp[np.argmin(diff)]
|
|
58
|
+
rect[3] = tmp[np.argmax(diff)]
|
|
59
|
+
return rect
|
|
60
|
+
|
|
61
|
+
def clip_det_res(self, points, img_height, img_width):
|
|
62
|
+
for pno in range(points.shape[0]):
|
|
63
|
+
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
|
64
|
+
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
|
65
|
+
return points
|
|
66
|
+
|
|
67
|
+
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
68
|
+
img_height, img_width = image_shape[0:2]
|
|
69
|
+
dt_boxes_new = []
|
|
70
|
+
for box in dt_boxes:
|
|
71
|
+
if type(box) is list:
|
|
72
|
+
box = np.array(box)
|
|
73
|
+
box = self.order_points_clockwise(box)
|
|
74
|
+
box = self.clip_det_res(box, img_height, img_width)
|
|
75
|
+
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
76
|
+
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
77
|
+
if rect_width <= 3 or rect_height <= 3:
|
|
78
|
+
continue
|
|
79
|
+
dt_boxes_new.append(box)
|
|
80
|
+
dt_boxes = np.array(dt_boxes_new)
|
|
81
|
+
return dt_boxes
|
|
82
|
+
|
|
83
|
+
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
|
84
|
+
img_height, img_width = image_shape[0:2]
|
|
85
|
+
dt_boxes_new = []
|
|
86
|
+
for box in dt_boxes:
|
|
87
|
+
if type(box) is list:
|
|
88
|
+
box = np.array(box)
|
|
89
|
+
box = self.clip_det_res(box, img_height, img_width)
|
|
90
|
+
dt_boxes_new.append(box)
|
|
91
|
+
dt_boxes = np.array(dt_boxes_new)
|
|
92
|
+
return dt_boxes
|
|
93
|
+
|
|
94
|
+
def __call__(self, img):
|
|
95
|
+
ori_im = img.copy()
|
|
96
|
+
data = {"image": img}
|
|
97
|
+
|
|
98
|
+
data = transform(data, self.preprocess_op)
|
|
99
|
+
img, shape_list = data
|
|
100
|
+
if img is None:
|
|
101
|
+
return None, 0
|
|
102
|
+
img = np.expand_dims(img, axis=0)
|
|
103
|
+
shape_list = np.expand_dims(shape_list, axis=0)
|
|
104
|
+
img = img.copy()
|
|
105
|
+
|
|
106
|
+
input_feed = self.get_input_feed(self.det_input_name, img)
|
|
107
|
+
outputs = self.det_onnx_session.run(self.det_output_name, input_feed=input_feed)
|
|
108
|
+
|
|
109
|
+
preds = {}
|
|
110
|
+
preds["maps"] = outputs[0]
|
|
111
|
+
|
|
112
|
+
post_result = self.postprocess_op(preds, shape_list)
|
|
113
|
+
dt_boxes = post_result[0]["points"]
|
|
114
|
+
|
|
115
|
+
if self.args.det_box_type == "poly":
|
|
116
|
+
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
|
117
|
+
else:
|
|
118
|
+
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
|
119
|
+
|
|
120
|
+
return dt_boxes
|