xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +415 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +449 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,403 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from tts.modules.ar_dur.commons.layers import Embedding
21
+
22
+
23
+ def convert_pad_shape(pad_shape):
24
+ l = pad_shape[::-1]
25
+ pad_shape = [item for sublist in l for item in sublist]
26
+ return pad_shape
27
+
28
+
29
+ def shift_1d(x):
30
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
31
+ return x
32
+
33
+
34
+ def sequence_mask(length, max_length=None):
35
+ if max_length is None:
36
+ max_length = length.max()
37
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
38
+ return x.unsqueeze(0) < length.unsqueeze(1)
39
+
40
+
41
+ class Encoder(nn.Module):
42
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
43
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
44
+ super().__init__()
45
+ self.hidden_channels = hidden_channels
46
+ self.filter_channels = filter_channels
47
+ self.n_heads = n_heads
48
+ self.n_layers = n_layers
49
+ self.kernel_size = kernel_size
50
+ self.p_dropout = p_dropout
51
+ self.window_size = window_size
52
+ self.block_length = block_length
53
+ self.pre_ln = pre_ln
54
+
55
+ self.drop = nn.Dropout(p_dropout)
56
+ self.attn_layers = nn.ModuleList()
57
+ self.norm_layers_1 = nn.ModuleList()
58
+ self.ffn_layers = nn.ModuleList()
59
+ self.norm_layers_2 = nn.ModuleList()
60
+ for i in range(self.n_layers):
61
+ self.attn_layers.append(
62
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
63
+ p_dropout=p_dropout, block_length=block_length))
64
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
65
+ self.ffn_layers.append(
66
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
67
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
68
+ if pre_ln:
69
+ self.last_ln = LayerNorm(hidden_channels)
70
+
71
+ def forward(self, x, x_mask, attn_mask=1):
72
+ if isinstance(attn_mask, torch.Tensor):
73
+ attn_mask = attn_mask[:, None]
74
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
75
+ for i in range(self.n_layers):
76
+ x = x * x_mask
77
+ x_ = x
78
+ if self.pre_ln:
79
+ x = self.norm_layers_1[i](x)
80
+ y = self.attn_layers[i](x, x, attn_mask)
81
+ y = self.drop(y)
82
+ x = x_ + y
83
+ if not self.pre_ln:
84
+ x = self.norm_layers_1[i](x)
85
+
86
+ x_ = x
87
+ if self.pre_ln:
88
+ x = self.norm_layers_2[i](x)
89
+ y = self.ffn_layers[i](x, x_mask)
90
+ y = self.drop(y)
91
+ x = x_ + y
92
+ if not self.pre_ln:
93
+ x = self.norm_layers_2[i](x)
94
+ if self.pre_ln:
95
+ x = self.last_ln(x)
96
+ x = x * x_mask
97
+ return x
98
+
99
+
100
+ class MultiHeadAttention(nn.Module):
101
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
102
+ block_length=None, proximal_bias=False, proximal_init=False):
103
+ super().__init__()
104
+ assert channels % n_heads == 0
105
+
106
+ self.channels = channels
107
+ self.out_channels = out_channels
108
+ self.n_heads = n_heads
109
+ self.window_size = window_size
110
+ self.heads_share = heads_share
111
+ self.block_length = block_length
112
+ self.proximal_bias = proximal_bias
113
+ self.p_dropout = p_dropout
114
+ self.attn = None
115
+
116
+ self.k_channels = channels // n_heads
117
+ self.conv_q = nn.Conv1d(channels, channels, 1)
118
+ self.conv_k = nn.Conv1d(channels, channels, 1)
119
+ self.conv_v = nn.Conv1d(channels, channels, 1)
120
+ if window_size is not None:
121
+ n_heads_rel = 1 if heads_share else n_heads
122
+ rel_stddev = self.k_channels ** -0.5
123
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
124
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
126
+ self.drop = nn.Dropout(p_dropout)
127
+
128
+ nn.init.xavier_uniform_(self.conv_q.weight)
129
+ nn.init.xavier_uniform_(self.conv_k.weight)
130
+ if proximal_init:
131
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
132
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
133
+ nn.init.xavier_uniform_(self.conv_v.weight)
134
+
135
+ def forward(self, x, c, attn_mask=None):
136
+ q = self.conv_q(x)
137
+ k = self.conv_k(c)
138
+ v = self.conv_v(c)
139
+
140
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
141
+
142
+ x = self.conv_o(x)
143
+ return x
144
+
145
+ def attention(self, query, key, value, mask=None):
146
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
147
+ b, d, t_s, t_t = (*key.size(), query.size(2))
148
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
+
152
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
153
+ if self.window_size is not None:
154
+ assert t_s == t_t, "Relative attention is only available for self-attention."
155
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
157
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
158
+ scores_local = rel_logits / math.sqrt(self.k_channels)
159
+ scores = scores + scores_local
160
+ if self.proximal_bias:
161
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
162
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
163
+ if mask is not None:
164
+ scores = scores.masked_fill(mask == 0, -1e4)
165
+ if self.block_length is not None:
166
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
168
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169
+ p_attn = self.drop(p_attn)
170
+ output = torch.matmul(p_attn, value)
171
+ if self.window_size is not None:
172
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
173
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176
+ return output, p_attn
177
+
178
+ def _matmul_with_relative_values(self, x, y):
179
+ """
180
+ x: [b, h, l, m]
181
+ y: [h or 1, m, d]
182
+ ret: [b, h, l, d]
183
+ """
184
+ ret = torch.matmul(x, y.unsqueeze(0))
185
+ return ret
186
+
187
+ def _matmul_with_relative_keys(self, x, y):
188
+ """
189
+ x: [b, h, l, d]
190
+ y: [h or 1, m, d]
191
+ ret: [b, h, l, m]
192
+ """
193
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194
+ return ret
195
+
196
+ def _get_relative_embeddings(self, relative_embeddings, length):
197
+ max_relative_position = 2 * self.window_size + 1
198
+ # Pad first before slice to avoid using cond ops.
199
+ pad_length = max(length - (self.window_size + 1), 0)
200
+ slice_start_position = max((self.window_size + 1) - length, 0)
201
+ slice_end_position = slice_start_position + 2 * length - 1
202
+ if pad_length > 0:
203
+ padded_relative_embeddings = F.pad(
204
+ relative_embeddings,
205
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206
+ else:
207
+ padded_relative_embeddings = relative_embeddings
208
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
209
+ return used_relative_embeddings
210
+
211
+ def _relative_position_to_absolute_position(self, x):
212
+ """
213
+ x: [b, h, l, 2*l-1]
214
+ ret: [b, h, l, l]
215
+ """
216
+ batch, heads, length, _ = x.size()
217
+ # Concat columns of pad to shift from relative to absolute indexing.
218
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
219
+
220
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
221
+ x_flat = x.view([batch, heads, length * 2 * length])
222
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
223
+
224
+ # Reshape and slice out the padded elements.
225
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
226
+ return x_final
227
+
228
+ def _absolute_position_to_relative_position(self, x):
229
+ """
230
+ x: [b, h, l, l]
231
+ ret: [b, h, l, 2*l-1]
232
+ """
233
+ batch, heads, length, _ = x.size()
234
+ # padd along column
235
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
236
+ x_flat = x.view([batch, heads, -1])
237
+ # add 0's in the beginning that will skew the elements after reshape
238
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
240
+ return x_final
241
+
242
+ def _attention_bias_proximal(self, length):
243
+ """Bias for self-attention to encourage attention to close positions.
244
+ Args:
245
+ length: an integer scalar.
246
+ Returns:
247
+ a Tensor with shape [1, 1, length, length]
248
+ """
249
+ r = torch.arange(length, dtype=torch.float32)
250
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
252
+
253
+
254
+ class FFN(nn.Module):
255
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
256
+ super().__init__()
257
+ self.in_channels = in_channels
258
+ self.out_channels = out_channels
259
+ self.filter_channels = filter_channels
260
+ self.kernel_size = kernel_size
261
+ self.p_dropout = p_dropout
262
+ self.activation = activation
263
+
264
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
265
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
266
+ self.drop = nn.Dropout(p_dropout)
267
+
268
+ def forward(self, x, x_mask):
269
+ x = self.conv_1(x * x_mask)
270
+ if self.activation == "gelu":
271
+ x = x * torch.sigmoid(1.702 * x)
272
+ else:
273
+ x = torch.relu(x)
274
+ x = self.drop(x)
275
+ x = self.conv_2(x * x_mask)
276
+ return x * x_mask
277
+
278
+
279
+ class LayerNorm(nn.Module):
280
+ def __init__(self, channels, eps=1e-4):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.eps = eps
284
+
285
+ self.gamma = nn.Parameter(torch.ones(channels))
286
+ self.beta = nn.Parameter(torch.zeros(channels))
287
+
288
+ def forward(self, x):
289
+ n_dims = len(x.shape)
290
+ mean = torch.mean(x, 1, keepdim=True)
291
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
292
+
293
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
294
+
295
+ shape = [1, -1] + [1] * (n_dims - 2)
296
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
297
+ return x
298
+
299
+
300
+ class ConvReluNorm(nn.Module):
301
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
302
+ super().__init__()
303
+ self.in_channels = in_channels
304
+ self.hidden_channels = hidden_channels
305
+ self.out_channels = out_channels
306
+ self.kernel_size = kernel_size
307
+ self.n_layers = n_layers
308
+ self.p_dropout = p_dropout
309
+ assert n_layers > 1, "Number of layers should be larger than 0."
310
+
311
+ self.conv_layers = nn.ModuleList()
312
+ self.norm_layers = nn.ModuleList()
313
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
314
+ self.norm_layers.append(LayerNorm(hidden_channels))
315
+ self.relu_drop = nn.Sequential(
316
+ nn.ReLU(),
317
+ nn.Dropout(p_dropout))
318
+ for _ in range(n_layers - 1):
319
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
320
+ self.norm_layers.append(LayerNorm(hidden_channels))
321
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
322
+ self.proj.weight.data.zero_()
323
+ self.proj.bias.data.zero_()
324
+
325
+ def forward(self, x, x_mask):
326
+ x_org = x
327
+ for i in range(self.n_layers):
328
+ x = self.conv_layers[i](x * x_mask)
329
+ x = self.norm_layers[i](x)
330
+ x = self.relu_drop(x)
331
+ x = x_org + self.proj(x)
332
+ return x * x_mask
333
+
334
+
335
+ class RelTransformerEncoder(nn.Module):
336
+ def __init__(self,
337
+ n_vocab,
338
+ out_channels,
339
+ hidden_channels,
340
+ filter_channels,
341
+ n_heads,
342
+ n_layers,
343
+ kernel_size,
344
+ p_dropout=0.0,
345
+ window_size=4,
346
+ block_length=None,
347
+ in_channels=None,
348
+ prenet=True,
349
+ pre_ln=True,
350
+ ):
351
+
352
+ super().__init__()
353
+
354
+ self.n_vocab = n_vocab
355
+ self.out_channels = out_channels
356
+ self.hidden_channels = hidden_channels
357
+ self.filter_channels = filter_channels
358
+ self.n_heads = n_heads
359
+ self.n_layers = n_layers
360
+ self.kernel_size = kernel_size
361
+ self.p_dropout = p_dropout
362
+ self.window_size = window_size
363
+ self.block_length = block_length
364
+ self.prenet = prenet
365
+ if n_vocab > 0:
366
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
367
+
368
+ if prenet:
369
+ if in_channels is None:
370
+ in_channels = hidden_channels
371
+ self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
372
+ kernel_size=5, n_layers=3, p_dropout=0)
373
+ if in_channels is not None and in_channels != hidden_channels:
374
+ self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
375
+ self.encoder = Encoder(
376
+ hidden_channels,
377
+ filter_channels,
378
+ n_heads,
379
+ n_layers,
380
+ kernel_size,
381
+ p_dropout,
382
+ window_size=window_size,
383
+ block_length=block_length,
384
+ pre_ln=pre_ln,
385
+ )
386
+
387
+ def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
388
+ if self.n_vocab > 0:
389
+ x_lengths = (x > 0).long().sum(-1)
390
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
391
+ else:
392
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
393
+ x = x + other_embeds
394
+ x = torch.transpose(x, 1, -1) # [b, h, t]
395
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
396
+
397
+ if self.prenet:
398
+ x = self.pre(x, x_mask)
399
+ self.prenet_out = x.transpose(1, 2)
400
+ if hasattr(self, 'encoder_inp_proj'):
401
+ x = self.encoder_inp_proj(x) * x_mask
402
+ x = self.encoder(x, x_mask, attn_mask)
403
+ return x.transpose(1, 2)