openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torch import nn
|
|
5
|
+
from openrec.modeling.common import Mlp
|
|
6
|
+
from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, Embeddings, MultiheadAttention
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MDiffDecoder(nn.Module):
|
|
10
|
+
"""A transformer model. User is able to modify the attributes as needed.
|
|
11
|
+
The architechture is based on the paper "Attention Is All You Need". Ashish
|
|
12
|
+
Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
|
|
13
|
+
Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you
|
|
14
|
+
need. In Advances in Neural Information Processing Systems, pages
|
|
15
|
+
6000-6010.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
|
19
|
+
nhead: the number of heads in the multiheadattention models (default=8).
|
|
20
|
+
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
|
21
|
+
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
|
22
|
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
23
|
+
dropout: the dropout value (default=0.1).
|
|
24
|
+
custom_encoder: custom encoder (default=None).
|
|
25
|
+
custom_decoder: custom decoder (default=None).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self,
|
|
29
|
+
in_channels,
|
|
30
|
+
out_channels,
|
|
31
|
+
nhead=None,
|
|
32
|
+
num_decoder_layers=6,
|
|
33
|
+
max_len=25,
|
|
34
|
+
attention_dropout_rate=0.0,
|
|
35
|
+
residual_dropout_rate=0.1,
|
|
36
|
+
scale_embedding=True,
|
|
37
|
+
parallel_decoding=False,
|
|
38
|
+
autoregressive_decoding=False,
|
|
39
|
+
sampler_step=5,
|
|
40
|
+
low_confidence_decoding=False,
|
|
41
|
+
random_mask_decoding=False,
|
|
42
|
+
semi_autoregressive_decoding=False,
|
|
43
|
+
cloze_mask_decoding=False,
|
|
44
|
+
rec_loss_weight=1.0,
|
|
45
|
+
reflect_loss_weight=1.0,
|
|
46
|
+
sample_k=0,
|
|
47
|
+
temperature=1.0):
|
|
48
|
+
super(MDiffDecoder, self).__init__()
|
|
49
|
+
self.out_channels = out_channels
|
|
50
|
+
self.ignore_index = out_channels - 1
|
|
51
|
+
self.mask_token_id = out_channels - 2
|
|
52
|
+
self.eos = 0
|
|
53
|
+
self.max_len = max_len
|
|
54
|
+
d_model = in_channels
|
|
55
|
+
dim_feedforward = d_model * 4
|
|
56
|
+
self.pd = parallel_decoding
|
|
57
|
+
self.ar = autoregressive_decoding
|
|
58
|
+
self.sampler_step = sampler_step
|
|
59
|
+
self.lc = low_confidence_decoding
|
|
60
|
+
self.rm = random_mask_decoding
|
|
61
|
+
self.semiar = semi_autoregressive_decoding
|
|
62
|
+
self.cm = cloze_mask_decoding
|
|
63
|
+
self.rec_loss_weight = rec_loss_weight
|
|
64
|
+
self.reflect_loss_weight = reflect_loss_weight
|
|
65
|
+
self.temperature = temperature
|
|
66
|
+
self.sample_k = sample_k
|
|
67
|
+
nhead = nhead if nhead is not None else d_model // 32
|
|
68
|
+
self.embedding = Embeddings(
|
|
69
|
+
d_model=d_model,
|
|
70
|
+
vocab=self.out_channels,
|
|
71
|
+
padding_idx=0,
|
|
72
|
+
scale_embedding=scale_embedding,
|
|
73
|
+
)
|
|
74
|
+
self.pos_embed = nn.Parameter(torch.zeros(
|
|
75
|
+
[1, self.max_len + 1, d_model], dtype=torch.float32),
|
|
76
|
+
requires_grad=True)
|
|
77
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
78
|
+
self.positional_encoding = PositionalEncoding(
|
|
79
|
+
dropout=residual_dropout_rate, dim=d_model)
|
|
80
|
+
|
|
81
|
+
self.decoder = nn.ModuleList([
|
|
82
|
+
TransformerBlock(
|
|
83
|
+
d_model,
|
|
84
|
+
nhead,
|
|
85
|
+
dim_feedforward,
|
|
86
|
+
attention_dropout_rate,
|
|
87
|
+
residual_dropout_rate,
|
|
88
|
+
with_self_attn=True,
|
|
89
|
+
with_cross_attn=True,
|
|
90
|
+
) for i in range(num_decoder_layers)
|
|
91
|
+
])
|
|
92
|
+
|
|
93
|
+
self.num_decoder_layers = num_decoder_layers
|
|
94
|
+
|
|
95
|
+
self.d_model = d_model
|
|
96
|
+
self.nhead = nhead
|
|
97
|
+
self.tgt_word_prj = nn.Linear(d_model,
|
|
98
|
+
self.out_channels - 2,
|
|
99
|
+
bias=False)
|
|
100
|
+
w0 = np.random.normal(0.0, d_model**-0.5,
|
|
101
|
+
(d_model, self.out_channels - 2)).astype(
|
|
102
|
+
np.float32)
|
|
103
|
+
self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose())
|
|
104
|
+
self.apply(self._init_weights)
|
|
105
|
+
|
|
106
|
+
def _init_weights(self, m):
|
|
107
|
+
if isinstance(m, nn.Linear):
|
|
108
|
+
nn.init.xavier_normal_(m.weight)
|
|
109
|
+
if m.bias is not None:
|
|
110
|
+
nn.init.zeros_(m.bias)
|
|
111
|
+
|
|
112
|
+
def forward_train(self, memory, data=None):
|
|
113
|
+
labels, reflect_ids, noisy_batch, masked_indices, p_mask, length = data
|
|
114
|
+
p_mask = p_mask[:, None].repeat(1, labels.shape[1])
|
|
115
|
+
noisy_data_length = length + 1
|
|
116
|
+
noisy_data_length = noisy_data_length[:,
|
|
117
|
+
None].repeat(1, labels.shape[1])
|
|
118
|
+
|
|
119
|
+
tgts = self.embedding(noisy_batch)
|
|
120
|
+
tgts = self.positional_encoding(tgts) + self.pos_embed
|
|
121
|
+
|
|
122
|
+
for decoder_layer in self.decoder:
|
|
123
|
+
tgts = decoder_layer(tgts, memory, self_mask=None)
|
|
124
|
+
logits = self.tgt_word_prj(tgts)
|
|
125
|
+
token_loss = F.cross_entropy(
|
|
126
|
+
logits[masked_indices],
|
|
127
|
+
labels[masked_indices],
|
|
128
|
+
reduction='none',
|
|
129
|
+
ignore_index=self.ignore_index) / p_mask[masked_indices]
|
|
130
|
+
loss = torch.sum(
|
|
131
|
+
token_loss / noisy_data_length[masked_indices]) / labels.shape[0]
|
|
132
|
+
|
|
133
|
+
if reflect_ids is not None:
|
|
134
|
+
reflect_tgts = self.embedding(reflect_ids)
|
|
135
|
+
reflect_tgts = self.positional_encoding(
|
|
136
|
+
reflect_tgts) + self.pos_embed
|
|
137
|
+
for decoder_layer in self.decoder:
|
|
138
|
+
reflect_tgts = decoder_layer(reflect_tgts,
|
|
139
|
+
memory,
|
|
140
|
+
self_mask=None)
|
|
141
|
+
reflect_logits = self.tgt_word_prj(reflect_tgts)
|
|
142
|
+
reflect_loss = F.cross_entropy(reflect_logits.flatten(0, 1),
|
|
143
|
+
labels.flatten(0, 1),
|
|
144
|
+
reduction='mean',
|
|
145
|
+
ignore_index=self.ignore_index)
|
|
146
|
+
loss = self.rec_loss_weight * loss + self.reflect_loss_weight * reflect_loss
|
|
147
|
+
|
|
148
|
+
return loss
|
|
149
|
+
|
|
150
|
+
def forward_train_all(self, memory, data=None):
|
|
151
|
+
|
|
152
|
+
labels, reflect_ids_all, noisy_batch_all, masked_indices_all, p_mask_all, length = data
|
|
153
|
+
bs, L = labels.shape
|
|
154
|
+
tgts = self.embedding(noisy_batch_all.flatten(0, 1))
|
|
155
|
+
tgts = self.positional_encoding(tgts) + self.pos_embed
|
|
156
|
+
tgts = tgts.reshape(bs, self.sample_k, L, -1)
|
|
157
|
+
|
|
158
|
+
for decoder_layer in self.decoder:
|
|
159
|
+
tgts = decoder_layer(tgts,
|
|
160
|
+
memory,
|
|
161
|
+
self_mask=None,
|
|
162
|
+
sample_k=self.sample_k)
|
|
163
|
+
logits_all = self.tgt_word_prj(tgts) # bs, sample_k, L, c_num
|
|
164
|
+
|
|
165
|
+
reflect_tgts = self.embedding(reflect_ids_all.flatten(0, 1))
|
|
166
|
+
reflect_tgts = self.positional_encoding(reflect_tgts) + self.pos_embed
|
|
167
|
+
reflect_tgts = reflect_tgts.reshape(bs, self.sample_k, L, -1)
|
|
168
|
+
|
|
169
|
+
for decoder_layer in self.decoder:
|
|
170
|
+
reflect_tgts = decoder_layer(reflect_tgts,
|
|
171
|
+
memory,
|
|
172
|
+
self_mask=None,
|
|
173
|
+
sample_k=self.sample_k)
|
|
174
|
+
reflect_logits_all = self.tgt_word_prj(reflect_tgts)
|
|
175
|
+
|
|
176
|
+
loss = []
|
|
177
|
+
for i in range(self.sample_k):
|
|
178
|
+
p_mask = p_mask_all[:, i]
|
|
179
|
+
masked_indices = masked_indices_all[:, i]
|
|
180
|
+
logits = logits_all[:, i]
|
|
181
|
+
|
|
182
|
+
p_mask = p_mask[:, None].repeat(1, labels.shape[1])
|
|
183
|
+
noisy_data_length = length + 1
|
|
184
|
+
noisy_data_length = noisy_data_length[:, None].repeat(
|
|
185
|
+
1, labels.shape[1])
|
|
186
|
+
token_loss = F.cross_entropy(
|
|
187
|
+
logits[masked_indices],
|
|
188
|
+
labels[masked_indices],
|
|
189
|
+
reduction='none',
|
|
190
|
+
ignore_index=self.ignore_index) / p_mask[masked_indices]
|
|
191
|
+
denoise_loss_i = torch.sum(
|
|
192
|
+
token_loss /
|
|
193
|
+
noisy_data_length[masked_indices]) / labels.shape[0]
|
|
194
|
+
|
|
195
|
+
reflect_logits = reflect_logits_all[:, i]
|
|
196
|
+
reflect_loss_i = F.cross_entropy(reflect_logits.flatten(0, 1),
|
|
197
|
+
labels.flatten(0, 1),
|
|
198
|
+
reduction='mean',
|
|
199
|
+
ignore_index=self.ignore_index)
|
|
200
|
+
loss_i = self.rec_loss_weight * denoise_loss_i + self.reflect_loss_weight * reflect_loss_i
|
|
201
|
+
loss.append(loss_i)
|
|
202
|
+
|
|
203
|
+
return sum(loss) / len(loss)
|
|
204
|
+
|
|
205
|
+
def forward(self, src, data=None):
|
|
206
|
+
"""Take in and process masked source/target sequences.
|
|
207
|
+
Args:
|
|
208
|
+
src: the sequence to the encoder (required).
|
|
209
|
+
tgt: the sequence to the decoder (required).
|
|
210
|
+
Shape:
|
|
211
|
+
- src: :math:`(B, sN, C)`.
|
|
212
|
+
- tgt: :math:`(B, tN, C)`.
|
|
213
|
+
Examples:
|
|
214
|
+
>>> output = transformer_model(src, tgt)
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
if self.training:
|
|
218
|
+
if self.sample_k > 0:
|
|
219
|
+
res = self.forward_train_all(src, data)
|
|
220
|
+
else:
|
|
221
|
+
res = self.forward_train(src, data)
|
|
222
|
+
else:
|
|
223
|
+
if self.pd:
|
|
224
|
+
res = self.forward_parallel_decoding(src)
|
|
225
|
+
elif self.ar:
|
|
226
|
+
res = self.forward_autoregressive_decoding(src)
|
|
227
|
+
elif self.lc:
|
|
228
|
+
res = self.forward_low_confidence_decoding(src)
|
|
229
|
+
elif self.rm:
|
|
230
|
+
res = self.forward_random_mask_decoding(src)
|
|
231
|
+
elif self.semiar:
|
|
232
|
+
res = self.forward_semi_autoregressive_decoding(src)
|
|
233
|
+
elif self.cm:
|
|
234
|
+
res = self.forward_cloze_mask_decoding(src)
|
|
235
|
+
else:
|
|
236
|
+
res = self.forward_parallel_decoding(src)
|
|
237
|
+
|
|
238
|
+
return res
|
|
239
|
+
|
|
240
|
+
def forward_decoding(self, src, tgts, step_i=0):
|
|
241
|
+
|
|
242
|
+
tgts = self.embedding(tgts)
|
|
243
|
+
tgts = self.positional_encoding(tgts) + self.pos_embed
|
|
244
|
+
for decoder_layer in self.decoder:
|
|
245
|
+
tgts = decoder_layer(tgts, src, self_mask=None)
|
|
246
|
+
|
|
247
|
+
return tgts
|
|
248
|
+
|
|
249
|
+
def forward_reflect(self, src, pred_indexs, step_i=0):
|
|
250
|
+
"""Reflect decoding."""
|
|
251
|
+
|
|
252
|
+
# reflect
|
|
253
|
+
masked_indices_eos = self.get_masked_indice_after_eos(
|
|
254
|
+
pred_indexs
|
|
255
|
+
) # [bs, max_len + 1] bool tensor False False(eos) True True ..
|
|
256
|
+
pred_indexs[
|
|
257
|
+
masked_indices_eos] = self.mask_token_id # 保留eos之后的token为mask token
|
|
258
|
+
|
|
259
|
+
reflect_tgts = self.forward_decoding(src, pred_indexs, step_i=step_i)
|
|
260
|
+
logits_reflect = F.softmax(self.tgt_word_prj(reflect_tgts), -1)
|
|
261
|
+
|
|
262
|
+
return logits_reflect
|
|
263
|
+
|
|
264
|
+
def forward_parallel_decoding(self, src):
|
|
265
|
+
bs = src.shape[0]
|
|
266
|
+
noisy_batch = torch.full((bs, self.max_len + 1),
|
|
267
|
+
self.mask_token_id,
|
|
268
|
+
dtype=torch.int64,
|
|
269
|
+
device=src.get_device())
|
|
270
|
+
tgts = self.forward_decoding(src, noisy_batch)
|
|
271
|
+
logits = F.softmax(self.tgt_word_prj(tgts), -1)
|
|
272
|
+
return logits
|
|
273
|
+
|
|
274
|
+
def get_masked_indice_after_eos(self, noisy_batch):
|
|
275
|
+
"""Get the indices of the masked tokens after the first EOS token."""
|
|
276
|
+
# noisy_batch: [batch_size, max_len + 1]
|
|
277
|
+
eos_mask = noisy_batch == self.eos # [batch_size, seq_len]
|
|
278
|
+
|
|
279
|
+
# 找到每行第一个eos的位置
|
|
280
|
+
eos_indices = eos_mask.float().argmax(dim=1) # [batch_size]
|
|
281
|
+
|
|
282
|
+
# 如果没有eos,argmax会返回0,但我们不想在这些地方mask,需要过滤
|
|
283
|
+
eos_exists = eos_mask.any(dim=1) # [batch_size]
|
|
284
|
+
|
|
285
|
+
batch_size, seq_len = noisy_batch.shape
|
|
286
|
+
arange = torch.arange(seq_len,
|
|
287
|
+
device=noisy_batch.device).unsqueeze(0).expand(
|
|
288
|
+
batch_size, -1) # [batch_size, seq_len]
|
|
289
|
+
|
|
290
|
+
# 创建掩码:只对eos之后的token设为True
|
|
291
|
+
masked_indices = arange > eos_indices.unsqueeze(1)
|
|
292
|
+
masked_indices = masked_indices | ~eos_exists.unsqueeze(1)
|
|
293
|
+
|
|
294
|
+
return masked_indices
|
|
295
|
+
|
|
296
|
+
def forward_low_confidence_decoding(self, src):
|
|
297
|
+
bs = src.shape[0]
|
|
298
|
+
noisy_batch = torch.full((bs, self.max_len + 1),
|
|
299
|
+
self.mask_token_id,
|
|
300
|
+
dtype=torch.int64,
|
|
301
|
+
device=src.get_device())
|
|
302
|
+
masked_indices_pre = torch.full((bs, self.max_len + 1),
|
|
303
|
+
True,
|
|
304
|
+
dtype=torch.bool,
|
|
305
|
+
device=src.get_device())
|
|
306
|
+
flag_exit = False
|
|
307
|
+
for step_i in range(self.sampler_step):
|
|
308
|
+
|
|
309
|
+
tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
|
|
310
|
+
pred_step = self.tgt_word_prj(tgts)
|
|
311
|
+
pred_step = F.softmax(pred_step, -1)
|
|
312
|
+
if step_i == 0:
|
|
313
|
+
logits = pred_step.clone()
|
|
314
|
+
logits[masked_indices_pre] = pred_step[masked_indices_pre]
|
|
315
|
+
pred_step_prob, pred_step_index = torch.max(
|
|
316
|
+
pred_step, dim=-1) # [bs, max_len + 1], [bs, max_len + 1]
|
|
317
|
+
masked_indices_eos = self.get_masked_indice_after_eos(
|
|
318
|
+
pred_step_index
|
|
319
|
+
) # [bs, max_len + 1] bool tensor False False(eos) True True ..
|
|
320
|
+
|
|
321
|
+
# 仅计算mask token位置以及eos之前token的平均概率
|
|
322
|
+
valid_indices = masked_indices_pre & ~masked_indices_eos
|
|
323
|
+
pred_step_prob = pred_step_prob * valid_indices.float()
|
|
324
|
+
pred_step_prob_avg = pred_step_prob.sum(
|
|
325
|
+
dim=1, keepdim=True) / valid_indices.sum(
|
|
326
|
+
dim=1, keepdim=True) # [bs, 1]
|
|
327
|
+
|
|
328
|
+
# 高于平均置信度的token
|
|
329
|
+
top_confidence_mask = pred_step_prob > pred_step_prob_avg
|
|
330
|
+
top_confidence_mask = top_confidence_mask & valid_indices
|
|
331
|
+
noisy_batch[top_confidence_mask] = pred_step_index[
|
|
332
|
+
top_confidence_mask]
|
|
333
|
+
# 低置信度的token或者eos之后的token均保留为 self.mask_token_id, 其他则替换为 pred_step_index
|
|
334
|
+
masked_indices_pre = noisy_batch == self.mask_token_id
|
|
335
|
+
masked_indices_vaild = masked_indices_pre & ~masked_indices_eos
|
|
336
|
+
if flag_exit:
|
|
337
|
+
# 如果已经满足退出条件,直接返回
|
|
338
|
+
break
|
|
339
|
+
if (masked_indices_vaild.sum(dim=-1) <= 1).all():
|
|
340
|
+
# 如果每个batch中只有一个或者0个token被mask,说明下次已经没有足够的token可以被mask了,再进行一次就结束
|
|
341
|
+
flag_exit = True
|
|
342
|
+
|
|
343
|
+
return logits
|
|
344
|
+
|
|
345
|
+
def forward_random_mask_decoding(self, src):
|
|
346
|
+
bs = src.shape[0]
|
|
347
|
+
noisy_batch = torch.full((bs, self.max_len + 1),
|
|
348
|
+
self.mask_token_id,
|
|
349
|
+
dtype=torch.int64,
|
|
350
|
+
device=src.get_device())
|
|
351
|
+
masked_indices_pre = torch.full((bs, self.max_len + 1),
|
|
352
|
+
True,
|
|
353
|
+
dtype=torch.bool,
|
|
354
|
+
device=src.get_device())
|
|
355
|
+
flag_exit = False
|
|
356
|
+
for step_i in range(self.sampler_step):
|
|
357
|
+
|
|
358
|
+
tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
|
|
359
|
+
|
|
360
|
+
pred_step = self.tgt_word_prj(tgts)
|
|
361
|
+
pred_step = F.softmax(pred_step, -1)
|
|
362
|
+
if step_i == 0:
|
|
363
|
+
logits = pred_step.clone()
|
|
364
|
+
else:
|
|
365
|
+
logits[masked_indices_pre] = pred_step[masked_indices_pre]
|
|
366
|
+
pred_step_prob, pred_step_index = torch.max(
|
|
367
|
+
pred_step, dim=-1) # [bs, max_len + 1], [bs, max_len + 1]
|
|
368
|
+
masked_indices_eos = self.get_masked_indice_after_eos(
|
|
369
|
+
pred_step_index) # [bs, max_len + 1] bool tensor
|
|
370
|
+
|
|
371
|
+
# 采用mask token位置以及eos之前token作为可用token
|
|
372
|
+
valid_indices = masked_indices_pre & ~masked_indices_eos
|
|
373
|
+
# 在这些可用token中随机选择一些进行mask
|
|
374
|
+
rand_mask_prob = torch.rand((bs, self.max_len + 1),
|
|
375
|
+
device=src.get_device())
|
|
376
|
+
# rand_mask_prob = rand_mask_prob * valid_indices.float()
|
|
377
|
+
random_res = rand_mask_prob > 0.5 # 50%的概率进行mask
|
|
378
|
+
# 仅保留mask token位置以及eos之前token的高置信度token
|
|
379
|
+
random_res = random_res & valid_indices
|
|
380
|
+
# random_mask = random_mask & masked_indices_pre
|
|
381
|
+
noisy_batch[random_res] = pred_step_index[random_res]
|
|
382
|
+
# 随机mask token或者eos之后的token均保留为 self.mask_token_id, 其他则替换为 pred_step_index
|
|
383
|
+
masked_indices_pre = noisy_batch == self.mask_token_id
|
|
384
|
+
masked_indices_vaild = masked_indices_pre & ~masked_indices_eos
|
|
385
|
+
if flag_exit:
|
|
386
|
+
# 如果已经满足退出条件,直接返回
|
|
387
|
+
break
|
|
388
|
+
if (masked_indices_vaild.sum(dim=-1) <= 1).all():
|
|
389
|
+
# 如果每个batch中只有一个或者0个token被mask,说明下次已经没有足够的token可以被mask了,再进行一次就结束
|
|
390
|
+
flag_exit = True
|
|
391
|
+
|
|
392
|
+
return logits
|
|
393
|
+
|
|
394
|
+
def forward_semi_autoregressive_decoding(self, src):
|
|
395
|
+
bs = src.shape[0]
|
|
396
|
+
noisy_batch = torch.full((bs, self.max_len + 1),
|
|
397
|
+
self.mask_token_id,
|
|
398
|
+
dtype=torch.int64,
|
|
399
|
+
device=src.get_device())
|
|
400
|
+
block_size = (self.max_len + 1) // self.sampler_step
|
|
401
|
+
masked_indices_pre = torch.full((bs, self.max_len + 1),
|
|
402
|
+
True,
|
|
403
|
+
dtype=torch.bool,
|
|
404
|
+
device=src.get_device())
|
|
405
|
+
flag_exit = False
|
|
406
|
+
for step_i in range(self.sampler_step):
|
|
407
|
+
|
|
408
|
+
tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
|
|
409
|
+
|
|
410
|
+
pred_step = self.tgt_word_prj(tgts)
|
|
411
|
+
|
|
412
|
+
pred_step = pred_step / self.temperature
|
|
413
|
+
pred_step = F.softmax(pred_step, -1)
|
|
414
|
+
if step_i == 0:
|
|
415
|
+
logits = pred_step.clone()
|
|
416
|
+
else:
|
|
417
|
+
logits[masked_indices_pre] = pred_step[masked_indices_pre]
|
|
418
|
+
pred_step_prob, pred_step_index = torch.max(
|
|
419
|
+
pred_step, dim=-1) # [bs, max_len + 1], [bs, max_len + 1]
|
|
420
|
+
masked_indices_eos = self.get_masked_indice_after_eos(
|
|
421
|
+
pred_step_index
|
|
422
|
+
) # [bs, max_len + 1] bool tensor False False(eos) True True ..
|
|
423
|
+
|
|
424
|
+
block_vaild_indices = torch.full((bs, self.max_len + 1),
|
|
425
|
+
False,
|
|
426
|
+
dtype=torch.bool,
|
|
427
|
+
device=src.get_device())
|
|
428
|
+
|
|
429
|
+
if step_i <= 2:
|
|
430
|
+
if self.sampler_step > 2:
|
|
431
|
+
block_vaild_indices[:, :block_size * (step_i + 1)] = True
|
|
432
|
+
else:
|
|
433
|
+
block_vaild_indices = ~block_vaild_indices
|
|
434
|
+
elif step_i >= self.sampler_step - 2:
|
|
435
|
+
block_vaild_indices[:, block_size * (step_i - 1):] = True
|
|
436
|
+
else:
|
|
437
|
+
block_vaild_indices[:, block_size * (step_i - 1):block_size *
|
|
438
|
+
(step_i + 1)] = True
|
|
439
|
+
|
|
440
|
+
# 仅计算mask token位置, eos之前token以及当前block中token的平均概率
|
|
441
|
+
valid_indices = masked_indices_pre & ~masked_indices_eos & block_vaild_indices
|
|
442
|
+
pred_step_prob = pred_step_prob * valid_indices.float()
|
|
443
|
+
pred_step_prob_avg = pred_step_prob.sum(
|
|
444
|
+
dim=1, keepdim=True) / valid_indices.sum(
|
|
445
|
+
dim=1, keepdim=True) # [bs, 1]
|
|
446
|
+
|
|
447
|
+
# 高于平均置信度的token
|
|
448
|
+
top_confidence_mask = pred_step_prob > pred_step_prob_avg
|
|
449
|
+
top_confidence_mask = top_confidence_mask & valid_indices
|
|
450
|
+
|
|
451
|
+
noisy_batch[top_confidence_mask] = pred_step_index[
|
|
452
|
+
top_confidence_mask]
|
|
453
|
+
|
|
454
|
+
# 低置信度的token或者eos之后的token均保留为 self.mask_token_id, 其他则替换为 pred_step_index
|
|
455
|
+
masked_indices_pre = noisy_batch == self.mask_token_id
|
|
456
|
+
masked_indices_vaild = masked_indices_pre & ~masked_indices_eos
|
|
457
|
+
if flag_exit:
|
|
458
|
+
# 如果已经满足退出条件,直接返回
|
|
459
|
+
break
|
|
460
|
+
if (masked_indices_vaild.sum(dim=-1) <= 1).all():
|
|
461
|
+
# 如果每个batch中只有一个或者0个token被mask,说明下次已经没有足够的token可以被mask了,再进行一次就结束
|
|
462
|
+
flag_exit = True
|
|
463
|
+
|
|
464
|
+
return logits
|
|
465
|
+
|
|
466
|
+
def forward_autoregressive_decoding(self, src):
|
|
467
|
+
bs = src.shape[0]
|
|
468
|
+
noisy_batch = torch.full((bs, self.max_len + 1),
|
|
469
|
+
self.mask_token_id,
|
|
470
|
+
dtype=torch.int64,
|
|
471
|
+
device=src.get_device())
|
|
472
|
+
logits = []
|
|
473
|
+
for step_i in range(self.max_len + 1):
|
|
474
|
+
|
|
475
|
+
tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
|
|
476
|
+
|
|
477
|
+
pred_step = self.tgt_word_prj(tgts[:, step_i:step_i + 1, :])
|
|
478
|
+
pred_step = F.softmax(pred_step, -1)
|
|
479
|
+
logits.append(pred_step)
|
|
480
|
+
pred_step = torch.argmax(pred_step, dim=-1)
|
|
481
|
+
noisy_batch[:, step_i] = pred_step[:, 0]
|
|
482
|
+
if (noisy_batch == self.eos).any(dim=-1).all():
|
|
483
|
+
break
|
|
484
|
+
logits = torch.cat(logits, dim=1)
|
|
485
|
+
return logits
|
|
486
|
+
|
|
487
|
+
def forward_cloze_mask_decoding(self, src, noisy_batch=None):
|
|
488
|
+
"""Cloze Mask Decoding."""
|
|
489
|
+
bs = src.shape[0]
|
|
490
|
+
if noisy_batch is None:
|
|
491
|
+
noisy_batch = torch.full((bs, self.max_len + 1),
|
|
492
|
+
self.mask_token_id,
|
|
493
|
+
dtype=torch.int64,
|
|
494
|
+
device=src.get_device())
|
|
495
|
+
tgts = self.forward_decoding(src, noisy_batch)
|
|
496
|
+
pred_step = self.tgt_word_prj(tgts)
|
|
497
|
+
pred_step = F.softmax(pred_step, -1)
|
|
498
|
+
noisy_batch = torch.argmax(pred_step, dim=-1)
|
|
499
|
+
masked_indices_eos = self.get_masked_indice_after_eos(
|
|
500
|
+
noisy_batch) # [bs, max_len + 1] bool tensor
|
|
501
|
+
noisy_batch[
|
|
502
|
+
masked_indices_eos] = self.mask_token_id # 保留eos之后的token为mask token
|
|
503
|
+
|
|
504
|
+
logits = torch.rand((bs, self.max_len + 1, self.out_channels - 2),
|
|
505
|
+
dtype=torch.float32,
|
|
506
|
+
device=src.get_device())
|
|
507
|
+
for step_i in range(self.max_len + 1):
|
|
508
|
+
noisy_batch[:, step_i] = self.mask_token_id
|
|
509
|
+
|
|
510
|
+
tgts = self.forward_decoding(src, noisy_batch, step_i=step_i)
|
|
511
|
+
|
|
512
|
+
pred_step = self.tgt_word_prj(tgts[:, step_i:step_i + 1, :])
|
|
513
|
+
pred_step = F.softmax(pred_step, -1)
|
|
514
|
+
logits[:, step_i:step_i + 1, :] = pred_step
|
|
515
|
+
pred_step = torch.argmax(pred_step, dim=-1)
|
|
516
|
+
noisy_batch[:, step_i] = pred_step[:, 0]
|
|
517
|
+
if (torch.argmax(logits, dim=-1) == self.eos).any(dim=-1).all():
|
|
518
|
+
break
|
|
519
|
+
return logits
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class TransformerBlock(nn.Module):
|
|
523
|
+
|
|
524
|
+
def __init__(
|
|
525
|
+
self,
|
|
526
|
+
d_model,
|
|
527
|
+
nhead,
|
|
528
|
+
dim_feedforward=2048,
|
|
529
|
+
attention_dropout_rate=0.0,
|
|
530
|
+
residual_dropout_rate=0.1,
|
|
531
|
+
with_self_attn=True,
|
|
532
|
+
with_cross_attn=False,
|
|
533
|
+
epsilon=1e-5,
|
|
534
|
+
):
|
|
535
|
+
super(TransformerBlock, self).__init__()
|
|
536
|
+
self.with_self_attn = with_self_attn
|
|
537
|
+
if with_self_attn:
|
|
538
|
+
self.self_attn = MultiheadAttention(d_model,
|
|
539
|
+
nhead,
|
|
540
|
+
dropout=attention_dropout_rate,
|
|
541
|
+
self_attn=with_self_attn)
|
|
542
|
+
self.norm1 = nn.LayerNorm(d_model, eps=epsilon)
|
|
543
|
+
self.dropout1 = nn.Dropout(residual_dropout_rate)
|
|
544
|
+
self.with_cross_attn = with_cross_attn
|
|
545
|
+
if with_cross_attn:
|
|
546
|
+
self.cross_attn = MultiheadAttention(
|
|
547
|
+
d_model, nhead, dropout=attention_dropout_rate
|
|
548
|
+
) # for self_attn of encoder or cross_attn of decoder
|
|
549
|
+
self.norm2 = nn.LayerNorm(d_model, eps=epsilon)
|
|
550
|
+
self.dropout2 = nn.Dropout(residual_dropout_rate)
|
|
551
|
+
|
|
552
|
+
self.mlp = Mlp(
|
|
553
|
+
in_features=d_model,
|
|
554
|
+
hidden_features=dim_feedforward,
|
|
555
|
+
act_layer=nn.ReLU,
|
|
556
|
+
drop=residual_dropout_rate,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
self.norm3 = nn.LayerNorm(d_model, eps=epsilon)
|
|
560
|
+
|
|
561
|
+
self.dropout3 = nn.Dropout(residual_dropout_rate)
|
|
562
|
+
|
|
563
|
+
def forward(self,
|
|
564
|
+
tgt,
|
|
565
|
+
memory=None,
|
|
566
|
+
self_mask=None,
|
|
567
|
+
cross_mask=None,
|
|
568
|
+
sample_k=0):
|
|
569
|
+
|
|
570
|
+
if self.with_self_attn:
|
|
571
|
+
if sample_k > 0:
|
|
572
|
+
bs, _, L, Dim = tgt.shape
|
|
573
|
+
tgt = tgt.flatten(0, 1)
|
|
574
|
+
tgt1 = self.self_attn(tgt, attn_mask=self_mask)
|
|
575
|
+
tgt = self.norm1(tgt + self.dropout1(tgt1))
|
|
576
|
+
|
|
577
|
+
if self.with_cross_attn:
|
|
578
|
+
if sample_k > 0:
|
|
579
|
+
tgt = tgt.reshape(bs, sample_k, L, Dim).flatten(1, 2)
|
|
580
|
+
tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
|
|
581
|
+
tgt = self.norm2(tgt + self.dropout2(tgt2))
|
|
582
|
+
tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
|
|
583
|
+
|
|
584
|
+
if sample_k > 0:
|
|
585
|
+
tgt = tgt.reshape(bs, sample_k, L, Dim)
|
|
586
|
+
|
|
587
|
+
return tgt
|