xfmr-zem 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/cli.py +32 -3
- xfmr_zem/client.py +59 -8
- xfmr_zem/server.py +21 -4
- xfmr_zem/servers/data_juicer/server.py +1 -1
- xfmr_zem/servers/instruction_gen/server.py +1 -1
- xfmr_zem/servers/io/server.py +1 -1
- xfmr_zem/servers/llm/parameters.yml +10 -0
- xfmr_zem/servers/nemo_curator/server.py +1 -1
- xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
- xfmr_zem/servers/ocr/engines.py +242 -0
- xfmr_zem/servers/ocr/install_models.py +63 -0
- xfmr_zem/servers/ocr/parameters.yml +4 -0
- xfmr_zem/servers/ocr/server.py +44 -0
- xfmr_zem/servers/profiler/parameters.yml +4 -0
- xfmr_zem/servers/sinks/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/server.py +62 -0
- xfmr_zem/zenml_wrapper.py +20 -7
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/METADATA +19 -1
- xfmr_zem-0.2.5.dist-info/RECORD +58 -0
- xfmr_zem-0.2.2.dist-info/RECORD +0 -23
- /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/WHEEL +0 -0
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/entry_points.txt +0 -0
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import math
|
|
20
|
+
import numpy as np
|
|
21
|
+
import cv2
|
|
22
|
+
from functools import cmp_to_key
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
from .operators import * # noqa: F403
|
|
26
|
+
from .operators import preprocess
|
|
27
|
+
from . import operators
|
|
28
|
+
from .ocr import load_model
|
|
29
|
+
|
|
30
|
+
def get_project_base_directory():
|
|
31
|
+
return Path(__file__).resolve().parent
|
|
32
|
+
|
|
33
|
+
class Recognizer:
|
|
34
|
+
def __init__(self, label_list, task_name, model_dir=None):
|
|
35
|
+
"""
|
|
36
|
+
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
37
|
+
|
|
38
|
+
For Linux:
|
|
39
|
+
export HF_ENDPOINT=https://hf-mirror.com
|
|
40
|
+
|
|
41
|
+
For Windows:
|
|
42
|
+
Good luck
|
|
43
|
+
^_-
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
if not model_dir:
|
|
47
|
+
model_dir = os.path.join(
|
|
48
|
+
get_project_base_directory(),
|
|
49
|
+
"onnx")
|
|
50
|
+
self.ort_sess, self.run_options = load_model(model_dir, task_name)
|
|
51
|
+
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
|
|
52
|
+
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
|
|
53
|
+
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
|
|
54
|
+
self.label_list = label_list
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def sort_Y_firstly(arr, threashold):
|
|
58
|
+
def cmp(c1, c2):
|
|
59
|
+
diff = c1["top"] - c2["top"]
|
|
60
|
+
if abs(diff) < threashold:
|
|
61
|
+
diff = c1["x0"] - c2["x0"]
|
|
62
|
+
return diff
|
|
63
|
+
arr = sorted(arr, key=cmp_to_key(cmp))
|
|
64
|
+
return arr
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def sort_X_firstly(arr, threashold):
|
|
68
|
+
def cmp(c1, c2):
|
|
69
|
+
diff = c1["x0"] - c2["x0"]
|
|
70
|
+
if abs(diff) < threashold:
|
|
71
|
+
diff = c1["top"] - c2["top"]
|
|
72
|
+
return diff
|
|
73
|
+
arr = sorted(arr, key=cmp_to_key(cmp))
|
|
74
|
+
return arr
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def sort_C_firstly(arr, thr=0):
|
|
78
|
+
# sort using y1 first and then x1
|
|
79
|
+
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
|
80
|
+
arr = Recognizer.sort_X_firstly(arr, thr)
|
|
81
|
+
for i in range(len(arr) - 1):
|
|
82
|
+
for j in range(i, -1, -1):
|
|
83
|
+
# restore the order using th
|
|
84
|
+
if "C" not in arr[j] or "C" not in arr[j + 1]:
|
|
85
|
+
continue
|
|
86
|
+
if arr[j + 1]["C"] < arr[j]["C"] \
|
|
87
|
+
or (
|
|
88
|
+
arr[j + 1]["C"] == arr[j]["C"]
|
|
89
|
+
and arr[j + 1]["top"] < arr[j]["top"]
|
|
90
|
+
):
|
|
91
|
+
tmp = arr[j]
|
|
92
|
+
arr[j] = arr[j + 1]
|
|
93
|
+
arr[j + 1] = tmp
|
|
94
|
+
return arr
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def sort_R_firstly(arr, thr=0):
|
|
98
|
+
# sort using y1 first and then x1
|
|
99
|
+
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
|
100
|
+
arr = Recognizer.sort_Y_firstly(arr, thr)
|
|
101
|
+
for i in range(len(arr) - 1):
|
|
102
|
+
for j in range(i, -1, -1):
|
|
103
|
+
if "R" not in arr[j] or "R" not in arr[j + 1]:
|
|
104
|
+
continue
|
|
105
|
+
if arr[j + 1]["R"] < arr[j]["R"] \
|
|
106
|
+
or (
|
|
107
|
+
arr[j + 1]["R"] == arr[j]["R"]
|
|
108
|
+
and arr[j + 1]["x0"] < arr[j]["x0"]
|
|
109
|
+
):
|
|
110
|
+
tmp = arr[j]
|
|
111
|
+
arr[j] = arr[j + 1]
|
|
112
|
+
arr[j + 1] = tmp
|
|
113
|
+
return arr
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def overlapped_area(a, b, ratio=True):
|
|
117
|
+
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
|
|
118
|
+
if b["x0"] > x1 or b["x1"] < x0:
|
|
119
|
+
return 0
|
|
120
|
+
if b["bottom"] < tp or b["top"] > btm:
|
|
121
|
+
return 0
|
|
122
|
+
x0_ = max(b["x0"], x0)
|
|
123
|
+
x1_ = min(b["x1"], x1)
|
|
124
|
+
assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format(
|
|
125
|
+
tp, btm, x0, x1, b)
|
|
126
|
+
tp_ = max(b["top"], tp)
|
|
127
|
+
btm_ = min(b["bottom"], btm)
|
|
128
|
+
assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format(
|
|
129
|
+
tp, btm, x0, x1, b)
|
|
130
|
+
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
|
131
|
+
x0 != 0 and btm - tp != 0 else 0
|
|
132
|
+
if ov > 0 and ratio:
|
|
133
|
+
ov /= (x1 - x0) * (btm - tp)
|
|
134
|
+
return ov
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
|
|
138
|
+
def notOverlapped(a, b):
|
|
139
|
+
return any([a["x1"] < b["x0"],
|
|
140
|
+
a["x0"] > b["x1"],
|
|
141
|
+
a["bottom"] < b["top"],
|
|
142
|
+
a["top"] > b["bottom"]])
|
|
143
|
+
|
|
144
|
+
i = 0
|
|
145
|
+
while i + 1 < len(layouts):
|
|
146
|
+
j = i + 1
|
|
147
|
+
while j < min(i + far, len(layouts)) \
|
|
148
|
+
and (layouts[i].get("type", "") != layouts[j].get("type", "")
|
|
149
|
+
or notOverlapped(layouts[i], layouts[j])):
|
|
150
|
+
j += 1
|
|
151
|
+
if j >= min(i + far, len(layouts)):
|
|
152
|
+
i += 1
|
|
153
|
+
continue
|
|
154
|
+
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
|
|
155
|
+
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
|
|
156
|
+
i += 1
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
if layouts[i].get("score") and layouts[j].get("score"):
|
|
160
|
+
if layouts[i]["score"] > layouts[j]["score"]:
|
|
161
|
+
layouts.pop(j)
|
|
162
|
+
else:
|
|
163
|
+
layouts.pop(i)
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
area_i, area_i_1 = 0, 0
|
|
167
|
+
for b in boxes:
|
|
168
|
+
if not notOverlapped(b, layouts[i]):
|
|
169
|
+
area_i += Recognizer.overlapped_area(b, layouts[i], False)
|
|
170
|
+
if not notOverlapped(b, layouts[j]):
|
|
171
|
+
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
|
|
172
|
+
|
|
173
|
+
if area_i > area_i_1:
|
|
174
|
+
layouts.pop(j)
|
|
175
|
+
else:
|
|
176
|
+
layouts.pop(i)
|
|
177
|
+
|
|
178
|
+
return layouts
|
|
179
|
+
|
|
180
|
+
def create_inputs(self, imgs, im_info):
|
|
181
|
+
"""generate input for different model type
|
|
182
|
+
Args:
|
|
183
|
+
imgs (list(numpy)): list of images (np.ndarray)
|
|
184
|
+
im_info (list(dict)): list of image info
|
|
185
|
+
Returns:
|
|
186
|
+
inputs (dict): input of model
|
|
187
|
+
"""
|
|
188
|
+
inputs = {}
|
|
189
|
+
|
|
190
|
+
im_shape = []
|
|
191
|
+
scale_factor = []
|
|
192
|
+
if len(imgs) == 1:
|
|
193
|
+
inputs['image'] = np.array((imgs[0],)).astype('float32')
|
|
194
|
+
inputs['im_shape'] = np.array(
|
|
195
|
+
(im_info[0]['im_shape'],)).astype('float32')
|
|
196
|
+
inputs['scale_factor'] = np.array(
|
|
197
|
+
(im_info[0]['scale_factor'],)).astype('float32')
|
|
198
|
+
return inputs
|
|
199
|
+
|
|
200
|
+
im_shape = np.array([info['im_shape'] for info in im_info], dtype='float32')
|
|
201
|
+
scale_factor = np.array([info['scale_factor'] for info in im_info], dtype='float32')
|
|
202
|
+
|
|
203
|
+
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
|
|
204
|
+
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
|
|
205
|
+
|
|
206
|
+
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
|
|
207
|
+
max_shape_h = max([e[0] for e in imgs_shape])
|
|
208
|
+
max_shape_w = max([e[1] for e in imgs_shape])
|
|
209
|
+
padding_imgs = []
|
|
210
|
+
for img in imgs:
|
|
211
|
+
im_c, im_h, im_w = img.shape[:]
|
|
212
|
+
padding_im = np.zeros(
|
|
213
|
+
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
|
|
214
|
+
padding_im[:, :im_h, :im_w] = img
|
|
215
|
+
padding_imgs.append(padding_im)
|
|
216
|
+
inputs['image'] = np.stack(padding_imgs, axis=0)
|
|
217
|
+
return inputs
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def find_overlapped(box, boxes_sorted_by_y, naive=False):
|
|
221
|
+
if not boxes_sorted_by_y:
|
|
222
|
+
return
|
|
223
|
+
bxs = boxes_sorted_by_y
|
|
224
|
+
s, e, ii = 0, len(bxs), 0
|
|
225
|
+
while s < e and not naive:
|
|
226
|
+
ii = (e + s) // 2
|
|
227
|
+
pv = bxs[ii]
|
|
228
|
+
if box["bottom"] < pv["top"]:
|
|
229
|
+
e = ii
|
|
230
|
+
continue
|
|
231
|
+
if box["top"] > pv["bottom"]:
|
|
232
|
+
s = ii + 1
|
|
233
|
+
continue
|
|
234
|
+
break
|
|
235
|
+
while s < ii:
|
|
236
|
+
if box["top"] > bxs[s]["bottom"]:
|
|
237
|
+
s += 1
|
|
238
|
+
break
|
|
239
|
+
while e - 1 > ii:
|
|
240
|
+
if box["bottom"] < bxs[e - 1]["top"]:
|
|
241
|
+
e -= 1
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
max_overlaped_i, max_overlaped = None, 0
|
|
245
|
+
for i in range(s, e):
|
|
246
|
+
ov = Recognizer.overlapped_area(bxs[i], box)
|
|
247
|
+
if ov <= max_overlaped:
|
|
248
|
+
continue
|
|
249
|
+
max_overlaped_i = i
|
|
250
|
+
max_overlaped = ov
|
|
251
|
+
|
|
252
|
+
return max_overlaped_i
|
|
253
|
+
|
|
254
|
+
@staticmethod
|
|
255
|
+
def find_horizontally_tightest_fit(box, boxes):
|
|
256
|
+
if not boxes:
|
|
257
|
+
return
|
|
258
|
+
min_dis, min_i = 1000000, None
|
|
259
|
+
for i,b in enumerate(boxes):
|
|
260
|
+
if box.get("layoutno", "0") != b.get("layoutno", "0"):
|
|
261
|
+
continue
|
|
262
|
+
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
|
263
|
+
if dis < min_dis:
|
|
264
|
+
min_i = i
|
|
265
|
+
min_dis = dis
|
|
266
|
+
return min_i
|
|
267
|
+
|
|
268
|
+
@staticmethod
|
|
269
|
+
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
|
270
|
+
if not boxes:
|
|
271
|
+
return
|
|
272
|
+
max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
|
|
273
|
+
s, e = 0, len(boxes)
|
|
274
|
+
for i in range(s, e):
|
|
275
|
+
ov = Recognizer.overlapped_area(box, boxes[i])
|
|
276
|
+
_ov = Recognizer.overlapped_area(boxes[i], box)
|
|
277
|
+
if (ov, _ov) < (max_overlapped, _max_overlapped):
|
|
278
|
+
continue
|
|
279
|
+
max_overlapped_i = i
|
|
280
|
+
max_overlapped = ov
|
|
281
|
+
_max_overlapped = _ov
|
|
282
|
+
|
|
283
|
+
return max_overlapped_i
|
|
284
|
+
|
|
285
|
+
def preprocess(self, image_list):
|
|
286
|
+
inputs = []
|
|
287
|
+
if "scale_factor" in self.input_names:
|
|
288
|
+
preprocess_ops = []
|
|
289
|
+
for op_info in [
|
|
290
|
+
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
|
|
291
|
+
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
|
|
292
|
+
{'type': 'Permute'},
|
|
293
|
+
{'stride': 32, 'type': 'PadStride'}
|
|
294
|
+
]:
|
|
295
|
+
new_op_info = op_info.copy()
|
|
296
|
+
op_type = new_op_info.pop('type')
|
|
297
|
+
preprocess_ops.append(getattr(operators, op_type)(**new_op_info))
|
|
298
|
+
|
|
299
|
+
for im_path in image_list:
|
|
300
|
+
im, im_info = preprocess(im_path, preprocess_ops)
|
|
301
|
+
inputs.append({"image": np.array((im,)).astype('float32'),
|
|
302
|
+
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
|
|
303
|
+
else:
|
|
304
|
+
hh, ww = self.input_shape
|
|
305
|
+
for img in image_list:
|
|
306
|
+
h, w = img.shape[:2]
|
|
307
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
308
|
+
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
|
|
309
|
+
# Scale input pixel values to 0 to 1
|
|
310
|
+
img /= 255.0
|
|
311
|
+
img = img.transpose(2, 0, 1)
|
|
312
|
+
img = img[np.newaxis, :, :, :].astype(np.float32)
|
|
313
|
+
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
|
|
314
|
+
return inputs
|
|
315
|
+
|
|
316
|
+
def postprocess(self, boxes, inputs, thr):
|
|
317
|
+
if "scale_factor" in self.input_names:
|
|
318
|
+
bb = []
|
|
319
|
+
for b in boxes:
|
|
320
|
+
clsid, bbox, score = int(b[0]), b[2:], b[1]
|
|
321
|
+
if score < thr:
|
|
322
|
+
continue
|
|
323
|
+
if clsid >= len(self.label_list):
|
|
324
|
+
continue
|
|
325
|
+
bb.append({
|
|
326
|
+
"type": self.label_list[clsid].lower(),
|
|
327
|
+
"bbox": [float(t) for t in bbox.tolist()],
|
|
328
|
+
"score": float(score)
|
|
329
|
+
})
|
|
330
|
+
return bb
|
|
331
|
+
|
|
332
|
+
def xywh2xyxy(x):
|
|
333
|
+
# [x, y, w, h] to [x1, y1, x2, y2]
|
|
334
|
+
y = np.copy(x)
|
|
335
|
+
y[:, 0] = x[:, 0] - x[:, 2] / 2
|
|
336
|
+
y[:, 1] = x[:, 1] - x[:, 3] / 2
|
|
337
|
+
y[:, 2] = x[:, 0] + x[:, 2] / 2
|
|
338
|
+
y[:, 3] = x[:, 1] + x[:, 3] / 2
|
|
339
|
+
return y
|
|
340
|
+
|
|
341
|
+
def compute_iou(box, boxes):
|
|
342
|
+
# Compute xmin, ymin, xmax, ymax for both boxes
|
|
343
|
+
xmin = np.maximum(box[0], boxes[:, 0])
|
|
344
|
+
ymin = np.maximum(box[1], boxes[:, 1])
|
|
345
|
+
xmax = np.minimum(box[2], boxes[:, 2])
|
|
346
|
+
ymax = np.minimum(box[3], boxes[:, 3])
|
|
347
|
+
|
|
348
|
+
# Compute intersection area
|
|
349
|
+
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
|
|
350
|
+
|
|
351
|
+
# Compute union area
|
|
352
|
+
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
|
353
|
+
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
354
|
+
union_area = box_area + boxes_area - intersection_area
|
|
355
|
+
|
|
356
|
+
# Compute IoU
|
|
357
|
+
iou = intersection_area / union_area
|
|
358
|
+
|
|
359
|
+
return iou
|
|
360
|
+
|
|
361
|
+
def iou_filter(boxes, scores, iou_threshold):
|
|
362
|
+
sorted_indices = np.argsort(scores)[::-1]
|
|
363
|
+
|
|
364
|
+
keep_boxes = []
|
|
365
|
+
while sorted_indices.size > 0:
|
|
366
|
+
# Pick the last box
|
|
367
|
+
box_id = sorted_indices[0]
|
|
368
|
+
keep_boxes.append(box_id)
|
|
369
|
+
|
|
370
|
+
# Compute IoU of the picked box with the rest
|
|
371
|
+
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
|
|
372
|
+
|
|
373
|
+
# Remove boxes with IoU over the threshold
|
|
374
|
+
keep_indices = np.where(ious < iou_threshold)[0]
|
|
375
|
+
|
|
376
|
+
# print(keep_indices.shape, sorted_indices.shape)
|
|
377
|
+
sorted_indices = sorted_indices[keep_indices + 1]
|
|
378
|
+
|
|
379
|
+
return keep_boxes
|
|
380
|
+
|
|
381
|
+
boxes = np.squeeze(boxes).T
|
|
382
|
+
# Filter out object confidence scores below threshold
|
|
383
|
+
scores = np.max(boxes[:, 4:], axis=1)
|
|
384
|
+
boxes = boxes[scores > thr, :]
|
|
385
|
+
scores = scores[scores > thr]
|
|
386
|
+
if len(boxes) == 0:
|
|
387
|
+
return []
|
|
388
|
+
|
|
389
|
+
# Get the class with the highest confidence
|
|
390
|
+
class_ids = np.argmax(boxes[:, 4:], axis=1)
|
|
391
|
+
boxes = boxes[:, :4]
|
|
392
|
+
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
|
393
|
+
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
|
394
|
+
boxes = xywh2xyxy(boxes)
|
|
395
|
+
|
|
396
|
+
unique_class_ids = np.unique(class_ids)
|
|
397
|
+
indices = []
|
|
398
|
+
for class_id in unique_class_ids:
|
|
399
|
+
class_indices = np.where(class_ids == class_id)[0]
|
|
400
|
+
class_boxes = boxes[class_indices, :]
|
|
401
|
+
class_scores = scores[class_indices]
|
|
402
|
+
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
|
|
403
|
+
indices.extend(class_indices[class_keep_boxes])
|
|
404
|
+
|
|
405
|
+
return [{
|
|
406
|
+
"type": self.label_list[class_ids[i]].lower(),
|
|
407
|
+
"bbox": [float(t) for t in boxes[i].tolist()],
|
|
408
|
+
"score": float(scores[i])
|
|
409
|
+
} for i in indices]
|
|
410
|
+
|
|
411
|
+
def __call__(self, image_list, thr=0.7, batch_size=16):
|
|
412
|
+
res = []
|
|
413
|
+
imgs = []
|
|
414
|
+
for i in range(len(image_list)):
|
|
415
|
+
if not isinstance(image_list[i], np.ndarray):
|
|
416
|
+
imgs.append(np.array(image_list[i]))
|
|
417
|
+
else:
|
|
418
|
+
imgs.append(image_list[i])
|
|
419
|
+
|
|
420
|
+
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
|
421
|
+
for i in range(batch_loop_cnt):
|
|
422
|
+
start_index = i * batch_size
|
|
423
|
+
end_index = min((i + 1) * batch_size, len(imgs))
|
|
424
|
+
batch_image_list = imgs[start_index:end_index]
|
|
425
|
+
inputs = self.preprocess(batch_image_list)
|
|
426
|
+
logging.debug("preprocess")
|
|
427
|
+
for ins in inputs:
|
|
428
|
+
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr)
|
|
429
|
+
res.append(bb)
|
|
430
|
+
|
|
431
|
+
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
|
|
432
|
+
|
|
433
|
+
return res
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
|