xfmr-zem 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/cli.py +32 -3
- xfmr_zem/client.py +59 -8
- xfmr_zem/server.py +21 -4
- xfmr_zem/servers/data_juicer/server.py +1 -1
- xfmr_zem/servers/instruction_gen/server.py +1 -1
- xfmr_zem/servers/io/server.py +1 -1
- xfmr_zem/servers/llm/parameters.yml +10 -0
- xfmr_zem/servers/nemo_curator/server.py +1 -1
- xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
- xfmr_zem/servers/ocr/engines.py +242 -0
- xfmr_zem/servers/ocr/install_models.py +63 -0
- xfmr_zem/servers/ocr/parameters.yml +4 -0
- xfmr_zem/servers/ocr/server.py +44 -0
- xfmr_zem/servers/profiler/parameters.yml +4 -0
- xfmr_zem/servers/sinks/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/server.py +62 -0
- xfmr_zem/zenml_wrapper.py +20 -7
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/METADATA +19 -1
- xfmr_zem-0.2.5.dist-info/RECORD +58 -0
- xfmr_zem-0.2.2.dist-info/RECORD +0 -23
- /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/WHEEL +0 -0
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/entry_points.txt +0 -0
- {xfmr_zem-0.2.2.dist-info → xfmr_zem-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
import re
|
|
19
|
+
import numpy as np
|
|
20
|
+
import cv2
|
|
21
|
+
from shapely.geometry import Polygon
|
|
22
|
+
import pyclipper
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def build_post_process(config, global_config=None):
|
|
26
|
+
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
|
|
27
|
+
|
|
28
|
+
config = copy.deepcopy(config)
|
|
29
|
+
module_name = config.pop('name')
|
|
30
|
+
if module_name == "None":
|
|
31
|
+
return
|
|
32
|
+
if global_config is not None:
|
|
33
|
+
config.update(global_config)
|
|
34
|
+
module_class = support_dict.get(module_name)
|
|
35
|
+
if module_class is None:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
'post process only support {}'.format(list(support_dict)))
|
|
38
|
+
return module_class(**config)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DBPostProcess:
|
|
42
|
+
"""
|
|
43
|
+
The post process for Differentiable Binarization (DB).
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self,
|
|
47
|
+
thresh=0.3,
|
|
48
|
+
box_thresh=0.7,
|
|
49
|
+
max_candidates=1000,
|
|
50
|
+
unclip_ratio=2.0,
|
|
51
|
+
use_dilation=False,
|
|
52
|
+
score_mode="fast",
|
|
53
|
+
box_type='quad',
|
|
54
|
+
**kwargs):
|
|
55
|
+
self.thresh = thresh
|
|
56
|
+
self.box_thresh = box_thresh
|
|
57
|
+
self.max_candidates = max_candidates
|
|
58
|
+
self.unclip_ratio = unclip_ratio
|
|
59
|
+
self.min_size = 3
|
|
60
|
+
self.score_mode = score_mode
|
|
61
|
+
self.box_type = box_type
|
|
62
|
+
assert score_mode in [
|
|
63
|
+
"slow", "fast"
|
|
64
|
+
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
|
65
|
+
|
|
66
|
+
self.dilation_kernel = None if not use_dilation else np.array(
|
|
67
|
+
[[1, 1], [1, 1]])
|
|
68
|
+
|
|
69
|
+
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
|
70
|
+
'''
|
|
71
|
+
_bitmap: single map with shape (1, H, W),
|
|
72
|
+
whose values are binarized as {0, 1}
|
|
73
|
+
'''
|
|
74
|
+
|
|
75
|
+
bitmap = _bitmap
|
|
76
|
+
height, width = bitmap.shape
|
|
77
|
+
|
|
78
|
+
boxes = []
|
|
79
|
+
scores = []
|
|
80
|
+
|
|
81
|
+
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
|
82
|
+
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
|
83
|
+
|
|
84
|
+
for contour in contours[:self.max_candidates]:
|
|
85
|
+
epsilon = 0.002 * cv2.arcLength(contour, True)
|
|
86
|
+
approx = cv2.approxPolyDP(contour, epsilon, True)
|
|
87
|
+
points = approx.reshape((-1, 2))
|
|
88
|
+
if points.shape[0] < 4:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
|
92
|
+
if self.box_thresh > score:
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
if points.shape[0] > 2:
|
|
96
|
+
box = self.unclip(points, self.unclip_ratio)
|
|
97
|
+
if len(box) > 1:
|
|
98
|
+
continue
|
|
99
|
+
else:
|
|
100
|
+
continue
|
|
101
|
+
box = box.reshape(-1, 2)
|
|
102
|
+
|
|
103
|
+
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
|
104
|
+
if sside < self.min_size + 2:
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
box = np.array(box)
|
|
108
|
+
box[:, 0] = np.clip(
|
|
109
|
+
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
|
110
|
+
box[:, 1] = np.clip(
|
|
111
|
+
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
|
112
|
+
boxes.append(box.tolist())
|
|
113
|
+
scores.append(score)
|
|
114
|
+
return boxes, scores
|
|
115
|
+
|
|
116
|
+
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
|
117
|
+
'''
|
|
118
|
+
_bitmap: single map with shape (1, H, W),
|
|
119
|
+
whose values are binarized as {0, 1}
|
|
120
|
+
'''
|
|
121
|
+
|
|
122
|
+
bitmap = _bitmap
|
|
123
|
+
height, width = bitmap.shape
|
|
124
|
+
|
|
125
|
+
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
|
126
|
+
cv2.CHAIN_APPROX_SIMPLE)
|
|
127
|
+
if len(outs) == 3:
|
|
128
|
+
_img, contours, _ = outs[0], outs[1], outs[2]
|
|
129
|
+
elif len(outs) == 2:
|
|
130
|
+
contours, _ = outs[0], outs[1]
|
|
131
|
+
|
|
132
|
+
num_contours = min(len(contours), self.max_candidates)
|
|
133
|
+
|
|
134
|
+
boxes = []
|
|
135
|
+
scores = []
|
|
136
|
+
for index in range(num_contours):
|
|
137
|
+
contour = contours[index]
|
|
138
|
+
points, sside = self.get_mini_boxes(contour)
|
|
139
|
+
if sside < self.min_size:
|
|
140
|
+
continue
|
|
141
|
+
points = np.array(points)
|
|
142
|
+
if self.score_mode == "fast":
|
|
143
|
+
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
|
144
|
+
else:
|
|
145
|
+
score = self.box_score_slow(pred, contour)
|
|
146
|
+
if self.box_thresh > score:
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
|
150
|
+
box, sside = self.get_mini_boxes(box)
|
|
151
|
+
if sside < self.min_size + 2:
|
|
152
|
+
continue
|
|
153
|
+
box = np.array(box)
|
|
154
|
+
|
|
155
|
+
box[:, 0] = np.clip(
|
|
156
|
+
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
|
157
|
+
box[:, 1] = np.clip(
|
|
158
|
+
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
|
159
|
+
boxes.append(box.astype("int32"))
|
|
160
|
+
scores.append(score)
|
|
161
|
+
return np.array(boxes, dtype="int32"), scores
|
|
162
|
+
|
|
163
|
+
def unclip(self, box, unclip_ratio):
|
|
164
|
+
poly = Polygon(box)
|
|
165
|
+
distance = poly.area * unclip_ratio / poly.length
|
|
166
|
+
offset = pyclipper.PyclipperOffset()
|
|
167
|
+
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
168
|
+
expanded = np.array(offset.Execute(distance))
|
|
169
|
+
return expanded
|
|
170
|
+
|
|
171
|
+
def get_mini_boxes(self, contour):
|
|
172
|
+
bounding_box = cv2.minAreaRect(contour)
|
|
173
|
+
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
|
174
|
+
|
|
175
|
+
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
|
176
|
+
if points[1][1] > points[0][1]:
|
|
177
|
+
index_1 = 0
|
|
178
|
+
index_4 = 1
|
|
179
|
+
else:
|
|
180
|
+
index_1 = 1
|
|
181
|
+
index_4 = 0
|
|
182
|
+
if points[3][1] > points[2][1]:
|
|
183
|
+
index_2 = 2
|
|
184
|
+
index_3 = 3
|
|
185
|
+
else:
|
|
186
|
+
index_2 = 3
|
|
187
|
+
index_3 = 2
|
|
188
|
+
|
|
189
|
+
box = [
|
|
190
|
+
points[index_1], points[index_2], points[index_3], points[index_4]
|
|
191
|
+
]
|
|
192
|
+
return box, min(bounding_box[1])
|
|
193
|
+
|
|
194
|
+
def box_score_fast(self, bitmap, _box):
|
|
195
|
+
'''
|
|
196
|
+
box_score_fast: use bbox mean score as the mean score
|
|
197
|
+
'''
|
|
198
|
+
h, w = bitmap.shape[:2]
|
|
199
|
+
box = _box.copy()
|
|
200
|
+
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
|
201
|
+
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
|
202
|
+
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
|
203
|
+
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
|
204
|
+
|
|
205
|
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
|
206
|
+
box[:, 0] = box[:, 0] - xmin
|
|
207
|
+
box[:, 1] = box[:, 1] - ymin
|
|
208
|
+
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
|
209
|
+
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
|
210
|
+
|
|
211
|
+
def box_score_slow(self, bitmap, contour):
|
|
212
|
+
'''
|
|
213
|
+
box_score_slow: use polyon mean score as the mean score
|
|
214
|
+
'''
|
|
215
|
+
h, w = bitmap.shape[:2]
|
|
216
|
+
contour = contour.copy()
|
|
217
|
+
contour = np.reshape(contour, (-1, 2))
|
|
218
|
+
|
|
219
|
+
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
|
220
|
+
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
|
221
|
+
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
|
222
|
+
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
|
223
|
+
|
|
224
|
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
|
225
|
+
|
|
226
|
+
contour[:, 0] = contour[:, 0] - xmin
|
|
227
|
+
contour[:, 1] = contour[:, 1] - ymin
|
|
228
|
+
|
|
229
|
+
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
|
230
|
+
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
|
231
|
+
|
|
232
|
+
def __call__(self, outs_dict, shape_list):
|
|
233
|
+
pred = outs_dict['maps']
|
|
234
|
+
if not isinstance(pred, np.ndarray):
|
|
235
|
+
pred = pred.numpy()
|
|
236
|
+
pred = pred[:, 0, :, :]
|
|
237
|
+
segmentation = pred > self.thresh
|
|
238
|
+
|
|
239
|
+
boxes_batch = []
|
|
240
|
+
for batch_index in range(pred.shape[0]):
|
|
241
|
+
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
|
242
|
+
if self.dilation_kernel is not None:
|
|
243
|
+
mask = cv2.dilate(
|
|
244
|
+
np.array(segmentation[batch_index]).astype(np.uint8),
|
|
245
|
+
self.dilation_kernel)
|
|
246
|
+
else:
|
|
247
|
+
mask = segmentation[batch_index]
|
|
248
|
+
if self.box_type == 'poly':
|
|
249
|
+
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
|
250
|
+
mask, src_w, src_h)
|
|
251
|
+
elif self.box_type == 'quad':
|
|
252
|
+
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
|
253
|
+
src_w, src_h)
|
|
254
|
+
else:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"box_type can only be one of ['quad', 'poly']")
|
|
257
|
+
|
|
258
|
+
boxes_batch.append({'points': boxes})
|
|
259
|
+
return boxes_batch
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class BaseRecLabelDecode:
|
|
263
|
+
""" Convert between text-label and text-index """
|
|
264
|
+
|
|
265
|
+
def __init__(self, character_dict_path=None, use_space_char=False):
|
|
266
|
+
self.beg_str = "sos"
|
|
267
|
+
self.end_str = "eos"
|
|
268
|
+
self.reverse = False
|
|
269
|
+
self.character_str = []
|
|
270
|
+
|
|
271
|
+
if character_dict_path is None:
|
|
272
|
+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
273
|
+
dict_character = list(self.character_str)
|
|
274
|
+
else:
|
|
275
|
+
with open(character_dict_path, "rb") as fin:
|
|
276
|
+
lines = fin.readlines()
|
|
277
|
+
for line in lines:
|
|
278
|
+
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
|
279
|
+
self.character_str.append(line)
|
|
280
|
+
if use_space_char:
|
|
281
|
+
self.character_str.append(" ")
|
|
282
|
+
dict_character = list(self.character_str)
|
|
283
|
+
if 'arabic' in character_dict_path:
|
|
284
|
+
self.reverse = True
|
|
285
|
+
|
|
286
|
+
dict_character = self.add_special_char(dict_character)
|
|
287
|
+
self.dict = {}
|
|
288
|
+
for i, char in enumerate(dict_character):
|
|
289
|
+
self.dict[char] = i
|
|
290
|
+
self.character = dict_character
|
|
291
|
+
|
|
292
|
+
def pred_reverse(self, pred):
|
|
293
|
+
pred_re = []
|
|
294
|
+
c_current = ''
|
|
295
|
+
for c in pred:
|
|
296
|
+
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
|
297
|
+
if c_current != '':
|
|
298
|
+
pred_re.append(c_current)
|
|
299
|
+
pred_re.append(c)
|
|
300
|
+
c_current = ''
|
|
301
|
+
else:
|
|
302
|
+
c_current += c
|
|
303
|
+
if c_current != '':
|
|
304
|
+
pred_re.append(c_current)
|
|
305
|
+
|
|
306
|
+
return ''.join(pred_re[::-1])
|
|
307
|
+
|
|
308
|
+
def add_special_char(self, dict_character):
|
|
309
|
+
return dict_character
|
|
310
|
+
|
|
311
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
312
|
+
""" convert text-index into text-label. """
|
|
313
|
+
result_list = []
|
|
314
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
315
|
+
batch_size = len(text_index)
|
|
316
|
+
for batch_idx in range(batch_size):
|
|
317
|
+
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
|
318
|
+
if is_remove_duplicate:
|
|
319
|
+
selection[1:] = text_index[batch_idx][1:] != text_index[
|
|
320
|
+
batch_idx][:-1]
|
|
321
|
+
for ignored_token in ignored_tokens:
|
|
322
|
+
selection &= text_index[batch_idx] != ignored_token
|
|
323
|
+
|
|
324
|
+
char_list = [
|
|
325
|
+
self.character[text_id]
|
|
326
|
+
for text_id in text_index[batch_idx][selection]
|
|
327
|
+
]
|
|
328
|
+
if text_prob is not None:
|
|
329
|
+
conf_list = text_prob[batch_idx][selection]
|
|
330
|
+
else:
|
|
331
|
+
conf_list = [1] * len(selection)
|
|
332
|
+
if len(conf_list) == 0:
|
|
333
|
+
conf_list = [0]
|
|
334
|
+
|
|
335
|
+
text = ''.join(char_list)
|
|
336
|
+
|
|
337
|
+
if self.reverse: # for arabic rec
|
|
338
|
+
text = self.pred_reverse(text)
|
|
339
|
+
|
|
340
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
341
|
+
return result_list
|
|
342
|
+
|
|
343
|
+
def get_ignored_tokens(self):
|
|
344
|
+
return [0] # for ctc blank
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class CTCLabelDecode(BaseRecLabelDecode):
|
|
348
|
+
""" Convert between text-label and text-index """
|
|
349
|
+
|
|
350
|
+
def __init__(self, character_dict_path=None, use_space_char=False,
|
|
351
|
+
**kwargs):
|
|
352
|
+
super(CTCLabelDecode, self).__init__(character_dict_path,
|
|
353
|
+
use_space_char)
|
|
354
|
+
|
|
355
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
356
|
+
if isinstance(preds, tuple) or isinstance(preds, list):
|
|
357
|
+
preds = preds[-1]
|
|
358
|
+
if not isinstance(preds, np.ndarray):
|
|
359
|
+
preds = preds.numpy()
|
|
360
|
+
preds_idx = preds.argmax(axis=2)
|
|
361
|
+
preds_prob = preds.max(axis=2)
|
|
362
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
|
363
|
+
if label is None:
|
|
364
|
+
return text
|
|
365
|
+
label = self.decode(label)
|
|
366
|
+
return text, label
|
|
367
|
+
|
|
368
|
+
def add_special_char(self, dict_character):
|
|
369
|
+
dict_character = ['blank'] + dict_character
|
|
370
|
+
return dict_character
|