xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +415 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +449 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,767 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import Parameter, Linear
19
+ from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
20
+ from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
21
+ import torch.nn.functional as F
22
+
23
+ DEFAULT_MAX_SOURCE_POSITIONS = 3000
24
+ DEFAULT_MAX_TARGET_POSITIONS = 3000
25
+
26
+
27
+ class SinusoidalPositionalEmbedding(nn.Module):
28
+ """This module produces sinusoidal positional embeddings of any length.
29
+
30
+ Padding symbols are ignored.
31
+ """
32
+
33
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
34
+ super().__init__()
35
+ self.embedding_dim = embedding_dim
36
+ self.padding_idx = padding_idx
37
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
38
+ init_size,
39
+ embedding_dim,
40
+ padding_idx,
41
+ )
42
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
43
+
44
+ @staticmethod
45
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
46
+ """Build sinusoidal embeddings.
47
+
48
+ This matches the implementation in tensor2tensor, but differs slightly
49
+ from the description in Section 3.5 of "Attention Is All You Need".
50
+ """
51
+ half_dim = embedding_dim // 2
52
+ emb = math.log(10000) / (half_dim - 1)
53
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
54
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
55
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
56
+ if embedding_dim % 2 == 1:
57
+ # zero pad
58
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
59
+ if padding_idx is not None:
60
+ emb[padding_idx, :] = 0
61
+ return emb
62
+
63
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
64
+ """Input is expected to be of size [bsz x seqlen]."""
65
+ bsz, seq_len = input.shape[:2]
66
+ max_pos = self.padding_idx + 1 + seq_len
67
+ if self.weights is None or max_pos > self.weights.size(0):
68
+ # recompute/expand embeddings if needed
69
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
70
+ max_pos,
71
+ self.embedding_dim,
72
+ self.padding_idx,
73
+ )
74
+ self.weights = self.weights.to(self._float_tensor)
75
+
76
+ if incremental_state is not None:
77
+ # positions is the same for every token when decoding a single step
78
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
79
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
80
+
81
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
82
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
83
+
84
+ def max_positions(self):
85
+ """Maximum number of supported positions."""
86
+ return int(1e5) # an arbitrary large number
87
+
88
+
89
+ class TransformerFFNLayer(nn.Module):
90
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
91
+ super().__init__()
92
+ self.kernel_size = kernel_size
93
+ self.dropout = dropout
94
+ self.act = act
95
+ if padding == 'SAME':
96
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
97
+ padding=kernel_size // 2, bias=bias)
98
+ elif padding == 'LEFT':
99
+ self.ffn_1 = nn.Sequential(
100
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
101
+ nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
102
+ )
103
+ self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)
104
+
105
+ def forward(self, x, incremental_state=None):
106
+ # x: T x B x C
107
+ if incremental_state is not None:
108
+ saved_state = self._get_input_buffer(incremental_state)
109
+ if 'prev_input' in saved_state:
110
+ prev_input = saved_state['prev_input']
111
+ x = torch.cat((prev_input, x), dim=0)
112
+ x = x[-self.kernel_size:]
113
+ saved_state['prev_input'] = x
114
+ self._set_input_buffer(incremental_state, saved_state)
115
+
116
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
117
+ x = x * self.kernel_size ** -0.5
118
+
119
+ if incremental_state is not None:
120
+ x = x[-1:]
121
+ if self.act == 'gelu':
122
+ x = F.gelu(x)
123
+ if self.act == 'relu':
124
+ x = F.relu(x)
125
+ x = F.dropout(x, self.dropout, training=self.training)
126
+ x = self.ffn_2(x)
127
+ return x
128
+
129
+ def _get_input_buffer(self, incremental_state):
130
+ return get_incremental_state(
131
+ self,
132
+ incremental_state,
133
+ 'f',
134
+ ) or {}
135
+
136
+ def _set_input_buffer(self, incremental_state, buffer):
137
+ set_incremental_state(
138
+ self,
139
+ incremental_state,
140
+ 'f',
141
+ buffer,
142
+ )
143
+
144
+ def clear_buffer(self, incremental_state):
145
+ if incremental_state is not None:
146
+ saved_state = self._get_input_buffer(incremental_state)
147
+ if 'prev_input' in saved_state:
148
+ del saved_state['prev_input']
149
+ self._set_input_buffer(incremental_state, saved_state)
150
+
151
+
152
+ class MultiheadAttention(nn.Module):
153
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
154
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
155
+ encoder_decoder_attention=False):
156
+ super().__init__()
157
+ self.embed_dim = embed_dim
158
+ self.kdim = kdim if kdim is not None else embed_dim
159
+ self.vdim = vdim if vdim is not None else embed_dim
160
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
161
+
162
+ self.num_heads = num_heads
163
+ self.dropout = dropout
164
+ self.head_dim = embed_dim // num_heads
165
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
166
+ self.scaling = self.head_dim ** -0.5
167
+
168
+ self.self_attention = self_attention
169
+ self.encoder_decoder_attention = encoder_decoder_attention
170
+
171
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
172
+ 'value to be of the same size'
173
+
174
+ if self.qkv_same_dim:
175
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
176
+ else:
177
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
178
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
179
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
180
+
181
+ if bias:
182
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
183
+ else:
184
+ self.register_parameter('in_proj_bias', None)
185
+
186
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
187
+
188
+ if add_bias_kv:
189
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
190
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
191
+ else:
192
+ self.bias_k = self.bias_v = None
193
+
194
+ self.add_zero_attn = add_zero_attn
195
+
196
+ self.reset_parameters()
197
+
198
+ self.enable_torch_version = False
199
+ self.last_attn_probs = None
200
+
201
+ def reset_parameters(self):
202
+ if self.qkv_same_dim:
203
+ nn.init.xavier_uniform_(self.in_proj_weight)
204
+ else:
205
+ nn.init.xavier_uniform_(self.k_proj_weight)
206
+ nn.init.xavier_uniform_(self.v_proj_weight)
207
+ nn.init.xavier_uniform_(self.q_proj_weight)
208
+
209
+ nn.init.xavier_uniform_(self.out_proj.weight)
210
+ if self.in_proj_bias is not None:
211
+ nn.init.constant_(self.in_proj_bias, 0.)
212
+ nn.init.constant_(self.out_proj.bias, 0.)
213
+ if self.bias_k is not None:
214
+ nn.init.xavier_normal_(self.bias_k)
215
+ if self.bias_v is not None:
216
+ nn.init.xavier_normal_(self.bias_v)
217
+
218
+ def forward(
219
+ self,
220
+ query, key, value,
221
+ key_padding_mask=None,
222
+ incremental_state=None,
223
+ need_weights=True,
224
+ static_kv=False,
225
+ attn_mask=None,
226
+ before_softmax=False,
227
+ need_head_weights=False,
228
+ enc_dec_attn_constraint_mask=None,
229
+ reset_attn_weight=None
230
+ ):
231
+ """Input shape: Time x Batch x Channel
232
+
233
+ Args:
234
+ key_padding_mask (ByteTensor, optional): mask to exclude
235
+ keys that are pads, of shape `(batch, src_len)`, where
236
+ padding elements are indicated by 1s.
237
+ need_weights (bool, optional): return the attention weights,
238
+ averaged over heads (default: False).
239
+ attn_mask (ByteTensor, optional): typically used to
240
+ implement causal attention, where the mask prevents the
241
+ attention from looking forward in time (default: None).
242
+ before_softmax (bool, optional): return the raw attention
243
+ weights and values before the attention softmax.
244
+ need_head_weights (bool, optional): return the attention
245
+ weights for each head. Implies *need_weights*. Default:
246
+ return the average attention weights over all heads.
247
+ """
248
+ if need_head_weights:
249
+ need_weights = True
250
+
251
+ tgt_len, bsz, embed_dim = query.size()
252
+ assert embed_dim == self.embed_dim
253
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
254
+
255
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
256
+ if self.qkv_same_dim:
257
+ return F.multi_head_attention_forward(query, key, value,
258
+ self.embed_dim, self.num_heads,
259
+ self.in_proj_weight,
260
+ self.in_proj_bias, self.bias_k, self.bias_v,
261
+ self.add_zero_attn, self.dropout,
262
+ self.out_proj.weight, self.out_proj.bias,
263
+ self.training, key_padding_mask, need_weights,
264
+ attn_mask)
265
+ else:
266
+ return F.multi_head_attention_forward(query, key, value,
267
+ self.embed_dim, self.num_heads,
268
+ torch.empty([0]),
269
+ self.in_proj_bias, self.bias_k, self.bias_v,
270
+ self.add_zero_attn, self.dropout,
271
+ self.out_proj.weight, self.out_proj.bias,
272
+ self.training, key_padding_mask, need_weights,
273
+ attn_mask, use_separate_proj_weight=True,
274
+ q_proj_weight=self.q_proj_weight,
275
+ k_proj_weight=self.k_proj_weight,
276
+ v_proj_weight=self.v_proj_weight)
277
+
278
+ if incremental_state is not None:
279
+ saved_state = self._get_input_buffer(incremental_state)
280
+ if 'prev_key' in saved_state:
281
+ # previous time steps are cached - no need to recompute
282
+ # key and value if they are static
283
+ if static_kv:
284
+ assert self.encoder_decoder_attention and not self.self_attention
285
+ key = value = None
286
+ else:
287
+ saved_state = None
288
+
289
+ if self.self_attention:
290
+ # self-attention
291
+ q, k, v = self.in_proj_qkv(query)
292
+ elif self.encoder_decoder_attention:
293
+ # encoder-decoder attention
294
+ q = self.in_proj_q(query)
295
+ if key is None:
296
+ assert value is None
297
+ k = v = None
298
+ else:
299
+ k = self.in_proj_k(key)
300
+ v = self.in_proj_v(key)
301
+
302
+ else:
303
+ q = self.in_proj_q(query)
304
+ k = self.in_proj_k(key)
305
+ v = self.in_proj_v(value)
306
+ q = q * self.scaling
307
+
308
+ if self.bias_k is not None:
309
+ assert self.bias_v is not None
310
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
311
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
312
+ if attn_mask is not None:
313
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
314
+ if key_padding_mask is not None:
315
+ key_padding_mask = torch.cat(
316
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
317
+
318
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
319
+ if k is not None:
320
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
321
+ if v is not None:
322
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
323
+
324
+ if saved_state is not None:
325
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
326
+ if 'prev_key' in saved_state:
327
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
328
+ if static_kv:
329
+ k = prev_key
330
+ else:
331
+ k = torch.cat((prev_key, k), dim=1)
332
+ if 'prev_value' in saved_state:
333
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
334
+ if static_kv:
335
+ v = prev_value
336
+ else:
337
+ v = torch.cat((prev_value, v), dim=1)
338
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
339
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
340
+ if static_kv:
341
+ key_padding_mask = prev_key_padding_mask
342
+ else:
343
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
344
+
345
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
346
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
347
+ saved_state['prev_key_padding_mask'] = key_padding_mask
348
+
349
+ self._set_input_buffer(incremental_state, saved_state)
350
+
351
+ src_len = k.size(1)
352
+
353
+ # This is part of a workaround to get around fork/join parallelism
354
+ # not supporting Optional types.
355
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
356
+ key_padding_mask = None
357
+
358
+ if key_padding_mask is not None:
359
+ assert key_padding_mask.size(0) == bsz
360
+ assert key_padding_mask.size(1) == src_len
361
+
362
+ if self.add_zero_attn:
363
+ src_len += 1
364
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
365
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
366
+ if attn_mask is not None:
367
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
368
+ if key_padding_mask is not None:
369
+ key_padding_mask = torch.cat(
370
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
371
+
372
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
373
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
374
+
375
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
376
+
377
+ if attn_mask is not None:
378
+ if len(attn_mask.shape) == 2:
379
+ attn_mask = attn_mask.unsqueeze(0)
380
+ elif len(attn_mask.shape) == 3:
381
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
382
+ bsz * self.num_heads, tgt_len, src_len)
383
+ attn_weights = attn_weights + attn_mask
384
+
385
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
386
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
387
+ attn_weights = attn_weights.masked_fill(
388
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
389
+ -1e8,
390
+ )
391
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
392
+
393
+ if key_padding_mask is not None:
394
+ # don't attend to padding symbols
395
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
396
+ attn_weights = attn_weights.masked_fill(
397
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
398
+ -1e8,
399
+ )
400
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
401
+
402
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
403
+
404
+ if before_softmax:
405
+ return attn_weights, v
406
+
407
+ attn_weights_float = softmax(attn_weights, dim=-1)
408
+ attn_weights = attn_weights_float.type_as(attn_weights)
409
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
410
+
411
+ if reset_attn_weight is not None:
412
+ if reset_attn_weight:
413
+ self.last_attn_probs = attn_probs.detach()
414
+ else:
415
+ assert self.last_attn_probs is not None
416
+ attn_probs = self.last_attn_probs
417
+ attn = torch.bmm(attn_probs, v)
418
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
419
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
420
+ attn = self.out_proj(attn)
421
+
422
+ if need_weights:
423
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
424
+ if not need_head_weights:
425
+ # average attention weights over heads
426
+ attn_weights = attn_weights.mean(dim=0)
427
+ else:
428
+ attn_weights = None
429
+
430
+ return attn, (attn_weights, attn_logits)
431
+
432
+ def in_proj_qkv(self, query):
433
+ return self._in_proj(query).chunk(3, dim=-1)
434
+
435
+ def in_proj_q(self, query):
436
+ if self.qkv_same_dim:
437
+ return self._in_proj(query, end=self.embed_dim)
438
+ else:
439
+ bias = self.in_proj_bias
440
+ if bias is not None:
441
+ bias = bias[:self.embed_dim]
442
+ return F.linear(query, self.q_proj_weight, bias)
443
+
444
+ def in_proj_k(self, key):
445
+ if self.qkv_same_dim:
446
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
447
+ else:
448
+ weight = self.k_proj_weight
449
+ bias = self.in_proj_bias
450
+ if bias is not None:
451
+ bias = bias[self.embed_dim:2 * self.embed_dim]
452
+ return F.linear(key, weight, bias)
453
+
454
+ def in_proj_v(self, value):
455
+ if self.qkv_same_dim:
456
+ return self._in_proj(value, start=2 * self.embed_dim)
457
+ else:
458
+ weight = self.v_proj_weight
459
+ bias = self.in_proj_bias
460
+ if bias is not None:
461
+ bias = bias[2 * self.embed_dim:]
462
+ return F.linear(value, weight, bias)
463
+
464
+ def _in_proj(self, input, start=0, end=None):
465
+ weight = self.in_proj_weight
466
+ bias = self.in_proj_bias
467
+ weight = weight[start:end, :]
468
+ if bias is not None:
469
+ bias = bias[start:end]
470
+ return F.linear(input, weight, bias)
471
+
472
+ def _get_input_buffer(self, incremental_state):
473
+ return get_incremental_state(
474
+ self,
475
+ incremental_state,
476
+ 'attn_state',
477
+ ) or {}
478
+
479
+ def _set_input_buffer(self, incremental_state, buffer):
480
+ set_incremental_state(
481
+ self,
482
+ incremental_state,
483
+ 'attn_state',
484
+ buffer,
485
+ )
486
+
487
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
488
+ return attn_weights
489
+
490
+ def clear_buffer(self, incremental_state=None):
491
+ if incremental_state is not None:
492
+ saved_state = self._get_input_buffer(incremental_state)
493
+ if 'prev_key' in saved_state:
494
+ del saved_state['prev_key']
495
+ if 'prev_value' in saved_state:
496
+ del saved_state['prev_value']
497
+ self._set_input_buffer(incremental_state, saved_state)
498
+
499
+
500
+ class EncSALayer(nn.Module):
501
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
502
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
503
+ ffn_hidden_size=1024):
504
+ super().__init__()
505
+ self.c = c
506
+ self.dropout = dropout
507
+ self.num_heads = num_heads
508
+ if num_heads > 0:
509
+ self.layer_norm1 = LayerNorm(c)
510
+ self.self_attn = MultiheadAttention(
511
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
512
+ self.layer_norm2 = LayerNorm(c)
513
+ self.ffn = TransformerFFNLayer(
514
+ c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
515
+
516
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
517
+ layer_norm_training = kwargs.get('layer_norm_training', None)
518
+ if layer_norm_training is not None:
519
+ self.layer_norm1.training = layer_norm_training
520
+ self.layer_norm2.training = layer_norm_training
521
+ if self.num_heads > 0:
522
+ residual = x
523
+ x = self.layer_norm1(x)
524
+ x, _, = self.self_attn(
525
+ query=x,
526
+ key=x,
527
+ value=x,
528
+ key_padding_mask=encoder_padding_mask
529
+ )
530
+ x = F.dropout(x, self.dropout, training=self.training)
531
+ x = residual + x
532
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
533
+
534
+ residual = x
535
+ x = self.layer_norm2(x)
536
+ x = self.ffn(x)
537
+ x = F.dropout(x, self.dropout, training=self.training)
538
+ x = residual + x
539
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
540
+ return x
541
+
542
+
543
+ class DecSALayer(nn.Module):
544
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
545
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
546
+ super().__init__()
547
+ self.c = c
548
+ self.dropout = dropout
549
+ self.layer_norm1 = LayerNorm(c)
550
+ self.self_attn = MultiheadAttention(
551
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
552
+ )
553
+ self.layer_norm2 = LayerNorm(c)
554
+ self.encoder_attn = MultiheadAttention(
555
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
556
+ )
557
+ self.layer_norm3 = LayerNorm(c)
558
+ self.ffn = TransformerFFNLayer(
559
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
560
+ self.post_ln = post_ln
561
+
562
+ def forward(
563
+ self,
564
+ x,
565
+ encoder_out=None,
566
+ encoder_padding_mask=None,
567
+ incremental_state=None,
568
+ self_attn_mask=None,
569
+ self_attn_padding_mask=None,
570
+ attn_out=None,
571
+ reset_attn_weight=None,
572
+ **kwargs,
573
+ ):
574
+ layer_norm_training = kwargs.get('layer_norm_training', None)
575
+ if layer_norm_training is not None:
576
+ self.layer_norm1.training = layer_norm_training
577
+ self.layer_norm2.training = layer_norm_training
578
+ self.layer_norm3.training = layer_norm_training
579
+ residual = x
580
+ if not self.post_ln:
581
+ x = self.layer_norm1(x)
582
+ x, _ = self.self_attn(
583
+ query=x,
584
+ key=x,
585
+ value=x,
586
+ key_padding_mask=self_attn_padding_mask,
587
+ incremental_state=incremental_state,
588
+ attn_mask=self_attn_mask
589
+ )
590
+ x = F.dropout(x, self.dropout, training=self.training)
591
+ x = residual + x
592
+ if self.post_ln:
593
+ x = self.layer_norm1(x)
594
+
595
+ attn_logits = None
596
+ if encoder_out is not None or attn_out is not None:
597
+ residual = x
598
+ if not self.post_ln:
599
+ x = self.layer_norm2(x)
600
+ if encoder_out is not None:
601
+ x, attn = self.encoder_attn(
602
+ query=x,
603
+ key=encoder_out,
604
+ value=encoder_out,
605
+ key_padding_mask=encoder_padding_mask,
606
+ incremental_state=incremental_state,
607
+ static_kv=True,
608
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
609
+ 'enc_dec_attn_constraint_mask'),
610
+ reset_attn_weight=reset_attn_weight
611
+ )
612
+ attn_logits = attn[1]
613
+ elif attn_out is not None:
614
+ x = self.encoder_attn.in_proj_v(attn_out)
615
+ if encoder_out is not None or attn_out is not None:
616
+ x = F.dropout(x, self.dropout, training=self.training)
617
+ x = residual + x
618
+ if self.post_ln:
619
+ x = self.layer_norm2(x)
620
+
621
+ residual = x
622
+ if not self.post_ln:
623
+ x = self.layer_norm3(x)
624
+ x = self.ffn(x, incremental_state=incremental_state)
625
+ x = F.dropout(x, self.dropout, training=self.training)
626
+ x = residual + x
627
+ if self.post_ln:
628
+ x = self.layer_norm3(x)
629
+ return x, attn_logits
630
+
631
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
632
+ self.encoder_attn.clear_buffer(incremental_state)
633
+ self.ffn.clear_buffer(incremental_state)
634
+
635
+ def set_buffer(self, name, tensor, incremental_state):
636
+ return set_incremental_state(self, incremental_state, name, tensor)
637
+
638
+
639
+ class TransformerEncoderLayer(nn.Module):
640
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
641
+ super().__init__()
642
+ self.hidden_size = hidden_size
643
+ self.dropout = dropout
644
+ self.num_heads = num_heads
645
+ self.op = EncSALayer(
646
+ hidden_size, num_heads, dropout=dropout,
647
+ attention_dropout=0.0, relu_dropout=dropout,
648
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
649
+
650
+ def forward(self, x, **kwargs):
651
+ return self.op(x, **kwargs)
652
+
653
+
654
+ class TransformerDecoderLayer(nn.Module):
655
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
656
+ super().__init__()
657
+ self.hidden_size = hidden_size
658
+ self.dropout = dropout
659
+ self.num_heads = num_heads
660
+ self.op = DecSALayer(
661
+ hidden_size, num_heads, dropout=dropout,
662
+ attention_dropout=0.0, relu_dropout=dropout,
663
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
664
+ post_ln=post_ln)
665
+
666
+ def forward(self, x, **kwargs):
667
+ return self.op(x, **kwargs)
668
+
669
+ def clear_buffer(self, *args):
670
+ return self.op.clear_buffer(*args)
671
+
672
+ def set_buffer(self, *args):
673
+ return self.op.set_buffer(*args)
674
+
675
+
676
+ class FFTBlocks(nn.Module):
677
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
678
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
679
+ use_pos_embed_alpha=True, ffn_hidden_size=1024):
680
+ super().__init__()
681
+ self.num_layers = num_layers
682
+ embed_dim = self.hidden_size = hidden_size
683
+ self.dropout = dropout
684
+ self.use_pos_embed = use_pos_embed
685
+ self.use_last_norm = use_last_norm
686
+ if use_pos_embed:
687
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
688
+ self.padding_idx = 0
689
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
690
+ self.embed_positions = SinusoidalPositionalEmbedding(
691
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
692
+ )
693
+
694
+ self.layers = nn.ModuleList([])
695
+ self.layers.extend([
696
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
697
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
698
+ ffn_hidden_size=ffn_hidden_size)
699
+ for _ in range(self.num_layers)
700
+ ])
701
+ if self.use_last_norm:
702
+ self.layer_norm = nn.LayerNorm(embed_dim)
703
+ else:
704
+ self.layer_norm = None
705
+
706
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
707
+ """
708
+ :param x: [B, T, C]
709
+ :param padding_mask: [B, T]
710
+ :return: [B, T, C] or [L, B, T, C]
711
+ """
712
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
713
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
714
+ if self.use_pos_embed:
715
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
716
+ x = x + positions
717
+ x = F.dropout(x, p=self.dropout, training=self.training)
718
+ # B x T x C -> T x B x C
719
+ x = x.transpose(0, 1) * nonpadding_mask_TB
720
+ hiddens = []
721
+ for layer in self.layers:
722
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
723
+ hiddens.append(x)
724
+ if self.use_last_norm:
725
+ x = self.layer_norm(x) * nonpadding_mask_TB
726
+ if return_hiddens:
727
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
728
+ x = x.transpose(1, 2) # [L, B, T, C]
729
+ else:
730
+ x = x.transpose(0, 1) # [B, T, C]
731
+ return x
732
+
733
+
734
+ class FastSpeechEncoder(FFTBlocks):
735
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
736
+ dropout=0.0, num_heads=2, ffn_hidden_size=1024):
737
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
738
+ use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
739
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
740
+ self.embed_scale = math.sqrt(hidden_size)
741
+ self.padding_idx = 0
742
+ self.embed_positions = SinusoidalPositionalEmbedding(
743
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
744
+ )
745
+
746
+ def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
747
+ """
748
+
749
+ :param txt_tokens: [B, T]
750
+ :return: {
751
+ 'encoder_out': [B x T x C]
752
+ }
753
+ """
754
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
755
+ x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
756
+ if self.num_layers > 0:
757
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
758
+ return x
759
+
760
+ def forward_embedding(self, txt_tokens):
761
+ # embed tokens and positions
762
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
763
+ if self.use_pos_embed:
764
+ positions = self.embed_positions(txt_tokens)
765
+ x = x + positions
766
+ x = F.dropout(x, p=self.dropout, training=self.training)
767
+ return x