xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1030 @@
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from melo import commons
7
+ from melo import modules
8
+ from melo import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+
13
+ from melo.commons import init_weights, get_padding
14
+ import melo.monotonic_align as monotonic_align
15
+
16
+
17
+ class DurationDiscriminator(nn.Module): # vits2
18
+ def __init__(
19
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
20
+ ):
21
+ super().__init__()
22
+ self.in_channels = in_channels
23
+ self.filter_channels = filter_channels
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = p_dropout
26
+ self.gin_channels = gin_channels
27
+
28
+ self.drop = nn.Dropout(p_dropout)
29
+ self.conv_1 = nn.Conv1d(
30
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
+ )
32
+ self.norm_1 = modules.LayerNorm(filter_channels)
33
+ self.conv_2 = nn.Conv1d(
34
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_2 = modules.LayerNorm(filter_channels)
37
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
38
+
39
+ self.pre_out_conv_1 = nn.Conv1d(
40
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
41
+ )
42
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
43
+ self.pre_out_conv_2 = nn.Conv1d(
44
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
+ )
46
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
47
+
48
+ if gin_channels != 0:
49
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
50
+
51
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
52
+
53
+ def forward_probability(self, x, x_mask, dur, g=None):
54
+ dur = self.dur_proj(dur)
55
+ x = torch.cat([x, dur], dim=1)
56
+ x = self.pre_out_conv_1(x * x_mask)
57
+ x = torch.relu(x)
58
+ x = self.pre_out_norm_1(x)
59
+ x = self.drop(x)
60
+ x = self.pre_out_conv_2(x * x_mask)
61
+ x = torch.relu(x)
62
+ x = self.pre_out_norm_2(x)
63
+ x = self.drop(x)
64
+ x = x * x_mask
65
+ x = x.transpose(1, 2)
66
+ output_prob = self.output_layer(x)
67
+ return output_prob
68
+
69
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
70
+ x = torch.detach(x)
71
+ if g is not None:
72
+ g = torch.detach(g)
73
+ x = x + self.cond(g)
74
+ x = self.conv_1(x * x_mask)
75
+ x = torch.relu(x)
76
+ x = self.norm_1(x)
77
+ x = self.drop(x)
78
+ x = self.conv_2(x * x_mask)
79
+ x = torch.relu(x)
80
+ x = self.norm_2(x)
81
+ x = self.drop(x)
82
+
83
+ output_probs = []
84
+ for dur in [dur_r, dur_hat]:
85
+ output_prob = self.forward_probability(x, x_mask, dur, g)
86
+ output_probs.append(output_prob)
87
+
88
+ return output_probs
89
+
90
+
91
+ class TransformerCouplingBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ n_heads,
98
+ n_layers,
99
+ kernel_size,
100
+ p_dropout,
101
+ n_flows=4,
102
+ gin_channels=0,
103
+ share_parameter=False,
104
+ ):
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.hidden_channels = hidden_channels
108
+ self.kernel_size = kernel_size
109
+ self.n_layers = n_layers
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.flows = nn.ModuleList()
114
+
115
+ self.wn = (
116
+ attentions.FFT(
117
+ hidden_channels,
118
+ filter_channels,
119
+ n_heads,
120
+ n_layers,
121
+ kernel_size,
122
+ p_dropout,
123
+ isflow=True,
124
+ gin_channels=self.gin_channels,
125
+ )
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(self, x, x_mask, g=None, reverse=False):
148
+ if not reverse:
149
+ for flow in self.flows:
150
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
+ else:
152
+ for flow in reversed(self.flows):
153
+ x = flow(x, x_mask, g=g, reverse=reverse)
154
+ return x
155
+
156
+
157
+ class StochasticDurationPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ filter_channels,
162
+ kernel_size,
163
+ p_dropout,
164
+ n_flows=4,
165
+ gin_channels=0,
166
+ ):
167
+ super().__init__()
168
+ filter_channels = in_channels # it needs to be removed from future version.
169
+ self.in_channels = in_channels
170
+ self.filter_channels = filter_channels
171
+ self.kernel_size = kernel_size
172
+ self.p_dropout = p_dropout
173
+ self.n_flows = n_flows
174
+ self.gin_channels = gin_channels
175
+
176
+ self.log_flow = modules.Log()
177
+ self.flows = nn.ModuleList()
178
+ self.flows.append(modules.ElementwiseAffine(2))
179
+ for i in range(n_flows):
180
+ self.flows.append(
181
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
+ )
183
+ self.flows.append(modules.Flip())
184
+
185
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
+ self.post_convs = modules.DDSConv(
188
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
+ )
190
+ self.post_flows = nn.ModuleList()
191
+ self.post_flows.append(modules.ElementwiseAffine(2))
192
+ for i in range(4):
193
+ self.post_flows.append(
194
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
+ )
196
+ self.post_flows.append(modules.Flip())
197
+
198
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
+ self.convs = modules.DDSConv(
201
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
+ )
203
+ if gin_channels != 0:
204
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
+
206
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
207
+ x = torch.detach(x)
208
+ x = self.pre(x)
209
+ if g is not None:
210
+ g = torch.detach(g)
211
+ x = x + self.cond(g)
212
+ x = self.convs(x, x_mask)
213
+ x = self.proj(x) * x_mask
214
+
215
+ if not reverse:
216
+ flows = self.flows
217
+ assert w is not None
218
+
219
+ logdet_tot_q = 0
220
+ h_w = self.post_pre(w)
221
+ h_w = self.post_convs(h_w, x_mask)
222
+ h_w = self.post_proj(h_w) * x_mask
223
+ e_q = (
224
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
225
+ * x_mask
226
+ )
227
+ z_q = e_q
228
+ for flow in self.post_flows:
229
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
230
+ logdet_tot_q += logdet_q
231
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
232
+ u = torch.sigmoid(z_u) * x_mask
233
+ z0 = (w - u) * x_mask
234
+ logdet_tot_q += torch.sum(
235
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
236
+ )
237
+ logq = (
238
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
239
+ - logdet_tot_q
240
+ )
241
+
242
+ logdet_tot = 0
243
+ z0, logdet = self.log_flow(z0, x_mask)
244
+ logdet_tot += logdet
245
+ z = torch.cat([z0, z1], 1)
246
+ for flow in flows:
247
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
248
+ logdet_tot = logdet_tot + logdet
249
+ nll = (
250
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
251
+ - logdet_tot
252
+ )
253
+ return nll + logq # [b]
254
+ else:
255
+ flows = list(reversed(self.flows))
256
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
257
+ z = (
258
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
259
+ * noise_scale
260
+ )
261
+ for flow in flows:
262
+ z = flow(z, x_mask, g=x, reverse=reverse)
263
+ z0, z1 = torch.split(z, [1, 1], 1)
264
+ logw = z0
265
+ return logw
266
+
267
+
268
+ class DurationPredictor(nn.Module):
269
+ def __init__(
270
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
271
+ ):
272
+ super().__init__()
273
+
274
+ self.in_channels = in_channels
275
+ self.filter_channels = filter_channels
276
+ self.kernel_size = kernel_size
277
+ self.p_dropout = p_dropout
278
+ self.gin_channels = gin_channels
279
+
280
+ self.drop = nn.Dropout(p_dropout)
281
+ self.conv_1 = nn.Conv1d(
282
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
283
+ )
284
+ self.norm_1 = modules.LayerNorm(filter_channels)
285
+ self.conv_2 = nn.Conv1d(
286
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
287
+ )
288
+ self.norm_2 = modules.LayerNorm(filter_channels)
289
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
290
+
291
+ if gin_channels != 0:
292
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
293
+
294
+ def forward(self, x, x_mask, g=None):
295
+ x = torch.detach(x)
296
+ if g is not None:
297
+ g = torch.detach(g)
298
+ x = x + self.cond(g)
299
+ x = self.conv_1(x * x_mask)
300
+ x = torch.relu(x)
301
+ x = self.norm_1(x)
302
+ x = self.drop(x)
303
+ x = self.conv_2(x * x_mask)
304
+ x = torch.relu(x)
305
+ x = self.norm_2(x)
306
+ x = self.drop(x)
307
+ x = self.proj(x * x_mask)
308
+ return x * x_mask
309
+
310
+
311
+ class TextEncoder(nn.Module):
312
+ def __init__(
313
+ self,
314
+ n_vocab,
315
+ out_channels,
316
+ hidden_channels,
317
+ filter_channels,
318
+ n_heads,
319
+ n_layers,
320
+ kernel_size,
321
+ p_dropout,
322
+ gin_channels=0,
323
+ num_languages=None,
324
+ num_tones=None,
325
+ ):
326
+ super().__init__()
327
+ if num_languages is None:
328
+ from text import num_languages
329
+ if num_tones is None:
330
+ from text import num_tones
331
+ self.n_vocab = n_vocab
332
+ self.out_channels = out_channels
333
+ self.hidden_channels = hidden_channels
334
+ self.filter_channels = filter_channels
335
+ self.n_heads = n_heads
336
+ self.n_layers = n_layers
337
+ self.kernel_size = kernel_size
338
+ self.p_dropout = p_dropout
339
+ self.gin_channels = gin_channels
340
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
341
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
343
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
344
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
345
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
346
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
347
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
348
+
349
+ self.encoder = attentions.Encoder(
350
+ hidden_channels,
351
+ filter_channels,
352
+ n_heads,
353
+ n_layers,
354
+ kernel_size,
355
+ p_dropout,
356
+ gin_channels=self.gin_channels,
357
+ )
358
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
359
+
360
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
361
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
362
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
363
+ x = (
364
+ self.emb(x)
365
+ + self.tone_emb(tone)
366
+ + self.language_emb(language)
367
+ + bert_emb
368
+ + ja_bert_emb
369
+ ) * math.sqrt(
370
+ self.hidden_channels
371
+ ) # [b, t, h]
372
+ x = torch.transpose(x, 1, -1) # [b, h, t]
373
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
374
+ x.dtype
375
+ )
376
+
377
+ x = self.encoder(x * x_mask, x_mask, g=g)
378
+ stats = self.proj(x) * x_mask
379
+
380
+ m, logs = torch.split(stats, self.out_channels, dim=1)
381
+ return x, m, logs, x_mask
382
+
383
+
384
+ class ResidualCouplingBlock(nn.Module):
385
+ def __init__(
386
+ self,
387
+ channels,
388
+ hidden_channels,
389
+ kernel_size,
390
+ dilation_rate,
391
+ n_layers,
392
+ n_flows=4,
393
+ gin_channels=0,
394
+ ):
395
+ super().__init__()
396
+ self.channels = channels
397
+ self.hidden_channels = hidden_channels
398
+ self.kernel_size = kernel_size
399
+ self.dilation_rate = dilation_rate
400
+ self.n_layers = n_layers
401
+ self.n_flows = n_flows
402
+ self.gin_channels = gin_channels
403
+
404
+ self.flows = nn.ModuleList()
405
+ for i in range(n_flows):
406
+ self.flows.append(
407
+ modules.ResidualCouplingLayer(
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ gin_channels=gin_channels,
414
+ mean_only=True,
415
+ )
416
+ )
417
+ self.flows.append(modules.Flip())
418
+
419
+ def forward(self, x, x_mask, g=None, reverse=False):
420
+ if not reverse:
421
+ for flow in self.flows:
422
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
423
+ else:
424
+ for flow in reversed(self.flows):
425
+ x = flow(x, x_mask, g=g, reverse=reverse)
426
+ return x
427
+
428
+
429
+ class PosteriorEncoder(nn.Module):
430
+ def __init__(
431
+ self,
432
+ in_channels,
433
+ out_channels,
434
+ hidden_channels,
435
+ kernel_size,
436
+ dilation_rate,
437
+ n_layers,
438
+ gin_channels=0,
439
+ ):
440
+ super().__init__()
441
+ self.in_channels = in_channels
442
+ self.out_channels = out_channels
443
+ self.hidden_channels = hidden_channels
444
+ self.kernel_size = kernel_size
445
+ self.dilation_rate = dilation_rate
446
+ self.n_layers = n_layers
447
+ self.gin_channels = gin_channels
448
+
449
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
450
+ self.enc = modules.WN(
451
+ hidden_channels,
452
+ kernel_size,
453
+ dilation_rate,
454
+ n_layers,
455
+ gin_channels=gin_channels,
456
+ )
457
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
458
+
459
+ def forward(self, x, x_lengths, g=None, tau=1.0):
460
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
461
+ x.dtype
462
+ )
463
+ x = self.pre(x) * x_mask
464
+ x = self.enc(x, x_mask, g=g)
465
+ stats = self.proj(x) * x_mask
466
+ m, logs = torch.split(stats, self.out_channels, dim=1)
467
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
468
+ return z, m, logs, x_mask
469
+
470
+
471
+ class Generator(torch.nn.Module):
472
+ def __init__(
473
+ self,
474
+ initial_channel,
475
+ resblock,
476
+ resblock_kernel_sizes,
477
+ resblock_dilation_sizes,
478
+ upsample_rates,
479
+ upsample_initial_channel,
480
+ upsample_kernel_sizes,
481
+ gin_channels=0,
482
+ ):
483
+ super(Generator, self).__init__()
484
+ self.num_kernels = len(resblock_kernel_sizes)
485
+ self.num_upsamples = len(upsample_rates)
486
+ self.conv_pre = Conv1d(
487
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
488
+ )
489
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
490
+
491
+ self.ups = nn.ModuleList()
492
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
493
+ self.ups.append(
494
+ weight_norm(
495
+ ConvTranspose1d(
496
+ upsample_initial_channel // (2**i),
497
+ upsample_initial_channel // (2 ** (i + 1)),
498
+ k,
499
+ u,
500
+ padding=(k - u) // 2,
501
+ )
502
+ )
503
+ )
504
+
505
+ self.resblocks = nn.ModuleList()
506
+ for i in range(len(self.ups)):
507
+ ch = upsample_initial_channel // (2 ** (i + 1))
508
+ for j, (k, d) in enumerate(
509
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
510
+ ):
511
+ self.resblocks.append(resblock(ch, k, d))
512
+
513
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
514
+ self.ups.apply(init_weights)
515
+
516
+ if gin_channels != 0:
517
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
518
+
519
+ def forward(self, x, g=None):
520
+ x = self.conv_pre(x)
521
+ if g is not None:
522
+ x = x + self.cond(g)
523
+
524
+ for i in range(self.num_upsamples):
525
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
526
+ x = self.ups[i](x)
527
+ xs = None
528
+ for j in range(self.num_kernels):
529
+ if xs is None:
530
+ xs = self.resblocks[i * self.num_kernels + j](x)
531
+ else:
532
+ xs += self.resblocks[i * self.num_kernels + j](x)
533
+ x = xs / self.num_kernels
534
+ x = F.leaky_relu(x)
535
+ x = self.conv_post(x)
536
+ x = torch.tanh(x)
537
+
538
+ return x
539
+
540
+ def remove_weight_norm(self):
541
+ print("Removing weight norm...")
542
+ for layer in self.ups:
543
+ remove_weight_norm(layer)
544
+ for layer in self.resblocks:
545
+ layer.remove_weight_norm()
546
+
547
+
548
+ class DiscriminatorP(torch.nn.Module):
549
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
550
+ super(DiscriminatorP, self).__init__()
551
+ self.period = period
552
+ self.use_spectral_norm = use_spectral_norm
553
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
554
+ self.convs = nn.ModuleList(
555
+ [
556
+ norm_f(
557
+ Conv2d(
558
+ 1,
559
+ 32,
560
+ (kernel_size, 1),
561
+ (stride, 1),
562
+ padding=(get_padding(kernel_size, 1), 0),
563
+ )
564
+ ),
565
+ norm_f(
566
+ Conv2d(
567
+ 32,
568
+ 128,
569
+ (kernel_size, 1),
570
+ (stride, 1),
571
+ padding=(get_padding(kernel_size, 1), 0),
572
+ )
573
+ ),
574
+ norm_f(
575
+ Conv2d(
576
+ 128,
577
+ 512,
578
+ (kernel_size, 1),
579
+ (stride, 1),
580
+ padding=(get_padding(kernel_size, 1), 0),
581
+ )
582
+ ),
583
+ norm_f(
584
+ Conv2d(
585
+ 512,
586
+ 1024,
587
+ (kernel_size, 1),
588
+ (stride, 1),
589
+ padding=(get_padding(kernel_size, 1), 0),
590
+ )
591
+ ),
592
+ norm_f(
593
+ Conv2d(
594
+ 1024,
595
+ 1024,
596
+ (kernel_size, 1),
597
+ 1,
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ ]
602
+ )
603
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
604
+
605
+ def forward(self, x):
606
+ fmap = []
607
+
608
+ # 1d to 2d
609
+ b, c, t = x.shape
610
+ if t % self.period != 0: # pad first
611
+ n_pad = self.period - (t % self.period)
612
+ x = F.pad(x, (0, n_pad), "reflect")
613
+ t = t + n_pad
614
+ x = x.view(b, c, t // self.period, self.period)
615
+
616
+ for layer in self.convs:
617
+ x = layer(x)
618
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
619
+ fmap.append(x)
620
+ x = self.conv_post(x)
621
+ fmap.append(x)
622
+ x = torch.flatten(x, 1, -1)
623
+
624
+ return x, fmap
625
+
626
+
627
+ class DiscriminatorS(torch.nn.Module):
628
+ def __init__(self, use_spectral_norm=False):
629
+ super(DiscriminatorS, self).__init__()
630
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
631
+ self.convs = nn.ModuleList(
632
+ [
633
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
634
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
635
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
636
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
637
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
638
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
639
+ ]
640
+ )
641
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
642
+
643
+ def forward(self, x):
644
+ fmap = []
645
+
646
+ for layer in self.convs:
647
+ x = layer(x)
648
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
649
+ fmap.append(x)
650
+ x = self.conv_post(x)
651
+ fmap.append(x)
652
+ x = torch.flatten(x, 1, -1)
653
+
654
+ return x, fmap
655
+
656
+
657
+ class MultiPeriodDiscriminator(torch.nn.Module):
658
+ def __init__(self, use_spectral_norm=False):
659
+ super(MultiPeriodDiscriminator, self).__init__()
660
+ periods = [2, 3, 5, 7, 11]
661
+
662
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
663
+ discs = discs + [
664
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
665
+ ]
666
+ self.discriminators = nn.ModuleList(discs)
667
+
668
+ def forward(self, y, y_hat):
669
+ y_d_rs = []
670
+ y_d_gs = []
671
+ fmap_rs = []
672
+ fmap_gs = []
673
+ for i, d in enumerate(self.discriminators):
674
+ y_d_r, fmap_r = d(y)
675
+ y_d_g, fmap_g = d(y_hat)
676
+ y_d_rs.append(y_d_r)
677
+ y_d_gs.append(y_d_g)
678
+ fmap_rs.append(fmap_r)
679
+ fmap_gs.append(fmap_g)
680
+
681
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
682
+
683
+
684
+ class ReferenceEncoder(nn.Module):
685
+ """
686
+ inputs --- [N, Ty/r, n_mels*r] mels
687
+ outputs --- [N, ref_enc_gru_size]
688
+ """
689
+
690
+ def __init__(self, spec_channels, gin_channels=0, layernorm=False):
691
+ super().__init__()
692
+ self.spec_channels = spec_channels
693
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
694
+ K = len(ref_enc_filters)
695
+ filters = [1] + ref_enc_filters
696
+ convs = [
697
+ weight_norm(
698
+ nn.Conv2d(
699
+ in_channels=filters[i],
700
+ out_channels=filters[i + 1],
701
+ kernel_size=(3, 3),
702
+ stride=(2, 2),
703
+ padding=(1, 1),
704
+ )
705
+ )
706
+ for i in range(K)
707
+ ]
708
+ self.convs = nn.ModuleList(convs)
709
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
710
+
711
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
712
+ self.gru = nn.GRU(
713
+ input_size=ref_enc_filters[-1] * out_channels,
714
+ hidden_size=256 // 2,
715
+ batch_first=True,
716
+ )
717
+ self.proj = nn.Linear(128, gin_channels)
718
+ if layernorm:
719
+ self.layernorm = nn.LayerNorm(self.spec_channels)
720
+ print('[Ref Enc]: using layer norm')
721
+ else:
722
+ self.layernorm = None
723
+
724
+ def forward(self, inputs, mask=None):
725
+ N = inputs.size(0)
726
+
727
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
728
+ if self.layernorm is not None:
729
+ out = self.layernorm(out)
730
+
731
+ for conv in self.convs:
732
+ out = conv(out)
733
+ # out = wn(out)
734
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
735
+
736
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
737
+ T = out.size(1)
738
+ N = out.size(0)
739
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
740
+
741
+ self.gru.flatten_parameters()
742
+ memory, out = self.gru(out) # out --- [1, N, 128]
743
+
744
+ return self.proj(out.squeeze(0))
745
+
746
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
747
+ for i in range(n_convs):
748
+ L = (L - kernel_size + 2 * pad) // stride + 1
749
+ return L
750
+
751
+
752
+ class SynthesizerTrn(nn.Module):
753
+ """
754
+ Synthesizer for Training
755
+ """
756
+
757
+ def __init__(
758
+ self,
759
+ n_vocab,
760
+ spec_channels,
761
+ segment_size,
762
+ inter_channels,
763
+ hidden_channels,
764
+ filter_channels,
765
+ n_heads,
766
+ n_layers,
767
+ kernel_size,
768
+ p_dropout,
769
+ resblock,
770
+ resblock_kernel_sizes,
771
+ resblock_dilation_sizes,
772
+ upsample_rates,
773
+ upsample_initial_channel,
774
+ upsample_kernel_sizes,
775
+ n_speakers=256,
776
+ gin_channels=256,
777
+ use_sdp=True,
778
+ n_flow_layer=4,
779
+ n_layers_trans_flow=6,
780
+ flow_share_parameter=False,
781
+ use_transformer_flow=True,
782
+ use_vc=False,
783
+ num_languages=None,
784
+ num_tones=None,
785
+ norm_refenc=False,
786
+ **kwargs
787
+ ):
788
+ super().__init__()
789
+ self.n_vocab = n_vocab
790
+ self.spec_channels = spec_channels
791
+ self.inter_channels = inter_channels
792
+ self.hidden_channels = hidden_channels
793
+ self.filter_channels = filter_channels
794
+ self.n_heads = n_heads
795
+ self.n_layers = n_layers
796
+ self.kernel_size = kernel_size
797
+ self.p_dropout = p_dropout
798
+ self.resblock = resblock
799
+ self.resblock_kernel_sizes = resblock_kernel_sizes
800
+ self.resblock_dilation_sizes = resblock_dilation_sizes
801
+ self.upsample_rates = upsample_rates
802
+ self.upsample_initial_channel = upsample_initial_channel
803
+ self.upsample_kernel_sizes = upsample_kernel_sizes
804
+ self.segment_size = segment_size
805
+ self.n_speakers = n_speakers
806
+ self.gin_channels = gin_channels
807
+ self.n_layers_trans_flow = n_layers_trans_flow
808
+ self.use_spk_conditioned_encoder = kwargs.get(
809
+ "use_spk_conditioned_encoder", True
810
+ )
811
+ self.use_sdp = use_sdp
812
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
813
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
814
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
815
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
816
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
817
+ self.enc_gin_channels = gin_channels
818
+ else:
819
+ self.enc_gin_channels = 0
820
+ self.enc_p = TextEncoder(
821
+ n_vocab,
822
+ inter_channels,
823
+ hidden_channels,
824
+ filter_channels,
825
+ n_heads,
826
+ n_layers,
827
+ kernel_size,
828
+ p_dropout,
829
+ gin_channels=self.enc_gin_channels,
830
+ num_languages=num_languages,
831
+ num_tones=num_tones,
832
+ )
833
+ self.dec = Generator(
834
+ inter_channels,
835
+ resblock,
836
+ resblock_kernel_sizes,
837
+ resblock_dilation_sizes,
838
+ upsample_rates,
839
+ upsample_initial_channel,
840
+ upsample_kernel_sizes,
841
+ gin_channels=gin_channels,
842
+ )
843
+ self.enc_q = PosteriorEncoder(
844
+ spec_channels,
845
+ inter_channels,
846
+ hidden_channels,
847
+ 5,
848
+ 1,
849
+ 16,
850
+ gin_channels=gin_channels,
851
+ )
852
+ if use_transformer_flow:
853
+ self.flow = TransformerCouplingBlock(
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers_trans_flow,
859
+ 5,
860
+ p_dropout,
861
+ n_flow_layer,
862
+ gin_channels=gin_channels,
863
+ share_parameter=flow_share_parameter,
864
+ )
865
+ else:
866
+ self.flow = ResidualCouplingBlock(
867
+ inter_channels,
868
+ hidden_channels,
869
+ 5,
870
+ 1,
871
+ n_flow_layer,
872
+ gin_channels=gin_channels,
873
+ )
874
+ self.sdp = StochasticDurationPredictor(
875
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
876
+ )
877
+ self.dp = DurationPredictor(
878
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
879
+ )
880
+
881
+ if n_speakers > 0:
882
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
883
+ else:
884
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
885
+ self.use_vc = use_vc
886
+
887
+
888
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
889
+ if self.n_speakers > 0:
890
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
891
+ else:
892
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
893
+ if self.use_vc:
894
+ g_p = None
895
+ else:
896
+ g_p = g
897
+ x, m_p, logs_p, x_mask = self.enc_p(
898
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
899
+ )
900
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
901
+ z_p = self.flow(z, y_mask, g=g)
902
+
903
+ with torch.no_grad():
904
+ # negative cross-entropy
905
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
906
+ neg_cent1 = torch.sum(
907
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
908
+ ) # [b, 1, t_s]
909
+ neg_cent2 = torch.matmul(
910
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
911
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
912
+ neg_cent3 = torch.matmul(
913
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
914
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
915
+ neg_cent4 = torch.sum(
916
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
917
+ ) # [b, 1, t_s]
918
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
919
+ if self.use_noise_scaled_mas:
920
+ epsilon = (
921
+ torch.std(neg_cent)
922
+ * torch.randn_like(neg_cent)
923
+ * self.current_mas_noise_scale
924
+ )
925
+ neg_cent = neg_cent + epsilon
926
+
927
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
928
+ attn = (
929
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
930
+ .unsqueeze(1)
931
+ .detach()
932
+ )
933
+
934
+ w = attn.sum(2)
935
+
936
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
937
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
938
+
939
+ logw_ = torch.log(w + 1e-6) * x_mask
940
+ logw = self.dp(x, x_mask, g=g)
941
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
942
+ x_mask
943
+ ) # for averaging
944
+
945
+ l_length = l_length_dp + l_length_sdp
946
+
947
+ # expand prior
948
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
949
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
950
+
951
+ z_slice, ids_slice = commons.rand_slice_segments(
952
+ z, y_lengths, self.segment_size
953
+ )
954
+ o = self.dec(z_slice, g=g)
955
+ return (
956
+ o,
957
+ l_length,
958
+ attn,
959
+ ids_slice,
960
+ x_mask,
961
+ y_mask,
962
+ (z, z_p, m_p, logs_p, m_q, logs_q),
963
+ (x, logw, logw_),
964
+ )
965
+
966
+ def infer(
967
+ self,
968
+ x,
969
+ x_lengths,
970
+ sid,
971
+ tone,
972
+ language,
973
+ bert,
974
+ ja_bert,
975
+ noise_scale=0.667,
976
+ length_scale=1,
977
+ noise_scale_w=0.8,
978
+ max_len=None,
979
+ sdp_ratio=0,
980
+ y=None,
981
+ g=None,
982
+ ):
983
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
984
+ # g = self.gst(y)
985
+ if g is None:
986
+ if self.n_speakers > 0:
987
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
988
+ else:
989
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
990
+ if self.use_vc:
991
+ g_p = None
992
+ else:
993
+ g_p = g
994
+ x, m_p, logs_p, x_mask = self.enc_p(
995
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
996
+ )
997
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
998
+ sdp_ratio
999
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1000
+ w = torch.exp(logw) * x_mask * length_scale
1001
+
1002
+ w_ceil = torch.ceil(w)
1003
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1004
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1005
+ x_mask.dtype
1006
+ )
1007
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1008
+ attn = commons.generate_path(w_ceil, attn_mask)
1009
+
1010
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1011
+ 1, 2
1012
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1013
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1014
+ 1, 2
1015
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1016
+
1017
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1018
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1019
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1020
+ # print('max/min of o:', o.max(), o.min())
1021
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
1022
+
1023
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
1024
+ g_src = sid_src
1025
+ g_tgt = sid_tgt
1026
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
1027
+ z_p = self.flow(z, y_mask, g=g_src)
1028
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1029
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
1030
+ return o_hat, y_mask, (z, z_p, z_hat)