xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,459 @@
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from . import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+
59
+ self.cond_layer_idx = self.n_layers
60
+ if "gin_channels" in kwargs:
61
+ self.gin_channels = kwargs["gin_channels"]
62
+ if self.gin_channels != 0:
63
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
64
+ self.cond_layer_idx = (
65
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
66
+ )
67
+ assert (
68
+ self.cond_layer_idx < self.n_layers
69
+ ), "cond_layer_idx should be less than n_layers"
70
+ self.drop = nn.Dropout(p_dropout)
71
+ self.attn_layers = nn.ModuleList()
72
+ self.norm_layers_1 = nn.ModuleList()
73
+ self.ffn_layers = nn.ModuleList()
74
+ self.norm_layers_2 = nn.ModuleList()
75
+
76
+ for i in range(self.n_layers):
77
+ self.attn_layers.append(
78
+ MultiHeadAttention(
79
+ hidden_channels,
80
+ hidden_channels,
81
+ n_heads,
82
+ p_dropout=p_dropout,
83
+ window_size=window_size,
84
+ )
85
+ )
86
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
87
+ self.ffn_layers.append(
88
+ FFN(
89
+ hidden_channels,
90
+ hidden_channels,
91
+ filter_channels,
92
+ kernel_size,
93
+ p_dropout=p_dropout,
94
+ )
95
+ )
96
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
97
+
98
+ def forward(self, x, x_mask, g=None):
99
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
100
+ x = x * x_mask
101
+ for i in range(self.n_layers):
102
+ if i == self.cond_layer_idx and g is not None:
103
+ g = self.spk_emb_linear(g.transpose(1, 2))
104
+ g = g.transpose(1, 2)
105
+ x = x + g
106
+ x = x * x_mask
107
+ y = self.attn_layers[i](x, x, attn_mask)
108
+ y = self.drop(y)
109
+ x = self.norm_layers_1[i](x + y)
110
+
111
+ y = self.ffn_layers[i](x, x_mask)
112
+ y = self.drop(y)
113
+ x = self.norm_layers_2[i](x + y)
114
+ x = x * x_mask
115
+ return x
116
+
117
+
118
+ class Decoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ hidden_channels,
122
+ filter_channels,
123
+ n_heads,
124
+ n_layers,
125
+ kernel_size=1,
126
+ p_dropout=0.0,
127
+ proximal_bias=False,
128
+ proximal_init=True,
129
+ **kwargs
130
+ ):
131
+ super().__init__()
132
+ self.hidden_channels = hidden_channels
133
+ self.filter_channels = filter_channels
134
+ self.n_heads = n_heads
135
+ self.n_layers = n_layers
136
+ self.kernel_size = kernel_size
137
+ self.p_dropout = p_dropout
138
+ self.proximal_bias = proximal_bias
139
+ self.proximal_init = proximal_init
140
+
141
+ self.drop = nn.Dropout(p_dropout)
142
+ self.self_attn_layers = nn.ModuleList()
143
+ self.norm_layers_0 = nn.ModuleList()
144
+ self.encdec_attn_layers = nn.ModuleList()
145
+ self.norm_layers_1 = nn.ModuleList()
146
+ self.ffn_layers = nn.ModuleList()
147
+ self.norm_layers_2 = nn.ModuleList()
148
+ for i in range(self.n_layers):
149
+ self.self_attn_layers.append(
150
+ MultiHeadAttention(
151
+ hidden_channels,
152
+ hidden_channels,
153
+ n_heads,
154
+ p_dropout=p_dropout,
155
+ proximal_bias=proximal_bias,
156
+ proximal_init=proximal_init,
157
+ )
158
+ )
159
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
160
+ self.encdec_attn_layers.append(
161
+ MultiHeadAttention(
162
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
163
+ )
164
+ )
165
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
166
+ self.ffn_layers.append(
167
+ FFN(
168
+ hidden_channels,
169
+ hidden_channels,
170
+ filter_channels,
171
+ kernel_size,
172
+ p_dropout=p_dropout,
173
+ causal=True,
174
+ )
175
+ )
176
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
177
+
178
+ def forward(self, x, x_mask, h, h_mask):
179
+ """
180
+ x: decoder input
181
+ h: encoder output
182
+ """
183
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
184
+ device=x.device, dtype=x.dtype
185
+ )
186
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
187
+ x = x * x_mask
188
+ for i in range(self.n_layers):
189
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
190
+ y = self.drop(y)
191
+ x = self.norm_layers_0[i](x + y)
192
+
193
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
194
+ y = self.drop(y)
195
+ x = self.norm_layers_1[i](x + y)
196
+
197
+ y = self.ffn_layers[i](x, x_mask)
198
+ y = self.drop(y)
199
+ x = self.norm_layers_2[i](x + y)
200
+ x = x * x_mask
201
+ return x
202
+
203
+
204
+ class MultiHeadAttention(nn.Module):
205
+ def __init__(
206
+ self,
207
+ channels,
208
+ out_channels,
209
+ n_heads,
210
+ p_dropout=0.0,
211
+ window_size=None,
212
+ heads_share=True,
213
+ block_length=None,
214
+ proximal_bias=False,
215
+ proximal_init=False,
216
+ ):
217
+ super().__init__()
218
+ assert channels % n_heads == 0
219
+
220
+ self.channels = channels
221
+ self.out_channels = out_channels
222
+ self.n_heads = n_heads
223
+ self.p_dropout = p_dropout
224
+ self.window_size = window_size
225
+ self.heads_share = heads_share
226
+ self.block_length = block_length
227
+ self.proximal_bias = proximal_bias
228
+ self.proximal_init = proximal_init
229
+ self.attn = None
230
+
231
+ self.k_channels = channels // n_heads
232
+ self.conv_q = nn.Conv1d(channels, channels, 1)
233
+ self.conv_k = nn.Conv1d(channels, channels, 1)
234
+ self.conv_v = nn.Conv1d(channels, channels, 1)
235
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
236
+ self.drop = nn.Dropout(p_dropout)
237
+
238
+ if window_size is not None:
239
+ n_heads_rel = 1 if heads_share else n_heads
240
+ rel_stddev = self.k_channels**-0.5
241
+ self.emb_rel_k = nn.Parameter(
242
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
243
+ * rel_stddev
244
+ )
245
+ self.emb_rel_v = nn.Parameter(
246
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
247
+ * rel_stddev
248
+ )
249
+
250
+ nn.init.xavier_uniform_(self.conv_q.weight)
251
+ nn.init.xavier_uniform_(self.conv_k.weight)
252
+ nn.init.xavier_uniform_(self.conv_v.weight)
253
+ if proximal_init:
254
+ with torch.no_grad():
255
+ self.conv_k.weight.copy_(self.conv_q.weight)
256
+ self.conv_k.bias.copy_(self.conv_q.bias)
257
+
258
+ def forward(self, x, c, attn_mask=None):
259
+ q = self.conv_q(x)
260
+ k = self.conv_k(c)
261
+ v = self.conv_v(c)
262
+
263
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
264
+
265
+ x = self.conv_o(x)
266
+ return x
267
+
268
+ def attention(self, query, key, value, mask=None):
269
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
270
+ b, d, t_s, t_t = (*key.size(), query.size(2))
271
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
272
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
273
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
274
+
275
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
276
+ if self.window_size is not None:
277
+ assert (
278
+ t_s == t_t
279
+ ), "Relative attention is only available for self-attention."
280
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
281
+ rel_logits = self._matmul_with_relative_keys(
282
+ query / math.sqrt(self.k_channels), key_relative_embeddings
283
+ )
284
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
285
+ scores = scores + scores_local
286
+ if self.proximal_bias:
287
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
288
+ scores = scores + self._attention_bias_proximal(t_s).to(
289
+ device=scores.device, dtype=scores.dtype
290
+ )
291
+ if mask is not None:
292
+ scores = scores.masked_fill(mask == 0, -1e4)
293
+ if self.block_length is not None:
294
+ assert (
295
+ t_s == t_t
296
+ ), "Local attention is only available for self-attention."
297
+ block_mask = (
298
+ torch.ones_like(scores)
299
+ .triu(-self.block_length)
300
+ .tril(self.block_length)
301
+ )
302
+ scores = scores.masked_fill(block_mask == 0, -1e4)
303
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
304
+ p_attn = self.drop(p_attn)
305
+ output = torch.matmul(p_attn, value)
306
+ if self.window_size is not None:
307
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
308
+ value_relative_embeddings = self._get_relative_embeddings(
309
+ self.emb_rel_v, t_s
310
+ )
311
+ output = output + self._matmul_with_relative_values(
312
+ relative_weights, value_relative_embeddings
313
+ )
314
+ output = (
315
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
316
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
317
+ return output, p_attn
318
+
319
+ def _matmul_with_relative_values(self, x, y):
320
+ """
321
+ x: [b, h, l, m]
322
+ y: [h or 1, m, d]
323
+ ret: [b, h, l, d]
324
+ """
325
+ ret = torch.matmul(x, y.unsqueeze(0))
326
+ return ret
327
+
328
+ def _matmul_with_relative_keys(self, x, y):
329
+ """
330
+ x: [b, h, l, d]
331
+ y: [h or 1, m, d]
332
+ ret: [b, h, l, m]
333
+ """
334
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
335
+ return ret
336
+
337
+ def _get_relative_embeddings(self, relative_embeddings, length):
338
+ 2 * self.window_size + 1
339
+ # Pad first before slice to avoid using cond ops.
340
+ pad_length = max(length - (self.window_size + 1), 0)
341
+ slice_start_position = max((self.window_size + 1) - length, 0)
342
+ slice_end_position = slice_start_position + 2 * length - 1
343
+ if pad_length > 0:
344
+ padded_relative_embeddings = F.pad(
345
+ relative_embeddings,
346
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
347
+ )
348
+ else:
349
+ padded_relative_embeddings = relative_embeddings
350
+ used_relative_embeddings = padded_relative_embeddings[
351
+ :, slice_start_position:slice_end_position
352
+ ]
353
+ return used_relative_embeddings
354
+
355
+ def _relative_position_to_absolute_position(self, x):
356
+ """
357
+ x: [b, h, l, 2*l-1]
358
+ ret: [b, h, l, l]
359
+ """
360
+ batch, heads, length, _ = x.size()
361
+ # Concat columns of pad to shift from relative to absolute indexing.
362
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
363
+
364
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
365
+ x_flat = x.view([batch, heads, length * 2 * length])
366
+ x_flat = F.pad(
367
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
368
+ )
369
+
370
+ # Reshape and slice out the padded elements.
371
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
372
+ :, :, :length, length - 1 :
373
+ ]
374
+ return x_final
375
+
376
+ def _absolute_position_to_relative_position(self, x):
377
+ """
378
+ x: [b, h, l, l]
379
+ ret: [b, h, l, 2*l-1]
380
+ """
381
+ batch, heads, length, _ = x.size()
382
+ # pad along column
383
+ x = F.pad(
384
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
385
+ )
386
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
387
+ # add 0's in the beginning that will skew the elements after reshape
388
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
389
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
390
+ return x_final
391
+
392
+ def _attention_bias_proximal(self, length):
393
+ """Bias for self-attention to encourage attention to close positions.
394
+ Args:
395
+ length: an integer scalar.
396
+ Returns:
397
+ a Tensor with shape [1, 1, length, length]
398
+ """
399
+ r = torch.arange(length, dtype=torch.float32)
400
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
401
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
402
+
403
+
404
+ class FFN(nn.Module):
405
+ def __init__(
406
+ self,
407
+ in_channels,
408
+ out_channels,
409
+ filter_channels,
410
+ kernel_size,
411
+ p_dropout=0.0,
412
+ activation=None,
413
+ causal=False,
414
+ ):
415
+ super().__init__()
416
+ self.in_channels = in_channels
417
+ self.out_channels = out_channels
418
+ self.filter_channels = filter_channels
419
+ self.kernel_size = kernel_size
420
+ self.p_dropout = p_dropout
421
+ self.activation = activation
422
+ self.causal = causal
423
+
424
+ if causal:
425
+ self.padding = self._causal_padding
426
+ else:
427
+ self.padding = self._same_padding
428
+
429
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
430
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
431
+ self.drop = nn.Dropout(p_dropout)
432
+
433
+ def forward(self, x, x_mask):
434
+ x = self.conv_1(self.padding(x * x_mask))
435
+ if self.activation == "gelu":
436
+ x = x * torch.sigmoid(1.702 * x)
437
+ else:
438
+ x = torch.relu(x)
439
+ x = self.drop(x)
440
+ x = self.conv_2(self.padding(x * x_mask))
441
+ return x * x_mask
442
+
443
+ def _causal_padding(self, x):
444
+ if self.kernel_size == 1:
445
+ return x
446
+ pad_l = self.kernel_size - 1
447
+ pad_r = 0
448
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
449
+ x = F.pad(x, commons.convert_pad_shape(padding))
450
+ return x
451
+
452
+ def _same_padding(self, x):
453
+ if self.kernel_size == 1:
454
+ return x
455
+ pad_l = (self.kernel_size - 1) // 2
456
+ pad_r = self.kernel_size // 2
457
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
458
+ x = F.pad(x, commons.convert_pad_shape(padding))
459
+ return x
@@ -0,0 +1,160 @@
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ layer = pad_shape[::-1]
112
+ pad_shape = [item for sublist in layer for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141
+ path = path.unsqueeze(1).transpose(2, 3) * mask
142
+ return path
143
+
144
+
145
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
146
+ if isinstance(parameters, torch.Tensor):
147
+ parameters = [parameters]
148
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
149
+ norm_type = float(norm_type)
150
+ if clip_value is not None:
151
+ clip_value = float(clip_value)
152
+
153
+ total_norm = 0
154
+ for p in parameters:
155
+ param_norm = p.grad.data.norm(norm_type)
156
+ total_norm += param_norm.item() ** norm_type
157
+ if clip_value is not None:
158
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
159
+ total_norm = total_norm ** (1.0 / norm_type)
160
+ return total_norm
@@ -0,0 +1,94 @@
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 52,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0003,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 6,
14
+ "fp16_run": false,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 16384,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "skip_optimizer": true
22
+ },
23
+ "data": {
24
+ "training_files": "",
25
+ "validation_files": "",
26
+ "max_wav_value": 32768.0,
27
+ "sampling_rate": 44100,
28
+ "filter_length": 2048,
29
+ "hop_length": 512,
30
+ "win_length": 2048,
31
+ "n_mel_channels": 128,
32
+ "mel_fmin": 0.0,
33
+ "mel_fmax": null,
34
+ "add_blank": true,
35
+ "n_speakers": 256,
36
+ "cleaned_text": true,
37
+ "spk2id": {}
38
+ },
39
+ "model": {
40
+ "use_spk_conditioned_encoder": true,
41
+ "use_noise_scaled_mas": true,
42
+ "use_mel_posterior_encoder": false,
43
+ "use_duration_discriminator": true,
44
+ "inter_channels": 192,
45
+ "hidden_channels": 192,
46
+ "filter_channels": 768,
47
+ "n_heads": 2,
48
+ "n_layers": 6,
49
+ "n_layers_trans_flow": 3,
50
+ "kernel_size": 3,
51
+ "p_dropout": 0.1,
52
+ "resblock": "1",
53
+ "resblock_kernel_sizes": [
54
+ 3,
55
+ 7,
56
+ 11
57
+ ],
58
+ "resblock_dilation_sizes": [
59
+ [
60
+ 1,
61
+ 3,
62
+ 5
63
+ ],
64
+ [
65
+ 1,
66
+ 3,
67
+ 5
68
+ ],
69
+ [
70
+ 1,
71
+ 3,
72
+ 5
73
+ ]
74
+ ],
75
+ "upsample_rates": [
76
+ 8,
77
+ 8,
78
+ 2,
79
+ 2,
80
+ 2
81
+ ],
82
+ "upsample_initial_channel": 512,
83
+ "upsample_kernel_sizes": [
84
+ 16,
85
+ 16,
86
+ 8,
87
+ 2,
88
+ 2
89
+ ],
90
+ "n_layers_q": 3,
91
+ "use_spectral_norm": false,
92
+ "gin_channels": 256
93
+ }
94
+ }