doc-page-extractor 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of doc-page-extractor might be problematic. Click here for more details.

@@ -0,0 +1,896 @@
1
+ import re
2
+ import numpy as np
3
+
4
+ paddle = None
5
+
6
+
7
+ class BaseRecLabelDecode(object):
8
+ """Convert between text-label and text-index"""
9
+
10
+ def __init__(self, character_dict_path=None, use_space_char=False):
11
+ self.beg_str = "sos"
12
+ self.end_str = "eos"
13
+ self.reverse = False
14
+ self.character_str = []
15
+
16
+ if character_dict_path is None:
17
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
18
+ dict_character = list(self.character_str)
19
+ else:
20
+ with open(character_dict_path, "rb") as fin:
21
+ lines = fin.readlines()
22
+ for line in lines:
23
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
24
+ self.character_str.append(line)
25
+ if use_space_char:
26
+ self.character_str.append(" ")
27
+ dict_character = list(self.character_str)
28
+ if "arabic" in character_dict_path:
29
+ self.reverse = True
30
+
31
+ dict_character = self.add_special_char(dict_character)
32
+ self.dict = {}
33
+ for i, char in enumerate(dict_character):
34
+ self.dict[char] = i
35
+ self.character = dict_character
36
+
37
+ def pred_reverse(self, pred):
38
+ pred_re = []
39
+ c_current = ""
40
+ for c in pred:
41
+ if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
42
+ if c_current != "":
43
+ pred_re.append(c_current)
44
+ pred_re.append(c)
45
+ c_current = ""
46
+ else:
47
+ c_current += c
48
+ if c_current != "":
49
+ pred_re.append(c_current)
50
+
51
+ return "".join(pred_re[::-1])
52
+
53
+ def add_special_char(self, dict_character):
54
+ return dict_character
55
+
56
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
57
+ """convert text-index into text-label."""
58
+ result_list = []
59
+ ignored_tokens = self.get_ignored_tokens()
60
+ batch_size = len(text_index)
61
+ for batch_idx in range(batch_size):
62
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
63
+ if is_remove_duplicate:
64
+ selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
65
+ for ignored_token in ignored_tokens:
66
+ selection &= text_index[batch_idx] != ignored_token
67
+
68
+ char_list = [
69
+ self.character[text_id] for text_id in text_index[batch_idx][selection]
70
+ ]
71
+ if text_prob is not None:
72
+ conf_list = text_prob[batch_idx][selection]
73
+ else:
74
+ conf_list = [1] * len(selection)
75
+ if len(conf_list) == 0:
76
+ conf_list = [0]
77
+
78
+ text = "".join(char_list)
79
+
80
+ if self.reverse: # for arabic rec
81
+ text = self.pred_reverse(text)
82
+
83
+ result_list.append((text, np.mean(conf_list).tolist()))
84
+ return result_list
85
+
86
+ def get_ignored_tokens(self):
87
+ return [0] # for ctc blank
88
+
89
+
90
+ class CTCLabelDecode(BaseRecLabelDecode):
91
+ """Convert between text-label and text-index"""
92
+
93
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
94
+ super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
95
+
96
+ def __call__(self, preds, label=None, *args, **kwargs):
97
+ if isinstance(preds, tuple) or isinstance(preds, list):
98
+ preds = preds[-1]
99
+ # if isinstance(preds, paddle.Tensor):
100
+ # preds = preds.numpy()
101
+ preds_idx = preds.argmax(axis=2)
102
+ preds_prob = preds.max(axis=2)
103
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
104
+ if label is None:
105
+ return text
106
+ label = self.decode(label)
107
+ return text, label
108
+
109
+ def add_special_char(self, dict_character):
110
+ dict_character = ["blank"] + dict_character
111
+ return dict_character
112
+
113
+
114
+ class DistillationCTCLabelDecode(CTCLabelDecode):
115
+ """
116
+ Convert
117
+ Convert between text-label and text-index
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ character_dict_path=None,
123
+ use_space_char=False,
124
+ model_name=["student"],
125
+ key=None,
126
+ multi_head=False,
127
+ **kwargs
128
+ ):
129
+ super(DistillationCTCLabelDecode, self).__init__(
130
+ character_dict_path, use_space_char
131
+ )
132
+ if not isinstance(model_name, list):
133
+ model_name = [model_name]
134
+ self.model_name = model_name
135
+
136
+ self.key = key
137
+ self.multi_head = multi_head
138
+
139
+ def __call__(self, preds, label=None, *args, **kwargs):
140
+ output = dict()
141
+ for name in self.model_name:
142
+ pred = preds[name]
143
+ if self.key is not None:
144
+ pred = pred[self.key]
145
+ if self.multi_head and isinstance(pred, dict):
146
+ pred = pred["ctc"]
147
+ output[name] = super().__call__(pred, label=label, *args, **kwargs)
148
+ return output
149
+
150
+
151
+ class AttnLabelDecode(BaseRecLabelDecode):
152
+ """Convert between text-label and text-index"""
153
+
154
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
155
+ super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char)
156
+
157
+ def add_special_char(self, dict_character):
158
+ self.beg_str = "sos"
159
+ self.end_str = "eos"
160
+ dict_character = dict_character
161
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
162
+ return dict_character
163
+
164
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
165
+ """convert text-index into text-label."""
166
+ result_list = []
167
+ ignored_tokens = self.get_ignored_tokens()
168
+ [beg_idx, end_idx] = self.get_ignored_tokens()
169
+ batch_size = len(text_index)
170
+ for batch_idx in range(batch_size):
171
+ char_list = []
172
+ conf_list = []
173
+ for idx in range(len(text_index[batch_idx])):
174
+ if text_index[batch_idx][idx] in ignored_tokens:
175
+ continue
176
+ if int(text_index[batch_idx][idx]) == int(end_idx):
177
+ break
178
+ if is_remove_duplicate:
179
+ # only for predict
180
+ if (
181
+ idx > 0
182
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
183
+ ):
184
+ continue
185
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
186
+ if text_prob is not None:
187
+ conf_list.append(text_prob[batch_idx][idx])
188
+ else:
189
+ conf_list.append(1)
190
+ text = "".join(char_list)
191
+ result_list.append((text, np.mean(conf_list).tolist()))
192
+ return result_list
193
+
194
+ def __call__(self, preds, label=None, *args, **kwargs):
195
+ """
196
+ text = self.decode(text)
197
+ if label is None:
198
+ return text
199
+ else:
200
+ label = self.decode(label, is_remove_duplicate=False)
201
+ return text, label
202
+ """
203
+ if isinstance(preds, paddle.Tensor):
204
+ preds = preds.numpy()
205
+
206
+ preds_idx = preds.argmax(axis=2)
207
+ preds_prob = preds.max(axis=2)
208
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
209
+ if label is None:
210
+ return text
211
+ label = self.decode(label, is_remove_duplicate=False)
212
+ return text, label
213
+
214
+ def get_ignored_tokens(self):
215
+ beg_idx = self.get_beg_end_flag_idx("beg")
216
+ end_idx = self.get_beg_end_flag_idx("end")
217
+ return [beg_idx, end_idx]
218
+
219
+ def get_beg_end_flag_idx(self, beg_or_end):
220
+ if beg_or_end == "beg":
221
+ idx = np.array(self.dict[self.beg_str])
222
+ elif beg_or_end == "end":
223
+ idx = np.array(self.dict[self.end_str])
224
+ else:
225
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
226
+ return idx
227
+
228
+
229
+ class RFLLabelDecode(BaseRecLabelDecode):
230
+ """Convert between text-label and text-index"""
231
+
232
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
233
+ super(RFLLabelDecode, self).__init__(character_dict_path, use_space_char)
234
+
235
+ def add_special_char(self, dict_character):
236
+ self.beg_str = "sos"
237
+ self.end_str = "eos"
238
+ dict_character = dict_character
239
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
240
+ return dict_character
241
+
242
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
243
+ """convert text-index into text-label."""
244
+ result_list = []
245
+ ignored_tokens = self.get_ignored_tokens()
246
+ [beg_idx, end_idx] = self.get_ignored_tokens()
247
+ batch_size = len(text_index)
248
+ for batch_idx in range(batch_size):
249
+ char_list = []
250
+ conf_list = []
251
+ for idx in range(len(text_index[batch_idx])):
252
+ if text_index[batch_idx][idx] in ignored_tokens:
253
+ continue
254
+ if int(text_index[batch_idx][idx]) == int(end_idx):
255
+ break
256
+ if is_remove_duplicate:
257
+ # only for predict
258
+ if (
259
+ idx > 0
260
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
261
+ ):
262
+ continue
263
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
264
+ if text_prob is not None:
265
+ conf_list.append(text_prob[batch_idx][idx])
266
+ else:
267
+ conf_list.append(1)
268
+ text = "".join(char_list)
269
+ result_list.append((text, np.mean(conf_list).tolist()))
270
+ return result_list
271
+
272
+ def __call__(self, preds, label=None, *args, **kwargs):
273
+ # if seq_outputs is not None:
274
+ if isinstance(preds, tuple) or isinstance(preds, list):
275
+ cnt_outputs, seq_outputs = preds
276
+ if isinstance(seq_outputs, paddle.Tensor):
277
+ seq_outputs = seq_outputs.numpy()
278
+ preds_idx = seq_outputs.argmax(axis=2)
279
+ preds_prob = seq_outputs.max(axis=2)
280
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
281
+
282
+ if label is None:
283
+ return text
284
+ label = self.decode(label, is_remove_duplicate=False)
285
+ return text, label
286
+
287
+ else:
288
+ cnt_outputs = preds
289
+ if isinstance(cnt_outputs, paddle.Tensor):
290
+ cnt_outputs = cnt_outputs.numpy()
291
+ cnt_length = []
292
+ for lens in cnt_outputs:
293
+ length = round(np.sum(lens))
294
+ cnt_length.append(length)
295
+ if label is None:
296
+ return cnt_length
297
+ label = self.decode(label, is_remove_duplicate=False)
298
+ length = [len(res[0]) for res in label]
299
+ return cnt_length, length
300
+
301
+ def get_ignored_tokens(self):
302
+ beg_idx = self.get_beg_end_flag_idx("beg")
303
+ end_idx = self.get_beg_end_flag_idx("end")
304
+ return [beg_idx, end_idx]
305
+
306
+ def get_beg_end_flag_idx(self, beg_or_end):
307
+ if beg_or_end == "beg":
308
+ idx = np.array(self.dict[self.beg_str])
309
+ elif beg_or_end == "end":
310
+ idx = np.array(self.dict[self.end_str])
311
+ else:
312
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
313
+ return idx
314
+
315
+
316
+ class SEEDLabelDecode(BaseRecLabelDecode):
317
+ """Convert between text-label and text-index"""
318
+
319
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
320
+ super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char)
321
+
322
+ def add_special_char(self, dict_character):
323
+ self.padding_str = "padding"
324
+ self.end_str = "eos"
325
+ self.unknown = "unknown"
326
+ dict_character = dict_character + [self.end_str, self.padding_str, self.unknown]
327
+ return dict_character
328
+
329
+ def get_ignored_tokens(self):
330
+ end_idx = self.get_beg_end_flag_idx("eos")
331
+ return [end_idx]
332
+
333
+ def get_beg_end_flag_idx(self, beg_or_end):
334
+ if beg_or_end == "sos":
335
+ idx = np.array(self.dict[self.beg_str])
336
+ elif beg_or_end == "eos":
337
+ idx = np.array(self.dict[self.end_str])
338
+ else:
339
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
340
+ return idx
341
+
342
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
343
+ """convert text-index into text-label."""
344
+ result_list = []
345
+ [end_idx] = self.get_ignored_tokens()
346
+ batch_size = len(text_index)
347
+ for batch_idx in range(batch_size):
348
+ char_list = []
349
+ conf_list = []
350
+ for idx in range(len(text_index[batch_idx])):
351
+ if int(text_index[batch_idx][idx]) == int(end_idx):
352
+ break
353
+ if is_remove_duplicate:
354
+ # only for predict
355
+ if (
356
+ idx > 0
357
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
358
+ ):
359
+ continue
360
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
361
+ if text_prob is not None:
362
+ conf_list.append(text_prob[batch_idx][idx])
363
+ else:
364
+ conf_list.append(1)
365
+ text = "".join(char_list)
366
+ result_list.append((text, np.mean(conf_list).tolist()))
367
+ return result_list
368
+
369
+ def __call__(self, preds, label=None, *args, **kwargs):
370
+ """
371
+ text = self.decode(text)
372
+ if label is None:
373
+ return text
374
+ else:
375
+ label = self.decode(label, is_remove_duplicate=False)
376
+ return text, label
377
+ """
378
+ preds_idx = preds["rec_pred"]
379
+ if isinstance(preds_idx, paddle.Tensor):
380
+ preds_idx = preds_idx.numpy()
381
+ if "rec_pred_scores" in preds:
382
+ preds_idx = preds["rec_pred"]
383
+ preds_prob = preds["rec_pred_scores"]
384
+ else:
385
+ preds_idx = preds["rec_pred"].argmax(axis=2)
386
+ preds_prob = preds["rec_pred"].max(axis=2)
387
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
388
+ if label is None:
389
+ return text
390
+ label = self.decode(label, is_remove_duplicate=False)
391
+ return text, label
392
+
393
+
394
+ class SRNLabelDecode(BaseRecLabelDecode):
395
+ """Convert between text-label and text-index"""
396
+
397
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
398
+ super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char)
399
+ self.max_text_length = kwargs.get("max_text_length", 25)
400
+
401
+ def __call__(self, preds, label=None, *args, **kwargs):
402
+ pred = preds["predict"]
403
+ char_num = len(self.character_str) + 2
404
+ if isinstance(pred, paddle.Tensor):
405
+ pred = pred.numpy()
406
+ pred = np.reshape(pred, [-1, char_num])
407
+
408
+ preds_idx = np.argmax(pred, axis=1)
409
+ preds_prob = np.max(pred, axis=1)
410
+
411
+ preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
412
+
413
+ preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
414
+
415
+ text = self.decode(preds_idx, preds_prob)
416
+
417
+ if label is None:
418
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
419
+ return text
420
+ label = self.decode(label)
421
+ return text, label
422
+
423
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
424
+ """convert text-index into text-label."""
425
+ result_list = []
426
+ ignored_tokens = self.get_ignored_tokens()
427
+ batch_size = len(text_index)
428
+
429
+ for batch_idx in range(batch_size):
430
+ char_list = []
431
+ conf_list = []
432
+ for idx in range(len(text_index[batch_idx])):
433
+ if text_index[batch_idx][idx] in ignored_tokens:
434
+ continue
435
+ if is_remove_duplicate:
436
+ # only for predict
437
+ if (
438
+ idx > 0
439
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
440
+ ):
441
+ continue
442
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
443
+ if text_prob is not None:
444
+ conf_list.append(text_prob[batch_idx][idx])
445
+ else:
446
+ conf_list.append(1)
447
+
448
+ text = "".join(char_list)
449
+ result_list.append((text, np.mean(conf_list).tolist()))
450
+ return result_list
451
+
452
+ def add_special_char(self, dict_character):
453
+ dict_character = dict_character + [self.beg_str, self.end_str]
454
+ return dict_character
455
+
456
+ def get_ignored_tokens(self):
457
+ beg_idx = self.get_beg_end_flag_idx("beg")
458
+ end_idx = self.get_beg_end_flag_idx("end")
459
+ return [beg_idx, end_idx]
460
+
461
+ def get_beg_end_flag_idx(self, beg_or_end):
462
+ if beg_or_end == "beg":
463
+ idx = np.array(self.dict[self.beg_str])
464
+ elif beg_or_end == "end":
465
+ idx = np.array(self.dict[self.end_str])
466
+ else:
467
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
468
+ return idx
469
+
470
+
471
+ class SARLabelDecode(BaseRecLabelDecode):
472
+ """Convert between text-label and text-index"""
473
+
474
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
475
+ super(SARLabelDecode, self).__init__(character_dict_path, use_space_char)
476
+
477
+ self.rm_symbol = kwargs.get("rm_symbol", False)
478
+
479
+ def add_special_char(self, dict_character):
480
+ beg_end_str = "<BOS/EOS>"
481
+ unknown_str = "<UKN>"
482
+ padding_str = "<PAD>"
483
+ dict_character = dict_character + [unknown_str]
484
+ self.unknown_idx = len(dict_character) - 1
485
+ dict_character = dict_character + [beg_end_str]
486
+ self.start_idx = len(dict_character) - 1
487
+ self.end_idx = len(dict_character) - 1
488
+ dict_character = dict_character + [padding_str]
489
+ self.padding_idx = len(dict_character) - 1
490
+ return dict_character
491
+
492
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
493
+ """convert text-index into text-label."""
494
+ result_list = []
495
+ ignored_tokens = self.get_ignored_tokens()
496
+
497
+ batch_size = len(text_index)
498
+ for batch_idx in range(batch_size):
499
+ char_list = []
500
+ conf_list = []
501
+ for idx in range(len(text_index[batch_idx])):
502
+ if text_index[batch_idx][idx] in ignored_tokens:
503
+ continue
504
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
505
+ if text_prob is None and idx == 0:
506
+ continue
507
+ else:
508
+ break
509
+ if is_remove_duplicate:
510
+ # only for predict
511
+ if (
512
+ idx > 0
513
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
514
+ ):
515
+ continue
516
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
517
+ if text_prob is not None:
518
+ conf_list.append(text_prob[batch_idx][idx])
519
+ else:
520
+ conf_list.append(1)
521
+ text = "".join(char_list)
522
+ if self.rm_symbol:
523
+ comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]")
524
+ text = text.lower()
525
+ text = comp.sub("", text)
526
+ result_list.append((text, np.mean(conf_list).tolist()))
527
+ return result_list
528
+
529
+ def __call__(self, preds, label=None, *args, **kwargs):
530
+ if isinstance(preds, paddle.Tensor):
531
+ preds = preds.numpy()
532
+ preds_idx = preds.argmax(axis=2)
533
+ preds_prob = preds.max(axis=2)
534
+
535
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
536
+
537
+ if label is None:
538
+ return text
539
+ label = self.decode(label, is_remove_duplicate=False)
540
+ return text, label
541
+
542
+ def get_ignored_tokens(self):
543
+ return [self.padding_idx]
544
+
545
+
546
+ class DistillationSARLabelDecode(SARLabelDecode):
547
+ """
548
+ Convert
549
+ Convert between text-label and text-index
550
+ """
551
+
552
+ def __init__(
553
+ self,
554
+ character_dict_path=None,
555
+ use_space_char=False,
556
+ model_name=["student"],
557
+ key=None,
558
+ multi_head=False,
559
+ **kwargs
560
+ ):
561
+ super(DistillationSARLabelDecode, self).__init__(
562
+ character_dict_path, use_space_char
563
+ )
564
+ if not isinstance(model_name, list):
565
+ model_name = [model_name]
566
+ self.model_name = model_name
567
+
568
+ self.key = key
569
+ self.multi_head = multi_head
570
+
571
+ def __call__(self, preds, label=None, *args, **kwargs):
572
+ output = dict()
573
+ for name in self.model_name:
574
+ pred = preds[name]
575
+ if self.key is not None:
576
+ pred = pred[self.key]
577
+ if self.multi_head and isinstance(pred, dict):
578
+ pred = pred["sar"]
579
+ output[name] = super().__call__(pred, label=label, *args, **kwargs)
580
+ return output
581
+
582
+
583
+ class PRENLabelDecode(BaseRecLabelDecode):
584
+ """Convert between text-label and text-index"""
585
+
586
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
587
+ super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char)
588
+
589
+ def add_special_char(self, dict_character):
590
+ padding_str = "<PAD>" # 0
591
+ end_str = "<EOS>" # 1
592
+ unknown_str = "<UNK>" # 2
593
+
594
+ dict_character = [padding_str, end_str, unknown_str] + dict_character
595
+ self.padding_idx = 0
596
+ self.end_idx = 1
597
+ self.unknown_idx = 2
598
+
599
+ return dict_character
600
+
601
+ def decode(self, text_index, text_prob=None):
602
+ """convert text-index into text-label."""
603
+ result_list = []
604
+ batch_size = len(text_index)
605
+
606
+ for batch_idx in range(batch_size):
607
+ char_list = []
608
+ conf_list = []
609
+ for idx in range(len(text_index[batch_idx])):
610
+ if text_index[batch_idx][idx] == self.end_idx:
611
+ break
612
+ if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]:
613
+ continue
614
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
615
+ if text_prob is not None:
616
+ conf_list.append(text_prob[batch_idx][idx])
617
+ else:
618
+ conf_list.append(1)
619
+
620
+ text = "".join(char_list)
621
+ if len(text) > 0:
622
+ result_list.append((text, np.mean(conf_list).tolist()))
623
+ else:
624
+ # here confidence of empty recog result is 1
625
+ result_list.append(("", 1))
626
+ return result_list
627
+
628
+ def __call__(self, preds, label=None, *args, **kwargs):
629
+ if isinstance(preds, paddle.Tensor):
630
+ preds = preds.numpy()
631
+ preds_idx = preds.argmax(axis=2)
632
+ preds_prob = preds.max(axis=2)
633
+ text = self.decode(preds_idx, preds_prob)
634
+ if label is None:
635
+ return text
636
+ label = self.decode(label)
637
+ return text, label
638
+
639
+
640
+ class NRTRLabelDecode(BaseRecLabelDecode):
641
+ """Convert between text-label and text-index"""
642
+
643
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
644
+ super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char)
645
+
646
+ def __call__(self, preds, label=None, *args, **kwargs):
647
+
648
+ if len(preds) == 2:
649
+ preds_id = preds[0]
650
+ preds_prob = preds[1]
651
+ if isinstance(preds_id, paddle.Tensor):
652
+ preds_id = preds_id.numpy()
653
+ if isinstance(preds_prob, paddle.Tensor):
654
+ preds_prob = preds_prob.numpy()
655
+ if preds_id[0][0] == 2:
656
+ preds_idx = preds_id[:, 1:]
657
+ preds_prob = preds_prob[:, 1:]
658
+ else:
659
+ preds_idx = preds_id
660
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
661
+ if label is None:
662
+ return text
663
+ label = self.decode(label[:, 1:])
664
+ else:
665
+ if isinstance(preds, paddle.Tensor):
666
+ preds = preds.numpy()
667
+ preds_idx = preds.argmax(axis=2)
668
+ preds_prob = preds.max(axis=2)
669
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
670
+ if label is None:
671
+ return text
672
+ label = self.decode(label[:, 1:])
673
+ return text, label
674
+
675
+ def add_special_char(self, dict_character):
676
+ dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
677
+ return dict_character
678
+
679
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
680
+ """convert text-index into text-label."""
681
+ result_list = []
682
+ batch_size = len(text_index)
683
+ for batch_idx in range(batch_size):
684
+ char_list = []
685
+ conf_list = []
686
+ for idx in range(len(text_index[batch_idx])):
687
+ try:
688
+ char_idx = self.character[int(text_index[batch_idx][idx])]
689
+ except:
690
+ continue
691
+ if char_idx == "</s>": # end
692
+ break
693
+ char_list.append(char_idx)
694
+ if text_prob is not None:
695
+ conf_list.append(text_prob[batch_idx][idx])
696
+ else:
697
+ conf_list.append(1)
698
+ text = "".join(char_list)
699
+ result_list.append((text.lower(), np.mean(conf_list).tolist()))
700
+ return result_list
701
+
702
+
703
+ class ViTSTRLabelDecode(NRTRLabelDecode):
704
+ """Convert between text-label and text-index"""
705
+
706
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
707
+ super(ViTSTRLabelDecode, self).__init__(character_dict_path, use_space_char)
708
+
709
+ def __call__(self, preds, label=None, *args, **kwargs):
710
+ if isinstance(preds, paddle.Tensor):
711
+ preds = preds[:, 1:].numpy()
712
+ else:
713
+ preds = preds[:, 1:]
714
+ preds_idx = preds.argmax(axis=2)
715
+ preds_prob = preds.max(axis=2)
716
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
717
+ if label is None:
718
+ return text
719
+ label = self.decode(label[:, 1:])
720
+ return text, label
721
+
722
+ def add_special_char(self, dict_character):
723
+ dict_character = ["<s>", "</s>"] + dict_character
724
+ return dict_character
725
+
726
+
727
+ class ABINetLabelDecode(NRTRLabelDecode):
728
+ """Convert between text-label and text-index"""
729
+
730
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
731
+ super(ABINetLabelDecode, self).__init__(character_dict_path, use_space_char)
732
+
733
+ def __call__(self, preds, label=None, *args, **kwargs):
734
+ if isinstance(preds, dict):
735
+ preds = preds["align"][-1].numpy()
736
+ elif isinstance(preds, paddle.Tensor):
737
+ preds = preds.numpy()
738
+ else:
739
+ preds = preds
740
+
741
+ preds_idx = preds.argmax(axis=2)
742
+ preds_prob = preds.max(axis=2)
743
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
744
+ if label is None:
745
+ return text
746
+ label = self.decode(label)
747
+ return text, label
748
+
749
+ def add_special_char(self, dict_character):
750
+ dict_character = ["</s>"] + dict_character
751
+ return dict_character
752
+
753
+
754
+ class SPINLabelDecode(AttnLabelDecode):
755
+ """Convert between text-label and text-index"""
756
+
757
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
758
+ super(SPINLabelDecode, self).__init__(character_dict_path, use_space_char)
759
+
760
+ def add_special_char(self, dict_character):
761
+ self.beg_str = "sos"
762
+ self.end_str = "eos"
763
+ dict_character = dict_character
764
+ dict_character = [self.beg_str] + [self.end_str] + dict_character
765
+ return dict_character
766
+
767
+
768
+ # class VLLabelDecode(BaseRecLabelDecode):
769
+ # """ Convert between text-label and text-index """
770
+ #
771
+ # def __init__(self, character_dict_path=None, use_space_char=False,
772
+ # **kwargs):
773
+ # super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
774
+ # self.max_text_length = kwargs.get('max_text_length', 25)
775
+ # self.nclass = len(self.character) + 1
776
+ #
777
+ # def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
778
+ # """ convert text-index into text-label. """
779
+ # result_list = []
780
+ # ignored_tokens = self.get_ignored_tokens()
781
+ # batch_size = len(text_index)
782
+ # for batch_idx in range(batch_size):
783
+ # selection = np.ones(len(text_index[batch_idx]), dtype=bool)
784
+ # if is_remove_duplicate:
785
+ # selection[1:] = text_index[batch_idx][1:] != text_index[
786
+ # batch_idx][:-1]
787
+ # for ignored_token in ignored_tokens:
788
+ # selection &= text_index[batch_idx] != ignored_token
789
+ #
790
+ # char_list = [
791
+ # self.character[text_id - 1]
792
+ # for text_id in text_index[batch_idx][selection]
793
+ # ]
794
+ # if text_prob is not None:
795
+ # conf_list = text_prob[batch_idx][selection]
796
+ # else:
797
+ # conf_list = [1] * len(selection)
798
+ # if len(conf_list) == 0:
799
+ # conf_list = [0]
800
+ #
801
+ # text = ''.join(char_list)
802
+ # result_list.append((text, np.mean(conf_list).tolist()))
803
+ # return result_list
804
+ #
805
+ # def __call__(self, preds, label=None, length=None, *args, **kwargs):
806
+ # if len(preds) == 2: # eval mode
807
+ # text_pre, x = preds
808
+ # b = text_pre.shape[1]
809
+ # lenText = self.max_text_length
810
+ # nsteps = self.max_text_length
811
+ #
812
+ # if not isinstance(text_pre, paddle.Tensor):
813
+ # text_pre = paddle.to_tensor(text_pre, dtype='float32')
814
+ #
815
+ # out_res = paddle.zeros(
816
+ # shape=[lenText, b, self.nclass], dtype=x.dtype)
817
+ # out_length = paddle.zeros(shape=[b], dtype=x.dtype)
818
+ # now_step = 0
819
+ # for _ in range(nsteps):
820
+ # if 0 in out_length and now_step < nsteps:
821
+ # tmp_result = text_pre[now_step, :, :]
822
+ # out_res[now_step] = tmp_result
823
+ # tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
824
+ # for j in range(b):
825
+ # if out_length[j] == 0 and tmp_result[j] == 0:
826
+ # out_length[j] = now_step + 1
827
+ # now_step += 1
828
+ # for j in range(0, b):
829
+ # if int(out_length[j]) == 0:
830
+ # out_length[j] = nsteps
831
+ # start = 0
832
+ # output = paddle.zeros(
833
+ # shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
834
+ # for i in range(0, b):
835
+ # cur_length = int(out_length[i])
836
+ # output[start:start + cur_length] = out_res[0:cur_length, i, :]
837
+ # start += cur_length
838
+ # net_out = output
839
+ # length = out_length
840
+ #
841
+ # else: # train mode
842
+ # net_out = preds[0]
843
+ # length = length
844
+ # net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
845
+ # text = []
846
+ # if not isinstance(net_out, paddle.Tensor):
847
+ # net_out = paddle.to_tensor(net_out, dtype='float32')
848
+ # net_out = F.softmax(net_out, axis=1)
849
+ # for i in range(0, length.shape[0]):
850
+ # preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
851
+ # ) + length[i])].topk(1)[1][:, 0].tolist()
852
+ # preds_text = ''.join([
853
+ # self.character[idx - 1]
854
+ # if idx > 0 and idx <= len(self.character) else ''
855
+ # for idx in preds_idx
856
+ # ])
857
+ # preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
858
+ # ) + length[i])].topk(1)[0][:, 0]
859
+ # preds_prob = paddle.exp(
860
+ # paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
861
+ # text.append((preds_text, preds_prob.numpy()[0]))
862
+ # if label is None:
863
+ # return text
864
+ # label = self.decode(label)
865
+ # return text, label
866
+
867
+
868
+ class CANLabelDecode(BaseRecLabelDecode):
869
+ """Convert between latex-symbol and symbol-index"""
870
+
871
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
872
+ super(CANLabelDecode, self).__init__(character_dict_path, use_space_char)
873
+
874
+ def decode(self, text_index, preds_prob=None):
875
+ result_list = []
876
+ batch_size = len(text_index)
877
+ for batch_idx in range(batch_size):
878
+ seq_end = text_index[batch_idx].argmin(0)
879
+ idx_list = text_index[batch_idx][:seq_end].tolist()
880
+ symbol_list = [self.character[idx] for idx in idx_list]
881
+ probs = []
882
+ if preds_prob is not None:
883
+ probs = preds_prob[batch_idx][: len(symbol_list)].tolist()
884
+
885
+ result_list.append([" ".join(symbol_list), probs])
886
+ return result_list
887
+
888
+ def __call__(self, preds, label=None, *args, **kwargs):
889
+ pred_prob, _, _, _ = preds
890
+ preds_idx = pred_prob.argmax(axis=2)
891
+
892
+ text = self.decode(preds_idx)
893
+ if label is None:
894
+ return text
895
+ label = self.decode(label)
896
+ return text, label