openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,587 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from openrec.modeling.common import Mlp
6
+ from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, Embeddings, MultiheadAttention
7
+
8
+
9
+ class MDiffDecoder(nn.Module):
10
+ """A transformer model. User is able to modify the attributes as needed.
11
+ The architechture is based on the paper "Attention Is All You Need". Ashish
12
+ Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
13
+ Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you
14
+ need. In Advances in Neural Information Processing Systems, pages
15
+ 6000-6010.
16
+
17
+ Args:
18
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
19
+ nhead: the number of heads in the multiheadattention models (default=8).
20
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
21
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
22
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
23
+ dropout: the dropout value (default=0.1).
24
+ custom_encoder: custom encoder (default=None).
25
+ custom_decoder: custom decoder (default=None).
26
+ """
27
+
28
+ def __init__(self,
29
+ in_channels,
30
+ out_channels,
31
+ nhead=None,
32
+ num_decoder_layers=6,
33
+ max_len=25,
34
+ attention_dropout_rate=0.0,
35
+ residual_dropout_rate=0.1,
36
+ scale_embedding=True,
37
+ parallel_decoding=False,
38
+ autoregressive_decoding=False,
39
+ sampler_step=5,
40
+ low_confidence_decoding=False,
41
+ random_mask_decoding=False,
42
+ semi_autoregressive_decoding=False,
43
+ cloze_mask_decoding=False,
44
+ rec_loss_weight=1.0,
45
+ reflect_loss_weight=1.0,
46
+ sample_k=0,
47
+ temperature=1.0):
48
+ super(MDiffDecoder, self).__init__()
49
+ self.out_channels = out_channels
50
+ self.ignore_index = out_channels - 1
51
+ self.mask_token_id = out_channels - 2
52
+ self.eos = 0
53
+ self.max_len = max_len
54
+ d_model = in_channels
55
+ dim_feedforward = d_model * 4
56
+ self.pd = parallel_decoding
57
+ self.ar = autoregressive_decoding
58
+ self.sampler_step = sampler_step
59
+ self.lc = low_confidence_decoding
60
+ self.rm = random_mask_decoding
61
+ self.semiar = semi_autoregressive_decoding
62
+ self.cm = cloze_mask_decoding
63
+ self.rec_loss_weight = rec_loss_weight
64
+ self.reflect_loss_weight = reflect_loss_weight
65
+ self.temperature = temperature
66
+ self.sample_k = sample_k
67
+ nhead = nhead if nhead is not None else d_model // 32
68
+ self.embedding = Embeddings(
69
+ d_model=d_model,
70
+ vocab=self.out_channels,
71
+ padding_idx=0,
72
+ scale_embedding=scale_embedding,
73
+ )
74
+ self.pos_embed = nn.Parameter(torch.zeros(
75
+ [1, self.max_len + 1, d_model], dtype=torch.float32),
76
+ requires_grad=True)
77
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
78
+ self.positional_encoding = PositionalEncoding(
79
+ dropout=residual_dropout_rate, dim=d_model)
80
+
81
+ self.decoder = nn.ModuleList([
82
+ TransformerBlock(
83
+ d_model,
84
+ nhead,
85
+ dim_feedforward,
86
+ attention_dropout_rate,
87
+ residual_dropout_rate,
88
+ with_self_attn=True,
89
+ with_cross_attn=True,
90
+ ) for i in range(num_decoder_layers)
91
+ ])
92
+
93
+ self.num_decoder_layers = num_decoder_layers
94
+
95
+ self.d_model = d_model
96
+ self.nhead = nhead
97
+ self.tgt_word_prj = nn.Linear(d_model,
98
+ self.out_channels - 2,
99
+ bias=False)
100
+ w0 = np.random.normal(0.0, d_model**-0.5,
101
+ (d_model, self.out_channels - 2)).astype(
102
+ np.float32)
103
+ self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose())
104
+ self.apply(self._init_weights)
105
+
106
+ def _init_weights(self, m):
107
+ if isinstance(m, nn.Linear):
108
+ nn.init.xavier_normal_(m.weight)
109
+ if m.bias is not None:
110
+ nn.init.zeros_(m.bias)
111
+
112
+ def forward_train(self, memory, data=None):
113
+ labels, reflect_ids, noisy_batch, masked_indices, p_mask, length = data
114
+ p_mask = p_mask[:, None].repeat(1, labels.shape[1])
115
+ noisy_data_length = length + 1
116
+ noisy_data_length = noisy_data_length[:,
117
+ None].repeat(1, labels.shape[1])
118
+
119
+ tgts = self.embedding(noisy_batch)
120
+ tgts = self.positional_encoding(tgts) + self.pos_embed
121
+
122
+ for decoder_layer in self.decoder:
123
+ tgts = decoder_layer(tgts, memory, self_mask=None)
124
+ logits = self.tgt_word_prj(tgts)
125
+ token_loss = F.cross_entropy(
126
+ logits[masked_indices],
127
+ labels[masked_indices],
128
+ reduction='none',
129
+ ignore_index=self.ignore_index) / p_mask[masked_indices]
130
+ loss = torch.sum(
131
+ token_loss / noisy_data_length[masked_indices]) / labels.shape[0]
132
+
133
+ if reflect_ids is not None:
134
+ reflect_tgts = self.embedding(reflect_ids)
135
+ reflect_tgts = self.positional_encoding(
136
+ reflect_tgts) + self.pos_embed
137
+ for decoder_layer in self.decoder:
138
+ reflect_tgts = decoder_layer(reflect_tgts,
139
+ memory,
140
+ self_mask=None)
141
+ reflect_logits = self.tgt_word_prj(reflect_tgts)
142
+ reflect_loss = F.cross_entropy(reflect_logits.flatten(0, 1),
143
+ labels.flatten(0, 1),
144
+ reduction='mean',
145
+ ignore_index=self.ignore_index)
146
+ loss = self.rec_loss_weight * loss + self.reflect_loss_weight * reflect_loss
147
+
148
+ return loss
149
+
150
+ def forward_train_all(self, memory, data=None):
151
+
152
+ labels, reflect_ids_all, noisy_batch_all, masked_indices_all, p_mask_all, length = data
153
+ bs, L = labels.shape
154
+ tgts = self.embedding(noisy_batch_all.flatten(0, 1))
155
+ tgts = self.positional_encoding(tgts) + self.pos_embed
156
+ tgts = tgts.reshape(bs, self.sample_k, L, -1)
157
+
158
+ for decoder_layer in self.decoder:
159
+ tgts = decoder_layer(tgts,
160
+ memory,
161
+ self_mask=None,
162
+ sample_k=self.sample_k)
163
+ logits_all = self.tgt_word_prj(tgts) # bs, sample_k, L, c_num
164
+
165
+ reflect_tgts = self.embedding(reflect_ids_all.flatten(0, 1))
166
+ reflect_tgts = self.positional_encoding(reflect_tgts) + self.pos_embed
167
+ reflect_tgts = reflect_tgts.reshape(bs, self.sample_k, L, -1)
168
+
169
+ for decoder_layer in self.decoder:
170
+ reflect_tgts = decoder_layer(reflect_tgts,
171
+ memory,
172
+ self_mask=None,
173
+ sample_k=self.sample_k)
174
+ reflect_logits_all = self.tgt_word_prj(reflect_tgts)
175
+
176
+ loss = []
177
+ for i in range(self.sample_k):
178
+ p_mask = p_mask_all[:, i]
179
+ masked_indices = masked_indices_all[:, i]
180
+ logits = logits_all[:, i]
181
+
182
+ p_mask = p_mask[:, None].repeat(1, labels.shape[1])
183
+ noisy_data_length = length + 1
184
+ noisy_data_length = noisy_data_length[:, None].repeat(
185
+ 1, labels.shape[1])
186
+ token_loss = F.cross_entropy(
187
+ logits[masked_indices],
188
+ labels[masked_indices],
189
+ reduction='none',
190
+ ignore_index=self.ignore_index) / p_mask[masked_indices]
191
+ denoise_loss_i = torch.sum(
192
+ token_loss /
193
+ noisy_data_length[masked_indices]) / labels.shape[0]
194
+
195
+ reflect_logits = reflect_logits_all[:, i]
196
+ reflect_loss_i = F.cross_entropy(reflect_logits.flatten(0, 1),
197
+ labels.flatten(0, 1),
198
+ reduction='mean',
199
+ ignore_index=self.ignore_index)
200
+ loss_i = self.rec_loss_weight * denoise_loss_i + self.reflect_loss_weight * reflect_loss_i
201
+ loss.append(loss_i)
202
+
203
+ return sum(loss) / len(loss)
204
+
205
+ def forward(self, src, data=None):
206
+ """Take in and process masked source/target sequences.
207
+ Args:
208
+ src: the sequence to the encoder (required).
209
+ tgt: the sequence to the decoder (required).
210
+ Shape:
211
+ - src: :math:`(B, sN, C)`.
212
+ - tgt: :math:`(B, tN, C)`.
213
+ Examples:
214
+ >>> output = transformer_model(src, tgt)
215
+ """
216
+
217
+ if self.training:
218
+ if self.sample_k > 0:
219
+ res = self.forward_train_all(src, data)
220
+ else:
221
+ res = self.forward_train(src, data)
222
+ else:
223
+ if self.pd:
224
+ res = self.forward_parallel_decoding(src)
225
+ elif self.ar:
226
+ res = self.forward_autoregressive_decoding(src)
227
+ elif self.lc:
228
+ res = self.forward_low_confidence_decoding(src)
229
+ elif self.rm:
230
+ res = self.forward_random_mask_decoding(src)
231
+ elif self.semiar:
232
+ res = self.forward_semi_autoregressive_decoding(src)
233
+ elif self.cm:
234
+ res = self.forward_cloze_mask_decoding(src)
235
+ else:
236
+ res = self.forward_parallel_decoding(src)
237
+
238
+ return res
239
+
240
+ def forward_decoding(self, src, tgts, step_i=0):
241
+
242
+ tgts = self.embedding(tgts)
243
+ tgts = self.positional_encoding(tgts) + self.pos_embed
244
+ for decoder_layer in self.decoder:
245
+ tgts = decoder_layer(tgts, src, self_mask=None)
246
+
247
+ return tgts
248
+
249
+ def forward_reflect(self, src, pred_indexs, step_i=0):
250
+ """Reflect decoding."""
251
+
252
+ # reflect
253
+ masked_indices_eos = self.get_masked_indice_after_eos(
254
+ pred_indexs
255
+ ) # [bs, max_len + 1] bool tensor False False(eos) True True ..
256
+ pred_indexs[
257
+ masked_indices_eos] = self.mask_token_id # 保留eos之后的token为mask token
258
+
259
+ reflect_tgts = self.forward_decoding(src, pred_indexs, step_i=step_i)
260
+ logits_reflect = F.softmax(self.tgt_word_prj(reflect_tgts), -1)
261
+
262
+ return logits_reflect
263
+
264
+ def forward_parallel_decoding(self, src):
265
+ bs = src.shape[0]
266
+ noisy_batch = torch.full((bs, self.max_len + 1),
267
+ self.mask_token_id,
268
+ dtype=torch.int64,
269
+ device=src.get_device())
270
+ tgts = self.forward_decoding(src, noisy_batch)
271
+ logits = F.softmax(self.tgt_word_prj(tgts), -1)
272
+ return logits
273
+
274
+ def get_masked_indice_after_eos(self, noisy_batch):
275
+ """Get the indices of the masked tokens after the first EOS token."""
276
+ # noisy_batch: [batch_size, max_len + 1]
277
+ eos_mask = noisy_batch == self.eos # [batch_size, seq_len]
278
+
279
+ # 找到每行第一个eos的位置
280
+ eos_indices = eos_mask.float().argmax(dim=1) # [batch_size]
281
+
282
+ # 如果没有eos,argmax会返回0,但我们不想在这些地方mask,需要过滤
283
+ eos_exists = eos_mask.any(dim=1) # [batch_size]
284
+
285
+ batch_size, seq_len = noisy_batch.shape
286
+ arange = torch.arange(seq_len,
287
+ device=noisy_batch.device).unsqueeze(0).expand(
288
+ batch_size, -1) # [batch_size, seq_len]
289
+
290
+ # 创建掩码:只对eos之后的token设为True
291
+ masked_indices = arange > eos_indices.unsqueeze(1)
292
+ masked_indices = masked_indices | ~eos_exists.unsqueeze(1)
293
+
294
+ return masked_indices
295
+
296
+ def forward_low_confidence_decoding(self, src):
297
+ bs = src.shape[0]
298
+ noisy_batch = torch.full((bs, self.max_len + 1),
299
+ self.mask_token_id,
300
+ dtype=torch.int64,
301
+ device=src.get_device())
302
+ masked_indices_pre = torch.full((bs, self.max_len + 1),
303
+ True,
304
+ dtype=torch.bool,
305
+ device=src.get_device())
306
+ flag_exit = False
307
+ for step_i in range(self.sampler_step):
308
+
309
+ tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
310
+ pred_step = self.tgt_word_prj(tgts)
311
+ pred_step = F.softmax(pred_step, -1)
312
+ if step_i == 0:
313
+ logits = pred_step.clone()
314
+ logits[masked_indices_pre] = pred_step[masked_indices_pre]
315
+ pred_step_prob, pred_step_index = torch.max(
316
+ pred_step, dim=-1) # [bs, max_len + 1], [bs, max_len + 1]
317
+ masked_indices_eos = self.get_masked_indice_after_eos(
318
+ pred_step_index
319
+ ) # [bs, max_len + 1] bool tensor False False(eos) True True ..
320
+
321
+ # 仅计算mask token位置以及eos之前token的平均概率
322
+ valid_indices = masked_indices_pre & ~masked_indices_eos
323
+ pred_step_prob = pred_step_prob * valid_indices.float()
324
+ pred_step_prob_avg = pred_step_prob.sum(
325
+ dim=1, keepdim=True) / valid_indices.sum(
326
+ dim=1, keepdim=True) # [bs, 1]
327
+
328
+ # 高于平均置信度的token
329
+ top_confidence_mask = pred_step_prob > pred_step_prob_avg
330
+ top_confidence_mask = top_confidence_mask & valid_indices
331
+ noisy_batch[top_confidence_mask] = pred_step_index[
332
+ top_confidence_mask]
333
+ # 低置信度的token或者eos之后的token均保留为 self.mask_token_id, 其他则替换为 pred_step_index
334
+ masked_indices_pre = noisy_batch == self.mask_token_id
335
+ masked_indices_vaild = masked_indices_pre & ~masked_indices_eos
336
+ if flag_exit:
337
+ # 如果已经满足退出条件,直接返回
338
+ break
339
+ if (masked_indices_vaild.sum(dim=-1) <= 1).all():
340
+ # 如果每个batch中只有一个或者0个token被mask,说明下次已经没有足够的token可以被mask了,再进行一次就结束
341
+ flag_exit = True
342
+
343
+ return logits
344
+
345
+ def forward_random_mask_decoding(self, src):
346
+ bs = src.shape[0]
347
+ noisy_batch = torch.full((bs, self.max_len + 1),
348
+ self.mask_token_id,
349
+ dtype=torch.int64,
350
+ device=src.get_device())
351
+ masked_indices_pre = torch.full((bs, self.max_len + 1),
352
+ True,
353
+ dtype=torch.bool,
354
+ device=src.get_device())
355
+ flag_exit = False
356
+ for step_i in range(self.sampler_step):
357
+
358
+ tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
359
+
360
+ pred_step = self.tgt_word_prj(tgts)
361
+ pred_step = F.softmax(pred_step, -1)
362
+ if step_i == 0:
363
+ logits = pred_step.clone()
364
+ else:
365
+ logits[masked_indices_pre] = pred_step[masked_indices_pre]
366
+ pred_step_prob, pred_step_index = torch.max(
367
+ pred_step, dim=-1) # [bs, max_len + 1], [bs, max_len + 1]
368
+ masked_indices_eos = self.get_masked_indice_after_eos(
369
+ pred_step_index) # [bs, max_len + 1] bool tensor
370
+
371
+ # 采用mask token位置以及eos之前token作为可用token
372
+ valid_indices = masked_indices_pre & ~masked_indices_eos
373
+ # 在这些可用token中随机选择一些进行mask
374
+ rand_mask_prob = torch.rand((bs, self.max_len + 1),
375
+ device=src.get_device())
376
+ # rand_mask_prob = rand_mask_prob * valid_indices.float()
377
+ random_res = rand_mask_prob > 0.5 # 50%的概率进行mask
378
+ # 仅保留mask token位置以及eos之前token的高置信度token
379
+ random_res = random_res & valid_indices
380
+ # random_mask = random_mask & masked_indices_pre
381
+ noisy_batch[random_res] = pred_step_index[random_res]
382
+ # 随机mask token或者eos之后的token均保留为 self.mask_token_id, 其他则替换为 pred_step_index
383
+ masked_indices_pre = noisy_batch == self.mask_token_id
384
+ masked_indices_vaild = masked_indices_pre & ~masked_indices_eos
385
+ if flag_exit:
386
+ # 如果已经满足退出条件,直接返回
387
+ break
388
+ if (masked_indices_vaild.sum(dim=-1) <= 1).all():
389
+ # 如果每个batch中只有一个或者0个token被mask,说明下次已经没有足够的token可以被mask了,再进行一次就结束
390
+ flag_exit = True
391
+
392
+ return logits
393
+
394
+ def forward_semi_autoregressive_decoding(self, src):
395
+ bs = src.shape[0]
396
+ noisy_batch = torch.full((bs, self.max_len + 1),
397
+ self.mask_token_id,
398
+ dtype=torch.int64,
399
+ device=src.get_device())
400
+ block_size = (self.max_len + 1) // self.sampler_step
401
+ masked_indices_pre = torch.full((bs, self.max_len + 1),
402
+ True,
403
+ dtype=torch.bool,
404
+ device=src.get_device())
405
+ flag_exit = False
406
+ for step_i in range(self.sampler_step):
407
+
408
+ tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
409
+
410
+ pred_step = self.tgt_word_prj(tgts)
411
+
412
+ pred_step = pred_step / self.temperature
413
+ pred_step = F.softmax(pred_step, -1)
414
+ if step_i == 0:
415
+ logits = pred_step.clone()
416
+ else:
417
+ logits[masked_indices_pre] = pred_step[masked_indices_pre]
418
+ pred_step_prob, pred_step_index = torch.max(
419
+ pred_step, dim=-1) # [bs, max_len + 1], [bs, max_len + 1]
420
+ masked_indices_eos = self.get_masked_indice_after_eos(
421
+ pred_step_index
422
+ ) # [bs, max_len + 1] bool tensor False False(eos) True True ..
423
+
424
+ block_vaild_indices = torch.full((bs, self.max_len + 1),
425
+ False,
426
+ dtype=torch.bool,
427
+ device=src.get_device())
428
+
429
+ if step_i <= 2:
430
+ if self.sampler_step > 2:
431
+ block_vaild_indices[:, :block_size * (step_i + 1)] = True
432
+ else:
433
+ block_vaild_indices = ~block_vaild_indices
434
+ elif step_i >= self.sampler_step - 2:
435
+ block_vaild_indices[:, block_size * (step_i - 1):] = True
436
+ else:
437
+ block_vaild_indices[:, block_size * (step_i - 1):block_size *
438
+ (step_i + 1)] = True
439
+
440
+ # 仅计算mask token位置, eos之前token以及当前block中token的平均概率
441
+ valid_indices = masked_indices_pre & ~masked_indices_eos & block_vaild_indices
442
+ pred_step_prob = pred_step_prob * valid_indices.float()
443
+ pred_step_prob_avg = pred_step_prob.sum(
444
+ dim=1, keepdim=True) / valid_indices.sum(
445
+ dim=1, keepdim=True) # [bs, 1]
446
+
447
+ # 高于平均置信度的token
448
+ top_confidence_mask = pred_step_prob > pred_step_prob_avg
449
+ top_confidence_mask = top_confidence_mask & valid_indices
450
+
451
+ noisy_batch[top_confidence_mask] = pred_step_index[
452
+ top_confidence_mask]
453
+
454
+ # 低置信度的token或者eos之后的token均保留为 self.mask_token_id, 其他则替换为 pred_step_index
455
+ masked_indices_pre = noisy_batch == self.mask_token_id
456
+ masked_indices_vaild = masked_indices_pre & ~masked_indices_eos
457
+ if flag_exit:
458
+ # 如果已经满足退出条件,直接返回
459
+ break
460
+ if (masked_indices_vaild.sum(dim=-1) <= 1).all():
461
+ # 如果每个batch中只有一个或者0个token被mask,说明下次已经没有足够的token可以被mask了,再进行一次就结束
462
+ flag_exit = True
463
+
464
+ return logits
465
+
466
+ def forward_autoregressive_decoding(self, src):
467
+ bs = src.shape[0]
468
+ noisy_batch = torch.full((bs, self.max_len + 1),
469
+ self.mask_token_id,
470
+ dtype=torch.int64,
471
+ device=src.get_device())
472
+ logits = []
473
+ for step_i in range(self.max_len + 1):
474
+
475
+ tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
476
+
477
+ pred_step = self.tgt_word_prj(tgts[:, step_i:step_i + 1, :])
478
+ pred_step = F.softmax(pred_step, -1)
479
+ logits.append(pred_step)
480
+ pred_step = torch.argmax(pred_step, dim=-1)
481
+ noisy_batch[:, step_i] = pred_step[:, 0]
482
+ if (noisy_batch == self.eos).any(dim=-1).all():
483
+ break
484
+ logits = torch.cat(logits, dim=1)
485
+ return logits
486
+
487
+ def forward_cloze_mask_decoding(self, src, noisy_batch=None):
488
+ """Cloze Mask Decoding."""
489
+ bs = src.shape[0]
490
+ if noisy_batch is None:
491
+ noisy_batch = torch.full((bs, self.max_len + 1),
492
+ self.mask_token_id,
493
+ dtype=torch.int64,
494
+ device=src.get_device())
495
+ tgts = self.forward_decoding(src, noisy_batch)
496
+ pred_step = self.tgt_word_prj(tgts)
497
+ pred_step = F.softmax(pred_step, -1)
498
+ noisy_batch = torch.argmax(pred_step, dim=-1)
499
+ masked_indices_eos = self.get_masked_indice_after_eos(
500
+ noisy_batch) # [bs, max_len + 1] bool tensor
501
+ noisy_batch[
502
+ masked_indices_eos] = self.mask_token_id # 保留eos之后的token为mask token
503
+
504
+ logits = torch.rand((bs, self.max_len + 1, self.out_channels - 2),
505
+ dtype=torch.float32,
506
+ device=src.get_device())
507
+ for step_i in range(self.max_len + 1):
508
+ noisy_batch[:, step_i] = self.mask_token_id
509
+
510
+ tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
511
+
512
+ pred_step = self.tgt_word_prj(tgts[:, step_i:step_i + 1, :])
513
+ pred_step = F.softmax(pred_step, -1)
514
+ logits[:, step_i:step_i + 1, :] = pred_step
515
+ pred_step = torch.argmax(pred_step, dim=-1)
516
+ noisy_batch[:, step_i] = pred_step[:, 0]
517
+ if (torch.argmax(logits, dim=-1) == self.eos).any(dim=-1).all():
518
+ break
519
+ return logits
520
+
521
+
522
+ class TransformerBlock(nn.Module):
523
+
524
+ def __init__(
525
+ self,
526
+ d_model,
527
+ nhead,
528
+ dim_feedforward=2048,
529
+ attention_dropout_rate=0.0,
530
+ residual_dropout_rate=0.1,
531
+ with_self_attn=True,
532
+ with_cross_attn=False,
533
+ epsilon=1e-5,
534
+ ):
535
+ super(TransformerBlock, self).__init__()
536
+ self.with_self_attn = with_self_attn
537
+ if with_self_attn:
538
+ self.self_attn = MultiheadAttention(d_model,
539
+ nhead,
540
+ dropout=attention_dropout_rate,
541
+ self_attn=with_self_attn)
542
+ self.norm1 = nn.LayerNorm(d_model, eps=epsilon)
543
+ self.dropout1 = nn.Dropout(residual_dropout_rate)
544
+ self.with_cross_attn = with_cross_attn
545
+ if with_cross_attn:
546
+ self.cross_attn = MultiheadAttention(
547
+ d_model, nhead, dropout=attention_dropout_rate
548
+ ) # for self_attn of encoder or cross_attn of decoder
549
+ self.norm2 = nn.LayerNorm(d_model, eps=epsilon)
550
+ self.dropout2 = nn.Dropout(residual_dropout_rate)
551
+
552
+ self.mlp = Mlp(
553
+ in_features=d_model,
554
+ hidden_features=dim_feedforward,
555
+ act_layer=nn.ReLU,
556
+ drop=residual_dropout_rate,
557
+ )
558
+
559
+ self.norm3 = nn.LayerNorm(d_model, eps=epsilon)
560
+
561
+ self.dropout3 = nn.Dropout(residual_dropout_rate)
562
+
563
+ def forward(self,
564
+ tgt,
565
+ memory=None,
566
+ self_mask=None,
567
+ cross_mask=None,
568
+ sample_k=0):
569
+
570
+ if self.with_self_attn:
571
+ if sample_k > 0:
572
+ bs, _, L, Dim = tgt.shape
573
+ tgt = tgt.flatten(0, 1)
574
+ tgt1 = self.self_attn(tgt, attn_mask=self_mask)
575
+ tgt = self.norm1(tgt + self.dropout1(tgt1))
576
+
577
+ if self.with_cross_attn:
578
+ if sample_k > 0:
579
+ tgt = tgt.reshape(bs, sample_k, L, Dim).flatten(1, 2)
580
+ tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
581
+ tgt = self.norm2(tgt + self.dropout2(tgt2))
582
+ tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
583
+
584
+ if sample_k > 0:
585
+ tgt = tgt.reshape(bs, sample_k, L, Dim)
586
+
587
+ return tgt