doc-page-extractor 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of doc-page-extractor might be problematic. Click here for more details.

@@ -0,0 +1,187 @@
1
+ import numpy as np
2
+ import cv2
3
+ import sys
4
+ import math
5
+
6
+
7
+ class NormalizeImage(object):
8
+ """ normalize image such as substract mean, divide std
9
+ """
10
+
11
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
12
+ if isinstance(scale, str):
13
+ scale = eval(scale)
14
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
15
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
16
+ std = std if std is not None else [0.229, 0.224, 0.225]
17
+
18
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
19
+ self.mean = np.array(mean).reshape(shape).astype('float32')
20
+ self.std = np.array(std).reshape(shape).astype('float32')
21
+
22
+ def __call__(self, data):
23
+ img = data['image']
24
+ from PIL import Image
25
+ if isinstance(img, Image.Image):
26
+ img = np.array(img)
27
+ assert isinstance(img,
28
+ np.ndarray), "invalid input 'img' in NormalizeImage"
29
+ data['image'] = (
30
+ img.astype('float32') * self.scale - self.mean) / self.std
31
+ return data
32
+
33
+
34
+ class DetResizeForTest(object):
35
+ def __init__(self, **kwargs):
36
+ super(DetResizeForTest, self).__init__()
37
+ self.resize_type = 0
38
+ self.keep_ratio = False
39
+ if 'image_shape' in kwargs:
40
+ self.image_shape = kwargs['image_shape']
41
+ self.resize_type = 1
42
+ if 'keep_ratio' in kwargs:
43
+ self.keep_ratio = kwargs['keep_ratio']
44
+ elif 'limit_side_len' in kwargs:
45
+ self.limit_side_len = kwargs['limit_side_len']
46
+ self.limit_type = kwargs.get('limit_type', 'min')
47
+ elif 'resize_long' in kwargs:
48
+ self.resize_type = 2
49
+ self.resize_long = kwargs.get('resize_long', 960)
50
+ else:
51
+ self.limit_side_len = 736
52
+ self.limit_type = 'min'
53
+
54
+ def __call__(self, data):
55
+ img = data['image']
56
+ src_h, src_w, _ = img.shape
57
+ if sum([src_h, src_w]) < 64:
58
+ img = self.image_padding(img)
59
+
60
+ if self.resize_type == 0:
61
+ # img, shape = self.resize_image_type0(img)
62
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
63
+ elif self.resize_type == 2:
64
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
65
+ else:
66
+ # img, shape = self.resize_image_type1(img)
67
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
68
+ data['image'] = img
69
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
70
+ return data
71
+
72
+ def image_padding(self, im, value=0):
73
+ h, w, c = im.shape
74
+ im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
75
+ im_pad[:h, :w, :] = im
76
+ return im_pad
77
+
78
+ def resize_image_type1(self, img):
79
+ resize_h, resize_w = self.image_shape
80
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
81
+ if self.keep_ratio is True:
82
+ resize_w = ori_w * resize_h / ori_h
83
+ N = math.ceil(resize_w / 32)
84
+ resize_w = N * 32
85
+ ratio_h = float(resize_h) / ori_h
86
+ ratio_w = float(resize_w) / ori_w
87
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
88
+ # return img, np.array([ori_h, ori_w])
89
+ return img, [ratio_h, ratio_w]
90
+
91
+ def resize_image_type0(self, img):
92
+ """
93
+ resize image to a size multiple of 32 which is required by the network
94
+ args:
95
+ img(array): array with shape [h, w, c]
96
+ return(tuple):
97
+ img, (ratio_h, ratio_w)
98
+ """
99
+ limit_side_len = self.limit_side_len
100
+ h, w, c = img.shape
101
+
102
+ # limit the max side
103
+ if self.limit_type == 'max':
104
+ if max(h, w) > limit_side_len:
105
+ if h > w:
106
+ ratio = float(limit_side_len) / h
107
+ else:
108
+ ratio = float(limit_side_len) / w
109
+ else:
110
+ ratio = 1.
111
+ elif self.limit_type == 'min':
112
+ if min(h, w) < limit_side_len:
113
+ if h < w:
114
+ ratio = float(limit_side_len) / h
115
+ else:
116
+ ratio = float(limit_side_len) / w
117
+ else:
118
+ ratio = 1.
119
+ elif self.limit_type == 'resize_long':
120
+ ratio = float(limit_side_len) / max(h, w)
121
+ else:
122
+ raise Exception('not support limit type, image ')
123
+ resize_h = int(h * ratio)
124
+ resize_w = int(w * ratio)
125
+
126
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
127
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
128
+
129
+ try:
130
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
131
+ return None, (None, None)
132
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
133
+ except:
134
+ print(img.shape, resize_w, resize_h)
135
+ sys.exit(0)
136
+ ratio_h = resize_h / float(h)
137
+ ratio_w = resize_w / float(w)
138
+ return img, [ratio_h, ratio_w]
139
+
140
+ def resize_image_type2(self, img):
141
+ h, w, _ = img.shape
142
+
143
+ resize_w = w
144
+ resize_h = h
145
+
146
+ if resize_h > resize_w:
147
+ ratio = float(self.resize_long) / resize_h
148
+ else:
149
+ ratio = float(self.resize_long) / resize_w
150
+
151
+ resize_h = int(resize_h * ratio)
152
+ resize_w = int(resize_w * ratio)
153
+
154
+ max_stride = 128
155
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
156
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
157
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
158
+ ratio_h = resize_h / float(h)
159
+ ratio_w = resize_w / float(w)
160
+
161
+ return img, [ratio_h, ratio_w]
162
+
163
+ class ToCHWImage(object):
164
+ """ convert hwc image to chw image
165
+ """
166
+
167
+ def __init__(self, **kwargs):
168
+ pass
169
+
170
+ def __call__(self, data):
171
+ img = data['image']
172
+ from PIL import Image
173
+ if isinstance(img, Image.Image):
174
+ img = np.array(img)
175
+ data['image'] = img.transpose((2, 0, 1))
176
+ return data
177
+
178
+
179
+ class KeepKeys(object):
180
+ def __init__(self, keep_keys, **kwargs):
181
+ self.keep_keys = keep_keys
182
+
183
+ def __call__(self, data):
184
+ data_list = []
185
+ for key in self.keep_keys:
186
+ data_list.append(data[key])
187
+ return data_list
@@ -0,0 +1,52 @@
1
+ import onnxruntime
2
+
3
+ class PredictBase(object):
4
+ def __init__(self):
5
+ pass
6
+
7
+ def get_onnx_session(self, model_dir, use_gpu):
8
+ # 使用gpu
9
+ if use_gpu:
10
+ providers = providers=['CUDAExecutionProvider']
11
+ else:
12
+ providers = providers = ['CPUExecutionProvider']
13
+
14
+ onnx_session = onnxruntime.InferenceSession(model_dir, None,providers=providers)
15
+
16
+ # print("providers:", onnxruntime.get_device())
17
+ return onnx_session
18
+
19
+
20
+ def get_output_name(self, onnx_session):
21
+ """
22
+ output_name = onnx_session.get_outputs()[0].name
23
+ :param onnx_session:
24
+ :return:
25
+ """
26
+ output_name = []
27
+ for node in onnx_session.get_outputs():
28
+ output_name.append(node.name)
29
+ return output_name
30
+
31
+ def get_input_name(self, onnx_session):
32
+ """
33
+ input_name = onnx_session.get_inputs()[0].name
34
+ :param onnx_session:
35
+ :return:
36
+ """
37
+ input_name = []
38
+ for node in onnx_session.get_inputs():
39
+ input_name.append(node.name)
40
+ return input_name
41
+
42
+ def get_input_feed(self, input_name, image_numpy):
43
+ """
44
+ input_feed={self.input_name: image_numpy}
45
+ :param input_name:
46
+ :param image_numpy:
47
+ :return:
48
+ """
49
+ input_feed = {}
50
+ for name in input_name:
51
+ input_feed[name] = image_numpy
52
+ return input_feed
@@ -0,0 +1,89 @@
1
+ import cv2
2
+ import copy
3
+ import numpy as np
4
+ import math
5
+
6
+ from .cls_postprocess import ClsPostProcess
7
+ from .predict_base import PredictBase
8
+
9
+
10
+ class TextClassifier(PredictBase):
11
+ def __init__(self, args):
12
+ self.cls_image_shape = args.cls_image_shape
13
+ self.cls_batch_num = args.cls_batch_num
14
+ self.cls_thresh = args.cls_thresh
15
+ self.postprocess_op = ClsPostProcess(label_list=args.label_list)
16
+
17
+ # 初始化模型
18
+ self.cls_onnx_session = self.get_onnx_session(args.cls_model_dir, args.use_gpu)
19
+ self.cls_input_name = self.get_input_name(self.cls_onnx_session)
20
+ self.cls_output_name = self.get_output_name(self.cls_onnx_session)
21
+
22
+ def resize_norm_img(self, img):
23
+ imgC, imgH, imgW = self.cls_image_shape
24
+ h = img.shape[0]
25
+ w = img.shape[1]
26
+ ratio = w / float(h)
27
+ if math.ceil(imgH * ratio) > imgW:
28
+ resized_w = imgW
29
+ else:
30
+ resized_w = int(math.ceil(imgH * ratio))
31
+ resized_image = cv2.resize(img, (resized_w, imgH))
32
+ resized_image = resized_image.astype("float32")
33
+ if self.cls_image_shape[0] == 1:
34
+ resized_image = resized_image / 255
35
+ resized_image = resized_image[np.newaxis, :]
36
+ else:
37
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
38
+ resized_image -= 0.5
39
+ resized_image /= 0.5
40
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
41
+ padding_im[:, :, 0:resized_w] = resized_image
42
+ return padding_im
43
+
44
+ def __call__(self, img_list):
45
+ img_list = copy.deepcopy(img_list)
46
+ img_num = len(img_list)
47
+ # Calculate the aspect ratio of all text bars
48
+ width_list = []
49
+ for img in img_list:
50
+ width_list.append(img.shape[1] / float(img.shape[0]))
51
+ # Sorting can speed up the cls process
52
+ indices = np.argsort(np.array(width_list))
53
+
54
+ cls_res = [["", 0.0]] * img_num
55
+ batch_num = self.cls_batch_num
56
+
57
+ for beg_img_no in range(0, img_num, batch_num):
58
+
59
+ end_img_no = min(img_num, beg_img_no + batch_num)
60
+ norm_img_batch = []
61
+ max_wh_ratio = 0
62
+
63
+ for ino in range(beg_img_no, end_img_no):
64
+ h, w = img_list[indices[ino]].shape[0:2]
65
+ wh_ratio = w * 1.0 / h
66
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
67
+ for ino in range(beg_img_no, end_img_no):
68
+ norm_img = self.resize_norm_img(img_list[indices[ino]])
69
+ norm_img = norm_img[np.newaxis, :]
70
+ norm_img_batch.append(norm_img)
71
+ norm_img_batch = np.concatenate(norm_img_batch)
72
+ norm_img_batch = norm_img_batch.copy()
73
+
74
+ input_feed = self.get_input_feed(self.cls_input_name, norm_img_batch)
75
+ outputs = self.cls_onnx_session.run(
76
+ self.cls_output_name, input_feed=input_feed
77
+ )
78
+
79
+ prob_out = outputs[0]
80
+
81
+ cls_result = self.postprocess_op(prob_out)
82
+ for rno in range(len(cls_result)):
83
+ label, score = cls_result[rno]
84
+ cls_res[indices[beg_img_no + rno]] = [label, score]
85
+ if "180" in label and score > self.cls_thresh:
86
+ img_list[indices[beg_img_no + rno]] = cv2.rotate(
87
+ img_list[indices[beg_img_no + rno]], 1
88
+ )
89
+ return img_list, cls_res
@@ -0,0 +1,120 @@
1
+ import numpy as np
2
+ from .imaug import transform, create_operators
3
+ from .db_postprocess import DBPostProcess
4
+ from .predict_base import PredictBase
5
+
6
+
7
+ class TextDetector(PredictBase):
8
+ def __init__(self, args):
9
+ self.args = args
10
+ self.det_algorithm = args.det_algorithm
11
+ pre_process_list = [
12
+ {
13
+ "DetResizeForTest": {
14
+ "limit_side_len": args.det_limit_side_len,
15
+ "limit_type": args.det_limit_type,
16
+ }
17
+ },
18
+ {
19
+ "NormalizeImage": {
20
+ "std": [0.229, 0.224, 0.225],
21
+ "mean": [0.485, 0.456, 0.406],
22
+ "scale": "1./255.",
23
+ "order": "hwc",
24
+ }
25
+ },
26
+ {"ToCHWImage": None},
27
+ {"KeepKeys": {"keep_keys": ["image", "shape"]}},
28
+ ]
29
+ postprocess_params = {}
30
+ postprocess_params["name"] = "DBPostProcess"
31
+ postprocess_params["thresh"] = args.det_db_thresh
32
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
33
+ postprocess_params["max_candidates"] = 1000
34
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
35
+ postprocess_params["use_dilation"] = args.use_dilation
36
+ postprocess_params["score_mode"] = args.det_db_score_mode
37
+ postprocess_params["box_type"] = args.det_box_type
38
+
39
+ # 实例化预处理操作类
40
+ self.preprocess_op = create_operators(pre_process_list)
41
+ # self.postprocess_op = build_post_process(postprocess_params)
42
+ # 实例化后处理操作类
43
+ self.postprocess_op = DBPostProcess(**postprocess_params)
44
+
45
+ # 初始化模型
46
+ self.det_onnx_session = self.get_onnx_session(args.det_model_dir, args.use_gpu)
47
+ self.det_input_name = self.get_input_name(self.det_onnx_session)
48
+ self.det_output_name = self.get_output_name(self.det_onnx_session)
49
+
50
+ def order_points_clockwise(self, pts):
51
+ rect = np.zeros((4, 2), dtype="float32")
52
+ s = pts.sum(axis=1)
53
+ rect[0] = pts[np.argmin(s)]
54
+ rect[2] = pts[np.argmax(s)]
55
+ tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
56
+ diff = np.diff(np.array(tmp), axis=1)
57
+ rect[1] = tmp[np.argmin(diff)]
58
+ rect[3] = tmp[np.argmax(diff)]
59
+ return rect
60
+
61
+ def clip_det_res(self, points, img_height, img_width):
62
+ for pno in range(points.shape[0]):
63
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
64
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
65
+ return points
66
+
67
+ def filter_tag_det_res(self, dt_boxes, image_shape):
68
+ img_height, img_width = image_shape[0:2]
69
+ dt_boxes_new = []
70
+ for box in dt_boxes:
71
+ if type(box) is list:
72
+ box = np.array(box)
73
+ box = self.order_points_clockwise(box)
74
+ box = self.clip_det_res(box, img_height, img_width)
75
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
76
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
77
+ if rect_width <= 3 or rect_height <= 3:
78
+ continue
79
+ dt_boxes_new.append(box)
80
+ dt_boxes = np.array(dt_boxes_new)
81
+ return dt_boxes
82
+
83
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
84
+ img_height, img_width = image_shape[0:2]
85
+ dt_boxes_new = []
86
+ for box in dt_boxes:
87
+ if type(box) is list:
88
+ box = np.array(box)
89
+ box = self.clip_det_res(box, img_height, img_width)
90
+ dt_boxes_new.append(box)
91
+ dt_boxes = np.array(dt_boxes_new)
92
+ return dt_boxes
93
+
94
+ def __call__(self, img):
95
+ ori_im = img.copy()
96
+ data = {"image": img}
97
+
98
+ data = transform(data, self.preprocess_op)
99
+ img, shape_list = data
100
+ if img is None:
101
+ return None, 0
102
+ img = np.expand_dims(img, axis=0)
103
+ shape_list = np.expand_dims(shape_list, axis=0)
104
+ img = img.copy()
105
+
106
+ input_feed = self.get_input_feed(self.det_input_name, img)
107
+ outputs = self.det_onnx_session.run(self.det_output_name, input_feed=input_feed)
108
+
109
+ preds = {}
110
+ preds["maps"] = outputs[0]
111
+
112
+ post_result = self.postprocess_op(preds, shape_list)
113
+ dt_boxes = post_result[0]["points"]
114
+
115
+ if self.args.det_box_type == "poly":
116
+ dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
117
+ else:
118
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
119
+
120
+ return dt_boxes