doc-page-extractor 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- doc_page_extractor/__init__.py +1 -1
- doc_page_extractor/downloader.py +4 -1
- doc_page_extractor/extractor.py +7 -13
- doc_page_extractor/ocr.py +110 -58
- doc_page_extractor/ocr_corrector.py +3 -3
- doc_page_extractor/onnxocr/__init__.py +1 -0
- doc_page_extractor/onnxocr/cls_postprocess.py +26 -0
- doc_page_extractor/onnxocr/db_postprocess.py +246 -0
- doc_page_extractor/onnxocr/imaug.py +32 -0
- doc_page_extractor/onnxocr/operators.py +187 -0
- doc_page_extractor/onnxocr/predict_base.py +52 -0
- doc_page_extractor/onnxocr/predict_cls.py +89 -0
- doc_page_extractor/onnxocr/predict_det.py +120 -0
- doc_page_extractor/onnxocr/predict_rec.py +321 -0
- doc_page_extractor/onnxocr/predict_system.py +97 -0
- doc_page_extractor/onnxocr/rec_postprocess.py +896 -0
- doc_page_extractor/onnxocr/utils.py +71 -0
- {doc_page_extractor-0.0.5.dist-info → doc_page_extractor-0.0.7.dist-info}/METADATA +17 -5
- doc_page_extractor-0.0.7.dist-info/RECORD +33 -0
- doc_page_extractor-0.0.5.dist-info/RECORD +0 -21
- {doc_page_extractor-0.0.5.dist-info → doc_page_extractor-0.0.7.dist-info}/LICENSE +0 -0
- {doc_page_extractor-0.0.5.dist-info → doc_page_extractor-0.0.7.dist-info}/WHEEL +0 -0
- {doc_page_extractor-0.0.5.dist-info → doc_page_extractor-0.0.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,896 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
paddle = None
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseRecLabelDecode(object):
|
|
8
|
+
"""Convert between text-label and text-index"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, character_dict_path=None, use_space_char=False):
|
|
11
|
+
self.beg_str = "sos"
|
|
12
|
+
self.end_str = "eos"
|
|
13
|
+
self.reverse = False
|
|
14
|
+
self.character_str = []
|
|
15
|
+
|
|
16
|
+
if character_dict_path is None:
|
|
17
|
+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|
18
|
+
dict_character = list(self.character_str)
|
|
19
|
+
else:
|
|
20
|
+
with open(character_dict_path, "rb") as fin:
|
|
21
|
+
lines = fin.readlines()
|
|
22
|
+
for line in lines:
|
|
23
|
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
|
24
|
+
self.character_str.append(line)
|
|
25
|
+
if use_space_char:
|
|
26
|
+
self.character_str.append(" ")
|
|
27
|
+
dict_character = list(self.character_str)
|
|
28
|
+
if "arabic" in character_dict_path:
|
|
29
|
+
self.reverse = True
|
|
30
|
+
|
|
31
|
+
dict_character = self.add_special_char(dict_character)
|
|
32
|
+
self.dict = {}
|
|
33
|
+
for i, char in enumerate(dict_character):
|
|
34
|
+
self.dict[char] = i
|
|
35
|
+
self.character = dict_character
|
|
36
|
+
|
|
37
|
+
def pred_reverse(self, pred):
|
|
38
|
+
pred_re = []
|
|
39
|
+
c_current = ""
|
|
40
|
+
for c in pred:
|
|
41
|
+
if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
|
|
42
|
+
if c_current != "":
|
|
43
|
+
pred_re.append(c_current)
|
|
44
|
+
pred_re.append(c)
|
|
45
|
+
c_current = ""
|
|
46
|
+
else:
|
|
47
|
+
c_current += c
|
|
48
|
+
if c_current != "":
|
|
49
|
+
pred_re.append(c_current)
|
|
50
|
+
|
|
51
|
+
return "".join(pred_re[::-1])
|
|
52
|
+
|
|
53
|
+
def add_special_char(self, dict_character):
|
|
54
|
+
return dict_character
|
|
55
|
+
|
|
56
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
57
|
+
"""convert text-index into text-label."""
|
|
58
|
+
result_list = []
|
|
59
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
60
|
+
batch_size = len(text_index)
|
|
61
|
+
for batch_idx in range(batch_size):
|
|
62
|
+
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
|
63
|
+
if is_remove_duplicate:
|
|
64
|
+
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
|
|
65
|
+
for ignored_token in ignored_tokens:
|
|
66
|
+
selection &= text_index[batch_idx] != ignored_token
|
|
67
|
+
|
|
68
|
+
char_list = [
|
|
69
|
+
self.character[text_id] for text_id in text_index[batch_idx][selection]
|
|
70
|
+
]
|
|
71
|
+
if text_prob is not None:
|
|
72
|
+
conf_list = text_prob[batch_idx][selection]
|
|
73
|
+
else:
|
|
74
|
+
conf_list = [1] * len(selection)
|
|
75
|
+
if len(conf_list) == 0:
|
|
76
|
+
conf_list = [0]
|
|
77
|
+
|
|
78
|
+
text = "".join(char_list)
|
|
79
|
+
|
|
80
|
+
if self.reverse: # for arabic rec
|
|
81
|
+
text = self.pred_reverse(text)
|
|
82
|
+
|
|
83
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
84
|
+
return result_list
|
|
85
|
+
|
|
86
|
+
def get_ignored_tokens(self):
|
|
87
|
+
return [0] # for ctc blank
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class CTCLabelDecode(BaseRecLabelDecode):
|
|
91
|
+
"""Convert between text-label and text-index"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
94
|
+
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
95
|
+
|
|
96
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
97
|
+
if isinstance(preds, tuple) or isinstance(preds, list):
|
|
98
|
+
preds = preds[-1]
|
|
99
|
+
# if isinstance(preds, paddle.Tensor):
|
|
100
|
+
# preds = preds.numpy()
|
|
101
|
+
preds_idx = preds.argmax(axis=2)
|
|
102
|
+
preds_prob = preds.max(axis=2)
|
|
103
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
|
104
|
+
if label is None:
|
|
105
|
+
return text
|
|
106
|
+
label = self.decode(label)
|
|
107
|
+
return text, label
|
|
108
|
+
|
|
109
|
+
def add_special_char(self, dict_character):
|
|
110
|
+
dict_character = ["blank"] + dict_character
|
|
111
|
+
return dict_character
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class DistillationCTCLabelDecode(CTCLabelDecode):
|
|
115
|
+
"""
|
|
116
|
+
Convert
|
|
117
|
+
Convert between text-label and text-index
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
character_dict_path=None,
|
|
123
|
+
use_space_char=False,
|
|
124
|
+
model_name=["student"],
|
|
125
|
+
key=None,
|
|
126
|
+
multi_head=False,
|
|
127
|
+
**kwargs
|
|
128
|
+
):
|
|
129
|
+
super(DistillationCTCLabelDecode, self).__init__(
|
|
130
|
+
character_dict_path, use_space_char
|
|
131
|
+
)
|
|
132
|
+
if not isinstance(model_name, list):
|
|
133
|
+
model_name = [model_name]
|
|
134
|
+
self.model_name = model_name
|
|
135
|
+
|
|
136
|
+
self.key = key
|
|
137
|
+
self.multi_head = multi_head
|
|
138
|
+
|
|
139
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
140
|
+
output = dict()
|
|
141
|
+
for name in self.model_name:
|
|
142
|
+
pred = preds[name]
|
|
143
|
+
if self.key is not None:
|
|
144
|
+
pred = pred[self.key]
|
|
145
|
+
if self.multi_head and isinstance(pred, dict):
|
|
146
|
+
pred = pred["ctc"]
|
|
147
|
+
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
|
148
|
+
return output
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class AttnLabelDecode(BaseRecLabelDecode):
|
|
152
|
+
"""Convert between text-label and text-index"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
155
|
+
super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
156
|
+
|
|
157
|
+
def add_special_char(self, dict_character):
|
|
158
|
+
self.beg_str = "sos"
|
|
159
|
+
self.end_str = "eos"
|
|
160
|
+
dict_character = dict_character
|
|
161
|
+
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
|
162
|
+
return dict_character
|
|
163
|
+
|
|
164
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
165
|
+
"""convert text-index into text-label."""
|
|
166
|
+
result_list = []
|
|
167
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
168
|
+
[beg_idx, end_idx] = self.get_ignored_tokens()
|
|
169
|
+
batch_size = len(text_index)
|
|
170
|
+
for batch_idx in range(batch_size):
|
|
171
|
+
char_list = []
|
|
172
|
+
conf_list = []
|
|
173
|
+
for idx in range(len(text_index[batch_idx])):
|
|
174
|
+
if text_index[batch_idx][idx] in ignored_tokens:
|
|
175
|
+
continue
|
|
176
|
+
if int(text_index[batch_idx][idx]) == int(end_idx):
|
|
177
|
+
break
|
|
178
|
+
if is_remove_duplicate:
|
|
179
|
+
# only for predict
|
|
180
|
+
if (
|
|
181
|
+
idx > 0
|
|
182
|
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
|
183
|
+
):
|
|
184
|
+
continue
|
|
185
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
186
|
+
if text_prob is not None:
|
|
187
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
188
|
+
else:
|
|
189
|
+
conf_list.append(1)
|
|
190
|
+
text = "".join(char_list)
|
|
191
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
192
|
+
return result_list
|
|
193
|
+
|
|
194
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
195
|
+
"""
|
|
196
|
+
text = self.decode(text)
|
|
197
|
+
if label is None:
|
|
198
|
+
return text
|
|
199
|
+
else:
|
|
200
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
201
|
+
return text, label
|
|
202
|
+
"""
|
|
203
|
+
if isinstance(preds, paddle.Tensor):
|
|
204
|
+
preds = preds.numpy()
|
|
205
|
+
|
|
206
|
+
preds_idx = preds.argmax(axis=2)
|
|
207
|
+
preds_prob = preds.max(axis=2)
|
|
208
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
209
|
+
if label is None:
|
|
210
|
+
return text
|
|
211
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
212
|
+
return text, label
|
|
213
|
+
|
|
214
|
+
def get_ignored_tokens(self):
|
|
215
|
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
|
216
|
+
end_idx = self.get_beg_end_flag_idx("end")
|
|
217
|
+
return [beg_idx, end_idx]
|
|
218
|
+
|
|
219
|
+
def get_beg_end_flag_idx(self, beg_or_end):
|
|
220
|
+
if beg_or_end == "beg":
|
|
221
|
+
idx = np.array(self.dict[self.beg_str])
|
|
222
|
+
elif beg_or_end == "end":
|
|
223
|
+
idx = np.array(self.dict[self.end_str])
|
|
224
|
+
else:
|
|
225
|
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
|
226
|
+
return idx
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class RFLLabelDecode(BaseRecLabelDecode):
|
|
230
|
+
"""Convert between text-label and text-index"""
|
|
231
|
+
|
|
232
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
233
|
+
super(RFLLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
234
|
+
|
|
235
|
+
def add_special_char(self, dict_character):
|
|
236
|
+
self.beg_str = "sos"
|
|
237
|
+
self.end_str = "eos"
|
|
238
|
+
dict_character = dict_character
|
|
239
|
+
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
|
240
|
+
return dict_character
|
|
241
|
+
|
|
242
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
243
|
+
"""convert text-index into text-label."""
|
|
244
|
+
result_list = []
|
|
245
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
246
|
+
[beg_idx, end_idx] = self.get_ignored_tokens()
|
|
247
|
+
batch_size = len(text_index)
|
|
248
|
+
for batch_idx in range(batch_size):
|
|
249
|
+
char_list = []
|
|
250
|
+
conf_list = []
|
|
251
|
+
for idx in range(len(text_index[batch_idx])):
|
|
252
|
+
if text_index[batch_idx][idx] in ignored_tokens:
|
|
253
|
+
continue
|
|
254
|
+
if int(text_index[batch_idx][idx]) == int(end_idx):
|
|
255
|
+
break
|
|
256
|
+
if is_remove_duplicate:
|
|
257
|
+
# only for predict
|
|
258
|
+
if (
|
|
259
|
+
idx > 0
|
|
260
|
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
|
261
|
+
):
|
|
262
|
+
continue
|
|
263
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
264
|
+
if text_prob is not None:
|
|
265
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
266
|
+
else:
|
|
267
|
+
conf_list.append(1)
|
|
268
|
+
text = "".join(char_list)
|
|
269
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
270
|
+
return result_list
|
|
271
|
+
|
|
272
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
273
|
+
# if seq_outputs is not None:
|
|
274
|
+
if isinstance(preds, tuple) or isinstance(preds, list):
|
|
275
|
+
cnt_outputs, seq_outputs = preds
|
|
276
|
+
if isinstance(seq_outputs, paddle.Tensor):
|
|
277
|
+
seq_outputs = seq_outputs.numpy()
|
|
278
|
+
preds_idx = seq_outputs.argmax(axis=2)
|
|
279
|
+
preds_prob = seq_outputs.max(axis=2)
|
|
280
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
281
|
+
|
|
282
|
+
if label is None:
|
|
283
|
+
return text
|
|
284
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
285
|
+
return text, label
|
|
286
|
+
|
|
287
|
+
else:
|
|
288
|
+
cnt_outputs = preds
|
|
289
|
+
if isinstance(cnt_outputs, paddle.Tensor):
|
|
290
|
+
cnt_outputs = cnt_outputs.numpy()
|
|
291
|
+
cnt_length = []
|
|
292
|
+
for lens in cnt_outputs:
|
|
293
|
+
length = round(np.sum(lens))
|
|
294
|
+
cnt_length.append(length)
|
|
295
|
+
if label is None:
|
|
296
|
+
return cnt_length
|
|
297
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
298
|
+
length = [len(res[0]) for res in label]
|
|
299
|
+
return cnt_length, length
|
|
300
|
+
|
|
301
|
+
def get_ignored_tokens(self):
|
|
302
|
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
|
303
|
+
end_idx = self.get_beg_end_flag_idx("end")
|
|
304
|
+
return [beg_idx, end_idx]
|
|
305
|
+
|
|
306
|
+
def get_beg_end_flag_idx(self, beg_or_end):
|
|
307
|
+
if beg_or_end == "beg":
|
|
308
|
+
idx = np.array(self.dict[self.beg_str])
|
|
309
|
+
elif beg_or_end == "end":
|
|
310
|
+
idx = np.array(self.dict[self.end_str])
|
|
311
|
+
else:
|
|
312
|
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
|
313
|
+
return idx
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class SEEDLabelDecode(BaseRecLabelDecode):
|
|
317
|
+
"""Convert between text-label and text-index"""
|
|
318
|
+
|
|
319
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
320
|
+
super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
321
|
+
|
|
322
|
+
def add_special_char(self, dict_character):
|
|
323
|
+
self.padding_str = "padding"
|
|
324
|
+
self.end_str = "eos"
|
|
325
|
+
self.unknown = "unknown"
|
|
326
|
+
dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
|
|
327
|
+
return dict_character
|
|
328
|
+
|
|
329
|
+
def get_ignored_tokens(self):
|
|
330
|
+
end_idx = self.get_beg_end_flag_idx("eos")
|
|
331
|
+
return [end_idx]
|
|
332
|
+
|
|
333
|
+
def get_beg_end_flag_idx(self, beg_or_end):
|
|
334
|
+
if beg_or_end == "sos":
|
|
335
|
+
idx = np.array(self.dict[self.beg_str])
|
|
336
|
+
elif beg_or_end == "eos":
|
|
337
|
+
idx = np.array(self.dict[self.end_str])
|
|
338
|
+
else:
|
|
339
|
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
|
340
|
+
return idx
|
|
341
|
+
|
|
342
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
343
|
+
"""convert text-index into text-label."""
|
|
344
|
+
result_list = []
|
|
345
|
+
[end_idx] = self.get_ignored_tokens()
|
|
346
|
+
batch_size = len(text_index)
|
|
347
|
+
for batch_idx in range(batch_size):
|
|
348
|
+
char_list = []
|
|
349
|
+
conf_list = []
|
|
350
|
+
for idx in range(len(text_index[batch_idx])):
|
|
351
|
+
if int(text_index[batch_idx][idx]) == int(end_idx):
|
|
352
|
+
break
|
|
353
|
+
if is_remove_duplicate:
|
|
354
|
+
# only for predict
|
|
355
|
+
if (
|
|
356
|
+
idx > 0
|
|
357
|
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
|
358
|
+
):
|
|
359
|
+
continue
|
|
360
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
361
|
+
if text_prob is not None:
|
|
362
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
363
|
+
else:
|
|
364
|
+
conf_list.append(1)
|
|
365
|
+
text = "".join(char_list)
|
|
366
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
367
|
+
return result_list
|
|
368
|
+
|
|
369
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
370
|
+
"""
|
|
371
|
+
text = self.decode(text)
|
|
372
|
+
if label is None:
|
|
373
|
+
return text
|
|
374
|
+
else:
|
|
375
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
376
|
+
return text, label
|
|
377
|
+
"""
|
|
378
|
+
preds_idx = preds["rec_pred"]
|
|
379
|
+
if isinstance(preds_idx, paddle.Tensor):
|
|
380
|
+
preds_idx = preds_idx.numpy()
|
|
381
|
+
if "rec_pred_scores" in preds:
|
|
382
|
+
preds_idx = preds["rec_pred"]
|
|
383
|
+
preds_prob = preds["rec_pred_scores"]
|
|
384
|
+
else:
|
|
385
|
+
preds_idx = preds["rec_pred"].argmax(axis=2)
|
|
386
|
+
preds_prob = preds["rec_pred"].max(axis=2)
|
|
387
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
388
|
+
if label is None:
|
|
389
|
+
return text
|
|
390
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
391
|
+
return text, label
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class SRNLabelDecode(BaseRecLabelDecode):
|
|
395
|
+
"""Convert between text-label and text-index"""
|
|
396
|
+
|
|
397
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
398
|
+
super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
399
|
+
self.max_text_length = kwargs.get("max_text_length", 25)
|
|
400
|
+
|
|
401
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
402
|
+
pred = preds["predict"]
|
|
403
|
+
char_num = len(self.character_str) + 2
|
|
404
|
+
if isinstance(pred, paddle.Tensor):
|
|
405
|
+
pred = pred.numpy()
|
|
406
|
+
pred = np.reshape(pred, [-1, char_num])
|
|
407
|
+
|
|
408
|
+
preds_idx = np.argmax(pred, axis=1)
|
|
409
|
+
preds_prob = np.max(pred, axis=1)
|
|
410
|
+
|
|
411
|
+
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
|
|
412
|
+
|
|
413
|
+
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
|
|
414
|
+
|
|
415
|
+
text = self.decode(preds_idx, preds_prob)
|
|
416
|
+
|
|
417
|
+
if label is None:
|
|
418
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
419
|
+
return text
|
|
420
|
+
label = self.decode(label)
|
|
421
|
+
return text, label
|
|
422
|
+
|
|
423
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
424
|
+
"""convert text-index into text-label."""
|
|
425
|
+
result_list = []
|
|
426
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
427
|
+
batch_size = len(text_index)
|
|
428
|
+
|
|
429
|
+
for batch_idx in range(batch_size):
|
|
430
|
+
char_list = []
|
|
431
|
+
conf_list = []
|
|
432
|
+
for idx in range(len(text_index[batch_idx])):
|
|
433
|
+
if text_index[batch_idx][idx] in ignored_tokens:
|
|
434
|
+
continue
|
|
435
|
+
if is_remove_duplicate:
|
|
436
|
+
# only for predict
|
|
437
|
+
if (
|
|
438
|
+
idx > 0
|
|
439
|
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
|
440
|
+
):
|
|
441
|
+
continue
|
|
442
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
443
|
+
if text_prob is not None:
|
|
444
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
445
|
+
else:
|
|
446
|
+
conf_list.append(1)
|
|
447
|
+
|
|
448
|
+
text = "".join(char_list)
|
|
449
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
450
|
+
return result_list
|
|
451
|
+
|
|
452
|
+
def add_special_char(self, dict_character):
|
|
453
|
+
dict_character = dict_character + [self.beg_str, self.end_str]
|
|
454
|
+
return dict_character
|
|
455
|
+
|
|
456
|
+
def get_ignored_tokens(self):
|
|
457
|
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
|
458
|
+
end_idx = self.get_beg_end_flag_idx("end")
|
|
459
|
+
return [beg_idx, end_idx]
|
|
460
|
+
|
|
461
|
+
def get_beg_end_flag_idx(self, beg_or_end):
|
|
462
|
+
if beg_or_end == "beg":
|
|
463
|
+
idx = np.array(self.dict[self.beg_str])
|
|
464
|
+
elif beg_or_end == "end":
|
|
465
|
+
idx = np.array(self.dict[self.end_str])
|
|
466
|
+
else:
|
|
467
|
+
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
|
468
|
+
return idx
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class SARLabelDecode(BaseRecLabelDecode):
|
|
472
|
+
"""Convert between text-label and text-index"""
|
|
473
|
+
|
|
474
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
475
|
+
super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
476
|
+
|
|
477
|
+
self.rm_symbol = kwargs.get("rm_symbol", False)
|
|
478
|
+
|
|
479
|
+
def add_special_char(self, dict_character):
|
|
480
|
+
beg_end_str = "<BOS/EOS>"
|
|
481
|
+
unknown_str = "<UKN>"
|
|
482
|
+
padding_str = "<PAD>"
|
|
483
|
+
dict_character = dict_character + [unknown_str]
|
|
484
|
+
self.unknown_idx = len(dict_character) - 1
|
|
485
|
+
dict_character = dict_character + [beg_end_str]
|
|
486
|
+
self.start_idx = len(dict_character) - 1
|
|
487
|
+
self.end_idx = len(dict_character) - 1
|
|
488
|
+
dict_character = dict_character + [padding_str]
|
|
489
|
+
self.padding_idx = len(dict_character) - 1
|
|
490
|
+
return dict_character
|
|
491
|
+
|
|
492
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
493
|
+
"""convert text-index into text-label."""
|
|
494
|
+
result_list = []
|
|
495
|
+
ignored_tokens = self.get_ignored_tokens()
|
|
496
|
+
|
|
497
|
+
batch_size = len(text_index)
|
|
498
|
+
for batch_idx in range(batch_size):
|
|
499
|
+
char_list = []
|
|
500
|
+
conf_list = []
|
|
501
|
+
for idx in range(len(text_index[batch_idx])):
|
|
502
|
+
if text_index[batch_idx][idx] in ignored_tokens:
|
|
503
|
+
continue
|
|
504
|
+
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
|
505
|
+
if text_prob is None and idx == 0:
|
|
506
|
+
continue
|
|
507
|
+
else:
|
|
508
|
+
break
|
|
509
|
+
if is_remove_duplicate:
|
|
510
|
+
# only for predict
|
|
511
|
+
if (
|
|
512
|
+
idx > 0
|
|
513
|
+
and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
|
|
514
|
+
):
|
|
515
|
+
continue
|
|
516
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
517
|
+
if text_prob is not None:
|
|
518
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
519
|
+
else:
|
|
520
|
+
conf_list.append(1)
|
|
521
|
+
text = "".join(char_list)
|
|
522
|
+
if self.rm_symbol:
|
|
523
|
+
comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
|
|
524
|
+
text = text.lower()
|
|
525
|
+
text = comp.sub("", text)
|
|
526
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
527
|
+
return result_list
|
|
528
|
+
|
|
529
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
530
|
+
if isinstance(preds, paddle.Tensor):
|
|
531
|
+
preds = preds.numpy()
|
|
532
|
+
preds_idx = preds.argmax(axis=2)
|
|
533
|
+
preds_prob = preds.max(axis=2)
|
|
534
|
+
|
|
535
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
536
|
+
|
|
537
|
+
if label is None:
|
|
538
|
+
return text
|
|
539
|
+
label = self.decode(label, is_remove_duplicate=False)
|
|
540
|
+
return text, label
|
|
541
|
+
|
|
542
|
+
def get_ignored_tokens(self):
|
|
543
|
+
return [self.padding_idx]
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
class DistillationSARLabelDecode(SARLabelDecode):
|
|
547
|
+
"""
|
|
548
|
+
Convert
|
|
549
|
+
Convert between text-label and text-index
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
def __init__(
|
|
553
|
+
self,
|
|
554
|
+
character_dict_path=None,
|
|
555
|
+
use_space_char=False,
|
|
556
|
+
model_name=["student"],
|
|
557
|
+
key=None,
|
|
558
|
+
multi_head=False,
|
|
559
|
+
**kwargs
|
|
560
|
+
):
|
|
561
|
+
super(DistillationSARLabelDecode, self).__init__(
|
|
562
|
+
character_dict_path, use_space_char
|
|
563
|
+
)
|
|
564
|
+
if not isinstance(model_name, list):
|
|
565
|
+
model_name = [model_name]
|
|
566
|
+
self.model_name = model_name
|
|
567
|
+
|
|
568
|
+
self.key = key
|
|
569
|
+
self.multi_head = multi_head
|
|
570
|
+
|
|
571
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
572
|
+
output = dict()
|
|
573
|
+
for name in self.model_name:
|
|
574
|
+
pred = preds[name]
|
|
575
|
+
if self.key is not None:
|
|
576
|
+
pred = pred[self.key]
|
|
577
|
+
if self.multi_head and isinstance(pred, dict):
|
|
578
|
+
pred = pred["sar"]
|
|
579
|
+
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
|
580
|
+
return output
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
class PRENLabelDecode(BaseRecLabelDecode):
|
|
584
|
+
"""Convert between text-label and text-index"""
|
|
585
|
+
|
|
586
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
587
|
+
super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
588
|
+
|
|
589
|
+
def add_special_char(self, dict_character):
|
|
590
|
+
padding_str = "<PAD>" # 0
|
|
591
|
+
end_str = "<EOS>" # 1
|
|
592
|
+
unknown_str = "<UNK>" # 2
|
|
593
|
+
|
|
594
|
+
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
|
595
|
+
self.padding_idx = 0
|
|
596
|
+
self.end_idx = 1
|
|
597
|
+
self.unknown_idx = 2
|
|
598
|
+
|
|
599
|
+
return dict_character
|
|
600
|
+
|
|
601
|
+
def decode(self, text_index, text_prob=None):
|
|
602
|
+
"""convert text-index into text-label."""
|
|
603
|
+
result_list = []
|
|
604
|
+
batch_size = len(text_index)
|
|
605
|
+
|
|
606
|
+
for batch_idx in range(batch_size):
|
|
607
|
+
char_list = []
|
|
608
|
+
conf_list = []
|
|
609
|
+
for idx in range(len(text_index[batch_idx])):
|
|
610
|
+
if text_index[batch_idx][idx] == self.end_idx:
|
|
611
|
+
break
|
|
612
|
+
if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
|
|
613
|
+
continue
|
|
614
|
+
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
|
615
|
+
if text_prob is not None:
|
|
616
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
617
|
+
else:
|
|
618
|
+
conf_list.append(1)
|
|
619
|
+
|
|
620
|
+
text = "".join(char_list)
|
|
621
|
+
if len(text) > 0:
|
|
622
|
+
result_list.append((text, np.mean(conf_list).tolist()))
|
|
623
|
+
else:
|
|
624
|
+
# here confidence of empty recog result is 1
|
|
625
|
+
result_list.append(("", 1))
|
|
626
|
+
return result_list
|
|
627
|
+
|
|
628
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
629
|
+
if isinstance(preds, paddle.Tensor):
|
|
630
|
+
preds = preds.numpy()
|
|
631
|
+
preds_idx = preds.argmax(axis=2)
|
|
632
|
+
preds_prob = preds.max(axis=2)
|
|
633
|
+
text = self.decode(preds_idx, preds_prob)
|
|
634
|
+
if label is None:
|
|
635
|
+
return text
|
|
636
|
+
label = self.decode(label)
|
|
637
|
+
return text, label
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
class NRTRLabelDecode(BaseRecLabelDecode):
|
|
641
|
+
"""Convert between text-label and text-index"""
|
|
642
|
+
|
|
643
|
+
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
|
|
644
|
+
super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
645
|
+
|
|
646
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
647
|
+
|
|
648
|
+
if len(preds) == 2:
|
|
649
|
+
preds_id = preds[0]
|
|
650
|
+
preds_prob = preds[1]
|
|
651
|
+
if isinstance(preds_id, paddle.Tensor):
|
|
652
|
+
preds_id = preds_id.numpy()
|
|
653
|
+
if isinstance(preds_prob, paddle.Tensor):
|
|
654
|
+
preds_prob = preds_prob.numpy()
|
|
655
|
+
if preds_id[0][0] == 2:
|
|
656
|
+
preds_idx = preds_id[:, 1:]
|
|
657
|
+
preds_prob = preds_prob[:, 1:]
|
|
658
|
+
else:
|
|
659
|
+
preds_idx = preds_id
|
|
660
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
661
|
+
if label is None:
|
|
662
|
+
return text
|
|
663
|
+
label = self.decode(label[:, 1:])
|
|
664
|
+
else:
|
|
665
|
+
if isinstance(preds, paddle.Tensor):
|
|
666
|
+
preds = preds.numpy()
|
|
667
|
+
preds_idx = preds.argmax(axis=2)
|
|
668
|
+
preds_prob = preds.max(axis=2)
|
|
669
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
670
|
+
if label is None:
|
|
671
|
+
return text
|
|
672
|
+
label = self.decode(label[:, 1:])
|
|
673
|
+
return text, label
|
|
674
|
+
|
|
675
|
+
def add_special_char(self, dict_character):
|
|
676
|
+
dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
|
|
677
|
+
return dict_character
|
|
678
|
+
|
|
679
|
+
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
680
|
+
"""convert text-index into text-label."""
|
|
681
|
+
result_list = []
|
|
682
|
+
batch_size = len(text_index)
|
|
683
|
+
for batch_idx in range(batch_size):
|
|
684
|
+
char_list = []
|
|
685
|
+
conf_list = []
|
|
686
|
+
for idx in range(len(text_index[batch_idx])):
|
|
687
|
+
try:
|
|
688
|
+
char_idx = self.character[int(text_index[batch_idx][idx])]
|
|
689
|
+
except:
|
|
690
|
+
continue
|
|
691
|
+
if char_idx == "</s>": # end
|
|
692
|
+
break
|
|
693
|
+
char_list.append(char_idx)
|
|
694
|
+
if text_prob is not None:
|
|
695
|
+
conf_list.append(text_prob[batch_idx][idx])
|
|
696
|
+
else:
|
|
697
|
+
conf_list.append(1)
|
|
698
|
+
text = "".join(char_list)
|
|
699
|
+
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
|
700
|
+
return result_list
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
class ViTSTRLabelDecode(NRTRLabelDecode):
|
|
704
|
+
"""Convert between text-label and text-index"""
|
|
705
|
+
|
|
706
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
707
|
+
super(ViTSTRLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
708
|
+
|
|
709
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
710
|
+
if isinstance(preds, paddle.Tensor):
|
|
711
|
+
preds = preds[:, 1:].numpy()
|
|
712
|
+
else:
|
|
713
|
+
preds = preds[:, 1:]
|
|
714
|
+
preds_idx = preds.argmax(axis=2)
|
|
715
|
+
preds_prob = preds.max(axis=2)
|
|
716
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
717
|
+
if label is None:
|
|
718
|
+
return text
|
|
719
|
+
label = self.decode(label[:, 1:])
|
|
720
|
+
return text, label
|
|
721
|
+
|
|
722
|
+
def add_special_char(self, dict_character):
|
|
723
|
+
dict_character = ["<s>", "</s>"] + dict_character
|
|
724
|
+
return dict_character
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class ABINetLabelDecode(NRTRLabelDecode):
|
|
728
|
+
"""Convert between text-label and text-index"""
|
|
729
|
+
|
|
730
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
731
|
+
super(ABINetLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
732
|
+
|
|
733
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
734
|
+
if isinstance(preds, dict):
|
|
735
|
+
preds = preds["align"][-1].numpy()
|
|
736
|
+
elif isinstance(preds, paddle.Tensor):
|
|
737
|
+
preds = preds.numpy()
|
|
738
|
+
else:
|
|
739
|
+
preds = preds
|
|
740
|
+
|
|
741
|
+
preds_idx = preds.argmax(axis=2)
|
|
742
|
+
preds_prob = preds.max(axis=2)
|
|
743
|
+
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
744
|
+
if label is None:
|
|
745
|
+
return text
|
|
746
|
+
label = self.decode(label)
|
|
747
|
+
return text, label
|
|
748
|
+
|
|
749
|
+
def add_special_char(self, dict_character):
|
|
750
|
+
dict_character = ["</s>"] + dict_character
|
|
751
|
+
return dict_character
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
class SPINLabelDecode(AttnLabelDecode):
|
|
755
|
+
"""Convert between text-label and text-index"""
|
|
756
|
+
|
|
757
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
758
|
+
super(SPINLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
759
|
+
|
|
760
|
+
def add_special_char(self, dict_character):
|
|
761
|
+
self.beg_str = "sos"
|
|
762
|
+
self.end_str = "eos"
|
|
763
|
+
dict_character = dict_character
|
|
764
|
+
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
|
765
|
+
return dict_character
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
# class VLLabelDecode(BaseRecLabelDecode):
|
|
769
|
+
# """ Convert between text-label and text-index """
|
|
770
|
+
#
|
|
771
|
+
# def __init__(self, character_dict_path=None, use_space_char=False,
|
|
772
|
+
# **kwargs):
|
|
773
|
+
# super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
774
|
+
# self.max_text_length = kwargs.get('max_text_length', 25)
|
|
775
|
+
# self.nclass = len(self.character) + 1
|
|
776
|
+
#
|
|
777
|
+
# def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
|
778
|
+
# """ convert text-index into text-label. """
|
|
779
|
+
# result_list = []
|
|
780
|
+
# ignored_tokens = self.get_ignored_tokens()
|
|
781
|
+
# batch_size = len(text_index)
|
|
782
|
+
# for batch_idx in range(batch_size):
|
|
783
|
+
# selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
|
784
|
+
# if is_remove_duplicate:
|
|
785
|
+
# selection[1:] = text_index[batch_idx][1:] != text_index[
|
|
786
|
+
# batch_idx][:-1]
|
|
787
|
+
# for ignored_token in ignored_tokens:
|
|
788
|
+
# selection &= text_index[batch_idx] != ignored_token
|
|
789
|
+
#
|
|
790
|
+
# char_list = [
|
|
791
|
+
# self.character[text_id - 1]
|
|
792
|
+
# for text_id in text_index[batch_idx][selection]
|
|
793
|
+
# ]
|
|
794
|
+
# if text_prob is not None:
|
|
795
|
+
# conf_list = text_prob[batch_idx][selection]
|
|
796
|
+
# else:
|
|
797
|
+
# conf_list = [1] * len(selection)
|
|
798
|
+
# if len(conf_list) == 0:
|
|
799
|
+
# conf_list = [0]
|
|
800
|
+
#
|
|
801
|
+
# text = ''.join(char_list)
|
|
802
|
+
# result_list.append((text, np.mean(conf_list).tolist()))
|
|
803
|
+
# return result_list
|
|
804
|
+
#
|
|
805
|
+
# def __call__(self, preds, label=None, length=None, *args, **kwargs):
|
|
806
|
+
# if len(preds) == 2: # eval mode
|
|
807
|
+
# text_pre, x = preds
|
|
808
|
+
# b = text_pre.shape[1]
|
|
809
|
+
# lenText = self.max_text_length
|
|
810
|
+
# nsteps = self.max_text_length
|
|
811
|
+
#
|
|
812
|
+
# if not isinstance(text_pre, paddle.Tensor):
|
|
813
|
+
# text_pre = paddle.to_tensor(text_pre, dtype='float32')
|
|
814
|
+
#
|
|
815
|
+
# out_res = paddle.zeros(
|
|
816
|
+
# shape=[lenText, b, self.nclass], dtype=x.dtype)
|
|
817
|
+
# out_length = paddle.zeros(shape=[b], dtype=x.dtype)
|
|
818
|
+
# now_step = 0
|
|
819
|
+
# for _ in range(nsteps):
|
|
820
|
+
# if 0 in out_length and now_step < nsteps:
|
|
821
|
+
# tmp_result = text_pre[now_step, :, :]
|
|
822
|
+
# out_res[now_step] = tmp_result
|
|
823
|
+
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
|
|
824
|
+
# for j in range(b):
|
|
825
|
+
# if out_length[j] == 0 and tmp_result[j] == 0:
|
|
826
|
+
# out_length[j] = now_step + 1
|
|
827
|
+
# now_step += 1
|
|
828
|
+
# for j in range(0, b):
|
|
829
|
+
# if int(out_length[j]) == 0:
|
|
830
|
+
# out_length[j] = nsteps
|
|
831
|
+
# start = 0
|
|
832
|
+
# output = paddle.zeros(
|
|
833
|
+
# shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
|
|
834
|
+
# for i in range(0, b):
|
|
835
|
+
# cur_length = int(out_length[i])
|
|
836
|
+
# output[start:start + cur_length] = out_res[0:cur_length, i, :]
|
|
837
|
+
# start += cur_length
|
|
838
|
+
# net_out = output
|
|
839
|
+
# length = out_length
|
|
840
|
+
#
|
|
841
|
+
# else: # train mode
|
|
842
|
+
# net_out = preds[0]
|
|
843
|
+
# length = length
|
|
844
|
+
# net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
|
|
845
|
+
# text = []
|
|
846
|
+
# if not isinstance(net_out, paddle.Tensor):
|
|
847
|
+
# net_out = paddle.to_tensor(net_out, dtype='float32')
|
|
848
|
+
# net_out = F.softmax(net_out, axis=1)
|
|
849
|
+
# for i in range(0, length.shape[0]):
|
|
850
|
+
# preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
|
|
851
|
+
# ) + length[i])].topk(1)[1][:, 0].tolist()
|
|
852
|
+
# preds_text = ''.join([
|
|
853
|
+
# self.character[idx - 1]
|
|
854
|
+
# if idx > 0 and idx <= len(self.character) else ''
|
|
855
|
+
# for idx in preds_idx
|
|
856
|
+
# ])
|
|
857
|
+
# preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
|
|
858
|
+
# ) + length[i])].topk(1)[0][:, 0]
|
|
859
|
+
# preds_prob = paddle.exp(
|
|
860
|
+
# paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
|
|
861
|
+
# text.append((preds_text, preds_prob.numpy()[0]))
|
|
862
|
+
# if label is None:
|
|
863
|
+
# return text
|
|
864
|
+
# label = self.decode(label)
|
|
865
|
+
# return text, label
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
class CANLabelDecode(BaseRecLabelDecode):
|
|
869
|
+
"""Convert between latex-symbol and symbol-index"""
|
|
870
|
+
|
|
871
|
+
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
|
|
872
|
+
super(CANLabelDecode, self).__init__(character_dict_path, use_space_char)
|
|
873
|
+
|
|
874
|
+
def decode(self, text_index, preds_prob=None):
|
|
875
|
+
result_list = []
|
|
876
|
+
batch_size = len(text_index)
|
|
877
|
+
for batch_idx in range(batch_size):
|
|
878
|
+
seq_end = text_index[batch_idx].argmin(0)
|
|
879
|
+
idx_list = text_index[batch_idx][:seq_end].tolist()
|
|
880
|
+
symbol_list = [self.character[idx] for idx in idx_list]
|
|
881
|
+
probs = []
|
|
882
|
+
if preds_prob is not None:
|
|
883
|
+
probs = preds_prob[batch_idx][: len(symbol_list)].tolist()
|
|
884
|
+
|
|
885
|
+
result_list.append([" ".join(symbol_list), probs])
|
|
886
|
+
return result_list
|
|
887
|
+
|
|
888
|
+
def __call__(self, preds, label=None, *args, **kwargs):
|
|
889
|
+
pred_prob, _, _, _ = preds
|
|
890
|
+
preds_idx = pred_prob.argmax(axis=2)
|
|
891
|
+
|
|
892
|
+
text = self.decode(preds_idx)
|
|
893
|
+
if label is None:
|
|
894
|
+
return text
|
|
895
|
+
label = self.decode(label)
|
|
896
|
+
return text, label
|