xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +99 -5
- xinference/client/restful/restful_client.py +98 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +85 -26
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/chattts.py +40 -8
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/core.py +3 -0
- xinference/model/image/model_spec.json +21 -0
- xinference/model/image/stable_diffusion/core.py +49 -7
- xinference/model/llm/llm_family.json +1065 -106
- xinference/model/llm/llm_family.py +26 -6
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +460 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/sglang/core.py +7 -2
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +11 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
|
2
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""Decoder definition."""
|
|
17
|
+
from typing import Tuple, List, Optional
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.utils.checkpoint as ckpt
|
|
21
|
+
import logging
|
|
22
|
+
|
|
23
|
+
from cosyvoice.transformer.decoder_layer import DecoderLayer
|
|
24
|
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
|
25
|
+
from cosyvoice.utils.class_utils import (
|
|
26
|
+
COSYVOICE_EMB_CLASSES,
|
|
27
|
+
COSYVOICE_ATTENTION_CLASSES,
|
|
28
|
+
COSYVOICE_ACTIVATION_CLASSES,
|
|
29
|
+
)
|
|
30
|
+
from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TransformerDecoder(torch.nn.Module):
|
|
34
|
+
"""Base class of Transfomer decoder module.
|
|
35
|
+
Args:
|
|
36
|
+
vocab_size: output dim
|
|
37
|
+
encoder_output_size: dimension of attention
|
|
38
|
+
attention_heads: the number of heads of multi head attention
|
|
39
|
+
linear_units: the hidden units number of position-wise feedforward
|
|
40
|
+
num_blocks: the number of decoder blocks
|
|
41
|
+
dropout_rate: dropout rate
|
|
42
|
+
self_attention_dropout_rate: dropout rate for attention
|
|
43
|
+
input_layer: input layer type
|
|
44
|
+
use_output_layer: whether to use output layer
|
|
45
|
+
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
|
46
|
+
normalize_before:
|
|
47
|
+
True: use layer_norm before each sub-block of a layer.
|
|
48
|
+
False: use layer_norm after each sub-block of a layer.
|
|
49
|
+
src_attention: if false, encoder-decoder cross attention is not
|
|
50
|
+
applied, such as CIF model
|
|
51
|
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
|
52
|
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
|
53
|
+
checkpointed segment during backward.
|
|
54
|
+
tie_word_embedding: Tie or clone module weights depending of whether we are
|
|
55
|
+
using TorchScript or not
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
vocab_size: int,
|
|
61
|
+
encoder_output_size: int,
|
|
62
|
+
attention_heads: int = 4,
|
|
63
|
+
linear_units: int = 2048,
|
|
64
|
+
num_blocks: int = 6,
|
|
65
|
+
dropout_rate: float = 0.1,
|
|
66
|
+
positional_dropout_rate: float = 0.1,
|
|
67
|
+
self_attention_dropout_rate: float = 0.0,
|
|
68
|
+
src_attention_dropout_rate: float = 0.0,
|
|
69
|
+
input_layer: str = "embed",
|
|
70
|
+
use_output_layer: bool = True,
|
|
71
|
+
normalize_before: bool = True,
|
|
72
|
+
src_attention: bool = True,
|
|
73
|
+
key_bias: bool = True,
|
|
74
|
+
activation_type: str = "relu",
|
|
75
|
+
gradient_checkpointing: bool = False,
|
|
76
|
+
tie_word_embedding: bool = False,
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
attention_dim = encoder_output_size
|
|
80
|
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
|
81
|
+
|
|
82
|
+
self.embed = torch.nn.Sequential(
|
|
83
|
+
torch.nn.Identity() if input_layer == "no_pos" else
|
|
84
|
+
torch.nn.Embedding(vocab_size, attention_dim),
|
|
85
|
+
COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
|
|
86
|
+
positional_dropout_rate),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
self.normalize_before = normalize_before
|
|
90
|
+
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
|
|
91
|
+
self.use_output_layer = use_output_layer
|
|
92
|
+
if use_output_layer:
|
|
93
|
+
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
|
94
|
+
else:
|
|
95
|
+
self.output_layer = torch.nn.Identity()
|
|
96
|
+
self.num_blocks = num_blocks
|
|
97
|
+
self.decoders = torch.nn.ModuleList([
|
|
98
|
+
DecoderLayer(
|
|
99
|
+
attention_dim,
|
|
100
|
+
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
|
101
|
+
attention_heads, attention_dim,
|
|
102
|
+
self_attention_dropout_rate, key_bias),
|
|
103
|
+
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
|
104
|
+
attention_heads, attention_dim, src_attention_dropout_rate,
|
|
105
|
+
key_bias) if src_attention else None,
|
|
106
|
+
PositionwiseFeedForward(attention_dim, linear_units,
|
|
107
|
+
dropout_rate, activation),
|
|
108
|
+
dropout_rate,
|
|
109
|
+
normalize_before,
|
|
110
|
+
) for _ in range(self.num_blocks)
|
|
111
|
+
])
|
|
112
|
+
|
|
113
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
114
|
+
self.tie_word_embedding = tie_word_embedding
|
|
115
|
+
|
|
116
|
+
def forward(
|
|
117
|
+
self,
|
|
118
|
+
memory: torch.Tensor,
|
|
119
|
+
memory_mask: torch.Tensor,
|
|
120
|
+
ys_in_pad: torch.Tensor,
|
|
121
|
+
ys_in_lens: torch.Tensor,
|
|
122
|
+
r_ys_in_pad: torch.Tensor = torch.empty(0),
|
|
123
|
+
reverse_weight: float = 0.0,
|
|
124
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
125
|
+
"""Forward decoder.
|
|
126
|
+
Args:
|
|
127
|
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
128
|
+
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
|
129
|
+
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
|
130
|
+
ys_in_lens: input lengths of this batch (batch)
|
|
131
|
+
r_ys_in_pad: not used in transformer decoder, in order to unify api
|
|
132
|
+
with bidirectional decoder
|
|
133
|
+
reverse_weight: not used in transformer decoder, in order to unify
|
|
134
|
+
api with bidirectional decode
|
|
135
|
+
Returns:
|
|
136
|
+
(tuple): tuple containing:
|
|
137
|
+
x: decoded token score before softmax (batch, maxlen_out,
|
|
138
|
+
vocab_size) if use_output_layer is True,
|
|
139
|
+
torch.tensor(0.0), in order to unify api with bidirectional decoder
|
|
140
|
+
olens: (batch, )
|
|
141
|
+
NOTE(xcsong):
|
|
142
|
+
We pass the `__call__` method of the modules instead of `forward` to the
|
|
143
|
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
|
144
|
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
145
|
+
"""
|
|
146
|
+
tgt = ys_in_pad
|
|
147
|
+
maxlen = tgt.size(1)
|
|
148
|
+
# tgt_mask: (B, 1, L)
|
|
149
|
+
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
|
|
150
|
+
tgt_mask = tgt_mask.to(tgt.device)
|
|
151
|
+
# m: (1, L, L)
|
|
152
|
+
m = subsequent_mask(tgt_mask.size(-1),
|
|
153
|
+
device=tgt_mask.device).unsqueeze(0)
|
|
154
|
+
# tgt_mask: (B, L, L)
|
|
155
|
+
tgt_mask = tgt_mask & m
|
|
156
|
+
x, _ = self.embed(tgt)
|
|
157
|
+
if self.gradient_checkpointing and self.training:
|
|
158
|
+
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
|
|
159
|
+
memory_mask)
|
|
160
|
+
else:
|
|
161
|
+
x = self.forward_layers(x, tgt_mask, memory, memory_mask)
|
|
162
|
+
if self.normalize_before:
|
|
163
|
+
x = self.after_norm(x)
|
|
164
|
+
if self.use_output_layer:
|
|
165
|
+
x = self.output_layer(x)
|
|
166
|
+
olens = tgt_mask.sum(1)
|
|
167
|
+
return x, torch.tensor(0.0), olens
|
|
168
|
+
|
|
169
|
+
def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
|
|
170
|
+
memory: torch.Tensor,
|
|
171
|
+
memory_mask: torch.Tensor) -> torch.Tensor:
|
|
172
|
+
for layer in self.decoders:
|
|
173
|
+
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
|
|
174
|
+
memory_mask)
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
@torch.jit.ignore(drop=True)
|
|
178
|
+
def forward_layers_checkpointed(self, x: torch.Tensor,
|
|
179
|
+
tgt_mask: torch.Tensor,
|
|
180
|
+
memory: torch.Tensor,
|
|
181
|
+
memory_mask: torch.Tensor) -> torch.Tensor:
|
|
182
|
+
for layer in self.decoders:
|
|
183
|
+
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
|
|
184
|
+
layer.__call__, x, tgt_mask, memory, memory_mask)
|
|
185
|
+
return x
|
|
186
|
+
|
|
187
|
+
def forward_one_step(
|
|
188
|
+
self,
|
|
189
|
+
memory: torch.Tensor,
|
|
190
|
+
memory_mask: torch.Tensor,
|
|
191
|
+
tgt: torch.Tensor,
|
|
192
|
+
tgt_mask: torch.Tensor,
|
|
193
|
+
cache: Optional[List[torch.Tensor]] = None,
|
|
194
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
195
|
+
"""Forward one step.
|
|
196
|
+
This is only used for decoding.
|
|
197
|
+
Args:
|
|
198
|
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
199
|
+
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
|
200
|
+
tgt: input token ids, int64 (batch, maxlen_out)
|
|
201
|
+
tgt_mask: input token mask, (batch, maxlen_out)
|
|
202
|
+
dtype=torch.uint8 in PyTorch 1.2-
|
|
203
|
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
|
204
|
+
cache: cached output list of (batch, max_time_out-1, size)
|
|
205
|
+
Returns:
|
|
206
|
+
y, cache: NN output value and cache per `self.decoders`.
|
|
207
|
+
y.shape` is (batch, maxlen_out, token)
|
|
208
|
+
"""
|
|
209
|
+
x, _ = self.embed(tgt)
|
|
210
|
+
new_cache = []
|
|
211
|
+
for i, decoder in enumerate(self.decoders):
|
|
212
|
+
if cache is None:
|
|
213
|
+
c = None
|
|
214
|
+
else:
|
|
215
|
+
c = cache[i]
|
|
216
|
+
x, tgt_mask, memory, memory_mask = decoder(x,
|
|
217
|
+
tgt_mask,
|
|
218
|
+
memory,
|
|
219
|
+
memory_mask,
|
|
220
|
+
cache=c)
|
|
221
|
+
new_cache.append(x)
|
|
222
|
+
if self.normalize_before:
|
|
223
|
+
y = self.after_norm(x[:, -1])
|
|
224
|
+
else:
|
|
225
|
+
y = x[:, -1]
|
|
226
|
+
if self.use_output_layer:
|
|
227
|
+
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
|
228
|
+
return y, new_cache
|
|
229
|
+
|
|
230
|
+
def tie_or_clone_weights(self, jit_mode: bool = True):
|
|
231
|
+
"""Tie or clone module weights (between word_emb and output_layer)
|
|
232
|
+
depending of whether we are using TorchScript or not"""
|
|
233
|
+
if not self.use_output_layer:
|
|
234
|
+
return
|
|
235
|
+
if jit_mode:
|
|
236
|
+
logging.info("clone emb.weight to output.weight")
|
|
237
|
+
self.output_layer.weight = torch.nn.Parameter(
|
|
238
|
+
self.embed[0].weight.clone())
|
|
239
|
+
else:
|
|
240
|
+
logging.info("tie emb.weight with output.weight")
|
|
241
|
+
self.output_layer.weight = self.embed[0].weight
|
|
242
|
+
|
|
243
|
+
if getattr(self.output_layer, "bias", None) is not None:
|
|
244
|
+
self.output_layer.bias.data = torch.nn.functional.pad(
|
|
245
|
+
self.output_layer.bias.data,
|
|
246
|
+
(
|
|
247
|
+
0,
|
|
248
|
+
self.output_layer.weight.shape[0] -
|
|
249
|
+
self.output_layer.bias.shape[0],
|
|
250
|
+
),
|
|
251
|
+
"constant",
|
|
252
|
+
0,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class BiTransformerDecoder(torch.nn.Module):
|
|
257
|
+
"""Base class of Transfomer decoder module.
|
|
258
|
+
Args:
|
|
259
|
+
vocab_size: output dim
|
|
260
|
+
encoder_output_size: dimension of attention
|
|
261
|
+
attention_heads: the number of heads of multi head attention
|
|
262
|
+
linear_units: the hidden units number of position-wise feedforward
|
|
263
|
+
num_blocks: the number of decoder blocks
|
|
264
|
+
r_num_blocks: the number of right to left decoder blocks
|
|
265
|
+
dropout_rate: dropout rate
|
|
266
|
+
self_attention_dropout_rate: dropout rate for attention
|
|
267
|
+
input_layer: input layer type
|
|
268
|
+
use_output_layer: whether to use output layer
|
|
269
|
+
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
|
270
|
+
normalize_before:
|
|
271
|
+
True: use layer_norm before each sub-block of a layer.
|
|
272
|
+
False: use layer_norm after each sub-block of a layer.
|
|
273
|
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
vocab_size: int,
|
|
279
|
+
encoder_output_size: int,
|
|
280
|
+
attention_heads: int = 4,
|
|
281
|
+
linear_units: int = 2048,
|
|
282
|
+
num_blocks: int = 6,
|
|
283
|
+
r_num_blocks: int = 0,
|
|
284
|
+
dropout_rate: float = 0.1,
|
|
285
|
+
positional_dropout_rate: float = 0.1,
|
|
286
|
+
self_attention_dropout_rate: float = 0.0,
|
|
287
|
+
src_attention_dropout_rate: float = 0.0,
|
|
288
|
+
input_layer: str = "embed",
|
|
289
|
+
use_output_layer: bool = True,
|
|
290
|
+
normalize_before: bool = True,
|
|
291
|
+
key_bias: bool = True,
|
|
292
|
+
gradient_checkpointing: bool = False,
|
|
293
|
+
tie_word_embedding: bool = False,
|
|
294
|
+
):
|
|
295
|
+
|
|
296
|
+
super().__init__()
|
|
297
|
+
self.tie_word_embedding = tie_word_embedding
|
|
298
|
+
self.left_decoder = TransformerDecoder(
|
|
299
|
+
vocab_size,
|
|
300
|
+
encoder_output_size,
|
|
301
|
+
attention_heads,
|
|
302
|
+
linear_units,
|
|
303
|
+
num_blocks,
|
|
304
|
+
dropout_rate,
|
|
305
|
+
positional_dropout_rate,
|
|
306
|
+
self_attention_dropout_rate,
|
|
307
|
+
src_attention_dropout_rate,
|
|
308
|
+
input_layer,
|
|
309
|
+
use_output_layer,
|
|
310
|
+
normalize_before,
|
|
311
|
+
key_bias=key_bias,
|
|
312
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
313
|
+
tie_word_embedding=tie_word_embedding)
|
|
314
|
+
|
|
315
|
+
self.right_decoder = TransformerDecoder(
|
|
316
|
+
vocab_size,
|
|
317
|
+
encoder_output_size,
|
|
318
|
+
attention_heads,
|
|
319
|
+
linear_units,
|
|
320
|
+
r_num_blocks,
|
|
321
|
+
dropout_rate,
|
|
322
|
+
positional_dropout_rate,
|
|
323
|
+
self_attention_dropout_rate,
|
|
324
|
+
src_attention_dropout_rate,
|
|
325
|
+
input_layer,
|
|
326
|
+
use_output_layer,
|
|
327
|
+
normalize_before,
|
|
328
|
+
key_bias=key_bias,
|
|
329
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
330
|
+
tie_word_embedding=tie_word_embedding)
|
|
331
|
+
|
|
332
|
+
def forward(
|
|
333
|
+
self,
|
|
334
|
+
memory: torch.Tensor,
|
|
335
|
+
memory_mask: torch.Tensor,
|
|
336
|
+
ys_in_pad: torch.Tensor,
|
|
337
|
+
ys_in_lens: torch.Tensor,
|
|
338
|
+
r_ys_in_pad: torch.Tensor,
|
|
339
|
+
reverse_weight: float = 0.0,
|
|
340
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
341
|
+
"""Forward decoder.
|
|
342
|
+
Args:
|
|
343
|
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
344
|
+
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
|
345
|
+
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
|
346
|
+
ys_in_lens: input lengths of this batch (batch)
|
|
347
|
+
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
|
|
348
|
+
used for right to left decoder
|
|
349
|
+
reverse_weight: used for right to left decoder
|
|
350
|
+
Returns:
|
|
351
|
+
(tuple): tuple containing:
|
|
352
|
+
x: decoded token score before softmax (batch, maxlen_out,
|
|
353
|
+
vocab_size) if use_output_layer is True,
|
|
354
|
+
r_x: x: decoded token score (right to left decoder)
|
|
355
|
+
before softmax (batch, maxlen_out, vocab_size)
|
|
356
|
+
if use_output_layer is True,
|
|
357
|
+
olens: (batch, )
|
|
358
|
+
"""
|
|
359
|
+
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
|
|
360
|
+
ys_in_lens)
|
|
361
|
+
r_x = torch.tensor(0.0)
|
|
362
|
+
if reverse_weight > 0.0:
|
|
363
|
+
r_x, _, olens = self.right_decoder(memory, memory_mask,
|
|
364
|
+
r_ys_in_pad, ys_in_lens)
|
|
365
|
+
return l_x, r_x, olens
|
|
366
|
+
|
|
367
|
+
def forward_one_step(
|
|
368
|
+
self,
|
|
369
|
+
memory: torch.Tensor,
|
|
370
|
+
memory_mask: torch.Tensor,
|
|
371
|
+
tgt: torch.Tensor,
|
|
372
|
+
tgt_mask: torch.Tensor,
|
|
373
|
+
cache: Optional[List[torch.Tensor]] = None,
|
|
374
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
375
|
+
"""Forward one step.
|
|
376
|
+
This is only used for decoding.
|
|
377
|
+
Args:
|
|
378
|
+
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
379
|
+
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
|
380
|
+
tgt: input token ids, int64 (batch, maxlen_out)
|
|
381
|
+
tgt_mask: input token mask, (batch, maxlen_out)
|
|
382
|
+
dtype=torch.uint8 in PyTorch 1.2-
|
|
383
|
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
|
384
|
+
cache: cached output list of (batch, max_time_out-1, size)
|
|
385
|
+
Returns:
|
|
386
|
+
y, cache: NN output value and cache per `self.decoders`.
|
|
387
|
+
y.shape` is (batch, maxlen_out, token)
|
|
388
|
+
"""
|
|
389
|
+
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
|
|
390
|
+
tgt_mask, cache)
|
|
391
|
+
|
|
392
|
+
def tie_or_clone_weights(self, jit_mode: bool = True):
|
|
393
|
+
"""Tie or clone module weights (between word_emb and output_layer)
|
|
394
|
+
depending of whether we are using TorchScript or not"""
|
|
395
|
+
self.left_decoder.tie_or_clone_weights(jit_mode)
|
|
396
|
+
self.right_decoder.tie_or_clone_weights(jit_mode)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Decoder self-attention layer definition."""
|
|
16
|
+
from typing import Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DecoderLayer(nn.Module):
|
|
23
|
+
"""Single decoder layer module.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
size (int): Input dimension.
|
|
27
|
+
self_attn (torch.nn.Module): Self-attention module instance.
|
|
28
|
+
`MultiHeadedAttention` instance can be used as the argument.
|
|
29
|
+
src_attn (torch.nn.Module): Inter-attention module instance.
|
|
30
|
+
`MultiHeadedAttention` instance can be used as the argument.
|
|
31
|
+
If `None` is passed, Inter-attention is not used, such as
|
|
32
|
+
CIF, GPT, and other decoder only model.
|
|
33
|
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
|
34
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
35
|
+
dropout_rate (float): Dropout rate.
|
|
36
|
+
normalize_before (bool):
|
|
37
|
+
True: use layer_norm before each sub-block.
|
|
38
|
+
False: to use layer_norm after each sub-block.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
size: int,
|
|
44
|
+
self_attn: nn.Module,
|
|
45
|
+
src_attn: Optional[nn.Module],
|
|
46
|
+
feed_forward: nn.Module,
|
|
47
|
+
dropout_rate: float,
|
|
48
|
+
normalize_before: bool = True,
|
|
49
|
+
):
|
|
50
|
+
"""Construct an DecoderLayer object."""
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.size = size
|
|
53
|
+
self.self_attn = self_attn
|
|
54
|
+
self.src_attn = src_attn
|
|
55
|
+
self.feed_forward = feed_forward
|
|
56
|
+
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
|
57
|
+
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
|
58
|
+
self.norm3 = nn.LayerNorm(size, eps=1e-5)
|
|
59
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
60
|
+
self.normalize_before = normalize_before
|
|
61
|
+
|
|
62
|
+
def forward(
|
|
63
|
+
self,
|
|
64
|
+
tgt: torch.Tensor,
|
|
65
|
+
tgt_mask: torch.Tensor,
|
|
66
|
+
memory: torch.Tensor,
|
|
67
|
+
memory_mask: torch.Tensor,
|
|
68
|
+
cache: Optional[torch.Tensor] = None
|
|
69
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
70
|
+
"""Compute decoded features.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
|
74
|
+
tgt_mask (torch.Tensor): Mask for input tensor
|
|
75
|
+
(#batch, maxlen_out).
|
|
76
|
+
memory (torch.Tensor): Encoded memory
|
|
77
|
+
(#batch, maxlen_in, size).
|
|
78
|
+
memory_mask (torch.Tensor): Encoded memory mask
|
|
79
|
+
(#batch, maxlen_in).
|
|
80
|
+
cache (torch.Tensor): cached tensors.
|
|
81
|
+
(#batch, maxlen_out - 1, size).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
torch.Tensor: Output tensor (#batch, maxlen_out, size).
|
|
85
|
+
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
|
86
|
+
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
|
87
|
+
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
|
88
|
+
|
|
89
|
+
"""
|
|
90
|
+
residual = tgt
|
|
91
|
+
if self.normalize_before:
|
|
92
|
+
tgt = self.norm1(tgt)
|
|
93
|
+
|
|
94
|
+
if cache is None:
|
|
95
|
+
tgt_q = tgt
|
|
96
|
+
tgt_q_mask = tgt_mask
|
|
97
|
+
else:
|
|
98
|
+
# compute only the last frame query keeping dim: max_time_out -> 1
|
|
99
|
+
assert cache.shape == (
|
|
100
|
+
tgt.shape[0],
|
|
101
|
+
tgt.shape[1] - 1,
|
|
102
|
+
self.size,
|
|
103
|
+
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
|
104
|
+
tgt_q = tgt[:, -1:, :]
|
|
105
|
+
residual = residual[:, -1:, :]
|
|
106
|
+
tgt_q_mask = tgt_mask[:, -1:, :]
|
|
107
|
+
|
|
108
|
+
x = residual + self.dropout(
|
|
109
|
+
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
|
|
110
|
+
if not self.normalize_before:
|
|
111
|
+
x = self.norm1(x)
|
|
112
|
+
|
|
113
|
+
if self.src_attn is not None:
|
|
114
|
+
residual = x
|
|
115
|
+
if self.normalize_before:
|
|
116
|
+
x = self.norm2(x)
|
|
117
|
+
x = residual + self.dropout(
|
|
118
|
+
self.src_attn(x, memory, memory, memory_mask)[0])
|
|
119
|
+
if not self.normalize_before:
|
|
120
|
+
x = self.norm2(x)
|
|
121
|
+
|
|
122
|
+
residual = x
|
|
123
|
+
if self.normalize_before:
|
|
124
|
+
x = self.norm3(x)
|
|
125
|
+
x = residual + self.dropout(self.feed_forward(x))
|
|
126
|
+
if not self.normalize_before:
|
|
127
|
+
x = self.norm3(x)
|
|
128
|
+
|
|
129
|
+
if cache is not None:
|
|
130
|
+
x = torch.cat([cache, x], dim=1)
|
|
131
|
+
|
|
132
|
+
return x, tgt_mask, memory, memory_mask
|