phoonnx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. phoonnx/__init__.py +0 -0
  2. phoonnx/config.py +490 -0
  3. phoonnx/locale/ca/phonetic_spellings.txt +2 -0
  4. phoonnx/locale/en/phonetic_spellings.txt +1 -0
  5. phoonnx/locale/gl/phonetic_spellings.txt +2 -0
  6. phoonnx/locale/pt/phonetic_spellings.txt +2 -0
  7. phoonnx/phoneme_ids.py +453 -0
  8. phoonnx/phonemizers/__init__.py +45 -0
  9. phoonnx/phonemizers/ar.py +42 -0
  10. phoonnx/phonemizers/base.py +216 -0
  11. phoonnx/phonemizers/en.py +250 -0
  12. phoonnx/phonemizers/fa.py +46 -0
  13. phoonnx/phonemizers/gl.py +142 -0
  14. phoonnx/phonemizers/he.py +67 -0
  15. phoonnx/phonemizers/ja.py +119 -0
  16. phoonnx/phonemizers/ko.py +97 -0
  17. phoonnx/phonemizers/mul.py +606 -0
  18. phoonnx/phonemizers/vi.py +44 -0
  19. phoonnx/phonemizers/zh.py +308 -0
  20. phoonnx/thirdparty/__init__.py +0 -0
  21. phoonnx/thirdparty/arpa2ipa.py +249 -0
  22. phoonnx/thirdparty/cotovia/cotovia_aarch64 +0 -0
  23. phoonnx/thirdparty/cotovia/cotovia_x86_64 +0 -0
  24. phoonnx/thirdparty/hangul2ipa.py +783 -0
  25. phoonnx/thirdparty/ko_tables/aspiration.csv +20 -0
  26. phoonnx/thirdparty/ko_tables/assimilation.csv +31 -0
  27. phoonnx/thirdparty/ko_tables/double_coda.csv +17 -0
  28. phoonnx/thirdparty/ko_tables/hanja.tsv +8525 -0
  29. phoonnx/thirdparty/ko_tables/ipa.csv +22 -0
  30. phoonnx/thirdparty/ko_tables/neutralization.csv +11 -0
  31. phoonnx/thirdparty/ko_tables/tensification.csv +56 -0
  32. phoonnx/thirdparty/ko_tables/yale.csv +22 -0
  33. phoonnx/thirdparty/kog2p/__init__.py +385 -0
  34. phoonnx/thirdparty/kog2p/rulebook.txt +212 -0
  35. phoonnx/thirdparty/mantoq/__init__.py +67 -0
  36. phoonnx/thirdparty/mantoq/buck/__init__.py +0 -0
  37. phoonnx/thirdparty/mantoq/buck/phonetise_buckwalter.py +569 -0
  38. phoonnx/thirdparty/mantoq/buck/symbols.py +64 -0
  39. phoonnx/thirdparty/mantoq/buck/tokenization.py +105 -0
  40. phoonnx/thirdparty/mantoq/num2words.py +37 -0
  41. phoonnx/thirdparty/mantoq/pyarabic/__init__.py +12 -0
  42. phoonnx/thirdparty/mantoq/pyarabic/arabrepr.py +64 -0
  43. phoonnx/thirdparty/mantoq/pyarabic/araby.py +1647 -0
  44. phoonnx/thirdparty/mantoq/pyarabic/named_const.py +227 -0
  45. phoonnx/thirdparty/mantoq/pyarabic/normalize.py +161 -0
  46. phoonnx/thirdparty/mantoq/pyarabic/number.py +826 -0
  47. phoonnx/thirdparty/mantoq/pyarabic/number_const.py +1704 -0
  48. phoonnx/thirdparty/mantoq/pyarabic/stack.py +52 -0
  49. phoonnx/thirdparty/mantoq/pyarabic/trans.py +517 -0
  50. phoonnx/thirdparty/mantoq/unicode_symbol2label.py +4173 -0
  51. phoonnx/thirdparty/tashkeel/LICENSE +22 -0
  52. phoonnx/thirdparty/tashkeel/SOURCE +1 -0
  53. phoonnx/thirdparty/tashkeel/__init__.py +212 -0
  54. phoonnx/thirdparty/tashkeel/hint_id_map.json +18 -0
  55. phoonnx/thirdparty/tashkeel/input_id_map.json +56 -0
  56. phoonnx/thirdparty/tashkeel/model.onnx +0 -0
  57. phoonnx/thirdparty/tashkeel/target_id_map.json +17 -0
  58. phoonnx/thirdparty/zh_num.py +238 -0
  59. phoonnx/util.py +705 -0
  60. phoonnx/version.py +6 -0
  61. phoonnx/voice.py +521 -0
  62. phoonnx-0.0.0.dist-info/METADATA +255 -0
  63. phoonnx-0.0.0.dist-info/RECORD +86 -0
  64. phoonnx-0.0.0.dist-info/WHEEL +5 -0
  65. phoonnx-0.0.0.dist-info/top_level.txt +2 -0
  66. phoonnx_train/__main__.py +151 -0
  67. phoonnx_train/export_onnx.py +109 -0
  68. phoonnx_train/norm_audio/__init__.py +92 -0
  69. phoonnx_train/norm_audio/trim.py +54 -0
  70. phoonnx_train/norm_audio/vad.py +54 -0
  71. phoonnx_train/preprocess.py +420 -0
  72. phoonnx_train/vits/__init__.py +0 -0
  73. phoonnx_train/vits/attentions.py +427 -0
  74. phoonnx_train/vits/commons.py +147 -0
  75. phoonnx_train/vits/config.py +330 -0
  76. phoonnx_train/vits/dataset.py +214 -0
  77. phoonnx_train/vits/lightning.py +352 -0
  78. phoonnx_train/vits/losses.py +58 -0
  79. phoonnx_train/vits/mel_processing.py +139 -0
  80. phoonnx_train/vits/models.py +732 -0
  81. phoonnx_train/vits/modules.py +527 -0
  82. phoonnx_train/vits/monotonic_align/__init__.py +20 -0
  83. phoonnx_train/vits/monotonic_align/setup.py +13 -0
  84. phoonnx_train/vits/transforms.py +212 -0
  85. phoonnx_train/vits/utils.py +16 -0
  86. phoonnx_train/vits/wavfile.py +860 -0
@@ -0,0 +1,732 @@
1
+ import math
2
+ import typing
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ from . import attentions, commons, modules, monotonic_align
11
+ from .commons import get_padding, init_weights
12
+
13
+
14
+ class StochasticDurationPredictor(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ filter_channels: int,
19
+ kernel_size: int,
20
+ p_dropout: float,
21
+ n_flows: int = 4,
22
+ gin_channels: int = 0,
23
+ ):
24
+ super().__init__()
25
+ filter_channels = in_channels # it needs to be removed from future version.
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.log_flow = modules.Log()
34
+ self.flows = nn.ModuleList()
35
+ self.flows.append(modules.ElementwiseAffine(2))
36
+ for i in range(n_flows):
37
+ self.flows.append(
38
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
39
+ )
40
+ self.flows.append(modules.Flip())
41
+
42
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
43
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
44
+ self.post_convs = modules.DDSConv(
45
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
46
+ )
47
+ self.post_flows = nn.ModuleList()
48
+ self.post_flows.append(modules.ElementwiseAffine(2))
49
+ for i in range(4):
50
+ self.post_flows.append(
51
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
52
+ )
53
+ self.post_flows.append(modules.Flip())
54
+
55
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
56
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
57
+ self.convs = modules.DDSConv(
58
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
59
+ )
60
+ if gin_channels != 0:
61
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
62
+
63
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
64
+ x = torch.detach(x)
65
+ x = self.pre(x)
66
+ if g is not None:
67
+ g = torch.detach(g)
68
+ x = x + self.cond(g)
69
+ x = self.convs(x, x_mask)
70
+ x = self.proj(x) * x_mask
71
+
72
+ if not reverse:
73
+ flows = self.flows
74
+ assert w is not None
75
+
76
+ logdet_tot_q = 0
77
+ h_w = self.post_pre(w)
78
+ h_w = self.post_convs(h_w, x_mask)
79
+ h_w = self.post_proj(h_w) * x_mask
80
+ e_q = torch.randn(w.size(0), 2, w.size(2)).type_as(x) * x_mask
81
+ z_q = e_q
82
+ for flow in self.post_flows:
83
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
84
+ logdet_tot_q += logdet_q
85
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
86
+ u = torch.sigmoid(z_u) * x_mask
87
+ z0 = (w - u) * x_mask
88
+ logdet_tot_q += torch.sum(
89
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
90
+ )
91
+ logq = (
92
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
93
+ - logdet_tot_q
94
+ )
95
+
96
+ logdet_tot = 0
97
+ z0, logdet = self.log_flow(z0, x_mask)
98
+ logdet_tot += logdet
99
+ z = torch.cat([z0, z1], 1)
100
+ for flow in flows:
101
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
102
+ logdet_tot = logdet_tot + logdet
103
+ nll = (
104
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
105
+ - logdet_tot
106
+ )
107
+ return nll + logq # [b]
108
+ else:
109
+ flows = list(reversed(self.flows))
110
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
111
+ z = torch.randn(x.size(0), 2, x.size(2)).type_as(x) * noise_scale
112
+
113
+ for flow in flows:
114
+ z = flow(z, x_mask, g=x, reverse=reverse)
115
+ z0, z1 = torch.split(z, [1, 1], 1)
116
+ logw = z0
117
+ return logw
118
+
119
+
120
+ class DurationPredictor(nn.Module):
121
+ def __init__(
122
+ self,
123
+ in_channels: int,
124
+ filter_channels: int,
125
+ kernel_size: int,
126
+ p_dropout: float,
127
+ gin_channels: int = 0,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.in_channels = in_channels
132
+ self.filter_channels = filter_channels
133
+ self.kernel_size = kernel_size
134
+ self.p_dropout = p_dropout
135
+ self.gin_channels = gin_channels
136
+
137
+ self.drop = nn.Dropout(p_dropout)
138
+ self.conv_1 = nn.Conv1d(
139
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
140
+ )
141
+ self.norm_1 = modules.LayerNorm(filter_channels)
142
+ self.conv_2 = nn.Conv1d(
143
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
144
+ )
145
+ self.norm_2 = modules.LayerNorm(filter_channels)
146
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
147
+
148
+ if gin_channels != 0:
149
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
150
+
151
+ def forward(self, x, x_mask, g=None):
152
+ x = torch.detach(x)
153
+ if g is not None:
154
+ g = torch.detach(g)
155
+ x = x + self.cond(g)
156
+ x = self.conv_1(x * x_mask)
157
+ x = torch.relu(x)
158
+ x = self.norm_1(x)
159
+ x = self.drop(x)
160
+ x = self.conv_2(x * x_mask)
161
+ x = torch.relu(x)
162
+ x = self.norm_2(x)
163
+ x = self.drop(x)
164
+ x = self.proj(x * x_mask)
165
+ return x * x_mask
166
+
167
+
168
+ class TextEncoder(nn.Module):
169
+ def __init__(
170
+ self,
171
+ n_vocab: int,
172
+ out_channels: int,
173
+ hidden_channels: int,
174
+ filter_channels: int,
175
+ n_heads: int,
176
+ n_layers: int,
177
+ kernel_size: int,
178
+ p_dropout: float,
179
+ ):
180
+ super().__init__()
181
+ self.n_vocab = n_vocab
182
+ self.out_channels = out_channels
183
+ self.hidden_channels = hidden_channels
184
+ self.filter_channels = filter_channels
185
+ self.n_heads = n_heads
186
+ self.n_layers = n_layers
187
+ self.kernel_size = kernel_size
188
+ self.p_dropout = p_dropout
189
+
190
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
191
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
192
+
193
+ self.encoder = attentions.Encoder(
194
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
195
+ )
196
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
197
+
198
+ def forward(self, x, x_lengths):
199
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
200
+ x = torch.transpose(x, 1, -1) # [b, h, t]
201
+ x_mask = torch.unsqueeze(
202
+ commons.sequence_mask(x_lengths, x.size(2)), 1
203
+ ).type_as(x)
204
+
205
+ x = self.encoder(x * x_mask, x_mask)
206
+ stats = self.proj(x) * x_mask
207
+
208
+ m, logs = torch.split(stats, self.out_channels, dim=1)
209
+ return x, m, logs, x_mask
210
+
211
+
212
+ class ResidualCouplingBlock(nn.Module):
213
+ def __init__(
214
+ self,
215
+ channels: int,
216
+ hidden_channels: int,
217
+ kernel_size: int,
218
+ dilation_rate: int,
219
+ n_layers: int,
220
+ n_flows: int = 4,
221
+ gin_channels: int = 0,
222
+ ):
223
+ super().__init__()
224
+ self.channels = channels
225
+ self.hidden_channels = hidden_channels
226
+ self.kernel_size = kernel_size
227
+ self.dilation_rate = dilation_rate
228
+ self.n_layers = n_layers
229
+ self.n_flows = n_flows
230
+ self.gin_channels = gin_channels
231
+
232
+ self.flows = nn.ModuleList()
233
+ for i in range(n_flows):
234
+ self.flows.append(
235
+ modules.ResidualCouplingLayer(
236
+ channels,
237
+ hidden_channels,
238
+ kernel_size,
239
+ dilation_rate,
240
+ n_layers,
241
+ gin_channels=gin_channels,
242
+ mean_only=True,
243
+ )
244
+ )
245
+ self.flows.append(modules.Flip())
246
+
247
+ def forward(self, x, x_mask, g=None, reverse=False):
248
+ if not reverse:
249
+ for flow in self.flows:
250
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
251
+ else:
252
+ for flow in reversed(self.flows):
253
+ x = flow(x, x_mask, g=g, reverse=reverse)
254
+ return x
255
+
256
+
257
+ class PosteriorEncoder(nn.Module):
258
+ def __init__(
259
+ self,
260
+ in_channels: int,
261
+ out_channels: int,
262
+ hidden_channels: int,
263
+ kernel_size: int,
264
+ dilation_rate: int,
265
+ n_layers: int,
266
+ gin_channels: int = 0,
267
+ ):
268
+ super().__init__()
269
+ self.in_channels = in_channels
270
+ self.out_channels = out_channels
271
+ self.hidden_channels = hidden_channels
272
+ self.kernel_size = kernel_size
273
+ self.dilation_rate = dilation_rate
274
+ self.n_layers = n_layers
275
+ self.gin_channels = gin_channels
276
+
277
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
278
+ self.enc = modules.WN(
279
+ hidden_channels,
280
+ kernel_size,
281
+ dilation_rate,
282
+ n_layers,
283
+ gin_channels=gin_channels,
284
+ )
285
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
286
+
287
+ def forward(self, x, x_lengths, g=None):
288
+ x_mask = torch.unsqueeze(
289
+ commons.sequence_mask(x_lengths, x.size(2)), 1
290
+ ).type_as(x)
291
+ x = self.pre(x) * x_mask
292
+ x = self.enc(x, x_mask, g=g)
293
+ stats = self.proj(x) * x_mask
294
+ m, logs = torch.split(stats, self.out_channels, dim=1)
295
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
296
+ return z, m, logs, x_mask
297
+
298
+
299
+ class Generator(torch.nn.Module):
300
+ def __init__(
301
+ self,
302
+ initial_channel: int,
303
+ resblock: typing.Optional[str],
304
+ resblock_kernel_sizes: typing.Tuple[int, ...],
305
+ resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
306
+ upsample_rates: typing.Tuple[int, ...],
307
+ upsample_initial_channel: int,
308
+ upsample_kernel_sizes: typing.Tuple[int, ...],
309
+ gin_channels: int = 0,
310
+ ):
311
+ super(Generator, self).__init__()
312
+ self.LRELU_SLOPE = 0.1
313
+ self.num_kernels = len(resblock_kernel_sizes)
314
+ self.num_upsamples = len(upsample_rates)
315
+ self.conv_pre = Conv1d(
316
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
317
+ )
318
+ resblock_module = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
319
+
320
+ self.ups = nn.ModuleList()
321
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
322
+ self.ups.append(
323
+ weight_norm(
324
+ ConvTranspose1d(
325
+ upsample_initial_channel // (2**i),
326
+ upsample_initial_channel // (2 ** (i + 1)),
327
+ k,
328
+ u,
329
+ padding=(k - u) // 2,
330
+ )
331
+ )
332
+ )
333
+
334
+ self.resblocks = nn.ModuleList()
335
+ for i in range(len(self.ups)):
336
+ ch = upsample_initial_channel // (2 ** (i + 1))
337
+ for j, (k, d) in enumerate(
338
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
339
+ ):
340
+ self.resblocks.append(resblock_module(ch, k, d))
341
+
342
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
343
+ self.ups.apply(init_weights)
344
+
345
+ if gin_channels != 0:
346
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
347
+
348
+ def forward(self, x, g=None):
349
+ x = self.conv_pre(x)
350
+ if g is not None:
351
+ x = x + self.cond(g)
352
+
353
+ for i, up in enumerate(self.ups):
354
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
355
+ x = up(x)
356
+ xs = torch.zeros(1)
357
+ for j, resblock in enumerate(self.resblocks):
358
+ index = j - (i * self.num_kernels)
359
+ if index == 0:
360
+ xs = resblock(x)
361
+ elif (index > 0) and (index < self.num_kernels):
362
+ xs += resblock(x)
363
+ x = xs / self.num_kernels
364
+ x = F.leaky_relu(x)
365
+ x = self.conv_post(x)
366
+ x = torch.tanh(x)
367
+
368
+ return x
369
+
370
+ def remove_weight_norm(self):
371
+ print("Removing weight norm...")
372
+ for l in self.ups:
373
+ remove_weight_norm(l)
374
+ for l in self.resblocks:
375
+ l.remove_weight_norm()
376
+
377
+
378
+ class DiscriminatorP(torch.nn.Module):
379
+ def __init__(
380
+ self,
381
+ period: int,
382
+ kernel_size: int = 5,
383
+ stride: int = 3,
384
+ use_spectral_norm: bool = False,
385
+ ):
386
+ super(DiscriminatorP, self).__init__()
387
+ self.LRELU_SLOPE = 0.1
388
+ self.period = period
389
+ self.use_spectral_norm = use_spectral_norm
390
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
391
+ self.convs = nn.ModuleList(
392
+ [
393
+ norm_f(
394
+ Conv2d(
395
+ 1,
396
+ 32,
397
+ (kernel_size, 1),
398
+ (stride, 1),
399
+ padding=(get_padding(kernel_size, 1), 0),
400
+ )
401
+ ),
402
+ norm_f(
403
+ Conv2d(
404
+ 32,
405
+ 128,
406
+ (kernel_size, 1),
407
+ (stride, 1),
408
+ padding=(get_padding(kernel_size, 1), 0),
409
+ )
410
+ ),
411
+ norm_f(
412
+ Conv2d(
413
+ 128,
414
+ 512,
415
+ (kernel_size, 1),
416
+ (stride, 1),
417
+ padding=(get_padding(kernel_size, 1), 0),
418
+ )
419
+ ),
420
+ norm_f(
421
+ Conv2d(
422
+ 512,
423
+ 1024,
424
+ (kernel_size, 1),
425
+ (stride, 1),
426
+ padding=(get_padding(kernel_size, 1), 0),
427
+ )
428
+ ),
429
+ norm_f(
430
+ Conv2d(
431
+ 1024,
432
+ 1024,
433
+ (kernel_size, 1),
434
+ 1,
435
+ padding=(get_padding(kernel_size, 1), 0),
436
+ )
437
+ ),
438
+ ]
439
+ )
440
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
441
+
442
+ def forward(self, x):
443
+ fmap = []
444
+
445
+ # 1d to 2d
446
+ b, c, t = x.shape
447
+ if t % self.period != 0: # pad first
448
+ n_pad = self.period - (t % self.period)
449
+ x = F.pad(x, (0, n_pad), "reflect")
450
+ t = t + n_pad
451
+ x = x.view(b, c, t // self.period, self.period)
452
+
453
+ for l in self.convs:
454
+ x = l(x)
455
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
456
+ fmap.append(x)
457
+ x = self.conv_post(x)
458
+ fmap.append(x)
459
+ x = torch.flatten(x, 1, -1)
460
+
461
+ return x, fmap
462
+
463
+
464
+ class DiscriminatorS(torch.nn.Module):
465
+ def __init__(self, use_spectral_norm=False):
466
+ super(DiscriminatorS, self).__init__()
467
+ self.LRELU_SLOPE = 0.1
468
+ norm_f = spectral_norm if use_spectral_norm else weight_norm
469
+ self.convs = nn.ModuleList(
470
+ [
471
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
472
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
473
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
474
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
475
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
476
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
477
+ ]
478
+ )
479
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
480
+
481
+ def forward(self, x):
482
+ fmap = []
483
+
484
+ for l in self.convs:
485
+ x = l(x)
486
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
487
+ fmap.append(x)
488
+ x = self.conv_post(x)
489
+ fmap.append(x)
490
+ x = torch.flatten(x, 1, -1)
491
+
492
+ return x, fmap
493
+
494
+
495
+ class MultiPeriodDiscriminator(torch.nn.Module):
496
+ def __init__(self, use_spectral_norm=False):
497
+ super(MultiPeriodDiscriminator, self).__init__()
498
+ periods = [2, 3, 5, 7, 11]
499
+
500
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
501
+ discs = discs + [
502
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
503
+ ]
504
+ self.discriminators = nn.ModuleList(discs)
505
+
506
+ def forward(self, y, y_hat):
507
+ y_d_rs = []
508
+ y_d_gs = []
509
+ fmap_rs = []
510
+ fmap_gs = []
511
+ for i, d in enumerate(self.discriminators):
512
+ y_d_r, fmap_r = d(y)
513
+ y_d_g, fmap_g = d(y_hat)
514
+ y_d_rs.append(y_d_r)
515
+ y_d_gs.append(y_d_g)
516
+ fmap_rs.append(fmap_r)
517
+ fmap_gs.append(fmap_g)
518
+
519
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
520
+
521
+
522
+ class SynthesizerTrn(nn.Module):
523
+ """
524
+ Synthesizer for Training
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ n_vocab: int,
530
+ spec_channels: int,
531
+ segment_size: int,
532
+ inter_channels: int,
533
+ hidden_channels: int,
534
+ filter_channels: int,
535
+ n_heads: int,
536
+ n_layers: int,
537
+ kernel_size: int,
538
+ p_dropout: float,
539
+ resblock: str,
540
+ resblock_kernel_sizes: typing.Tuple[int, ...],
541
+ resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
542
+ upsample_rates: typing.Tuple[int, ...],
543
+ upsample_initial_channel: int,
544
+ upsample_kernel_sizes: typing.Tuple[int, ...],
545
+ n_speakers: int = 1,
546
+ gin_channels: int = 0,
547
+ use_sdp: bool = True,
548
+ ):
549
+
550
+ super().__init__()
551
+ self.n_vocab = n_vocab
552
+ self.spec_channels = spec_channels
553
+ self.inter_channels = inter_channels
554
+ self.hidden_channels = hidden_channels
555
+ self.filter_channels = filter_channels
556
+ self.n_heads = n_heads
557
+ self.n_layers = n_layers
558
+ self.kernel_size = kernel_size
559
+ self.p_dropout = p_dropout
560
+ self.resblock = resblock
561
+ self.resblock_kernel_sizes = resblock_kernel_sizes
562
+ self.resblock_dilation_sizes = resblock_dilation_sizes
563
+ self.upsample_rates = upsample_rates
564
+ self.upsample_initial_channel = upsample_initial_channel
565
+ self.upsample_kernel_sizes = upsample_kernel_sizes
566
+ self.segment_size = segment_size
567
+ self.n_speakers = n_speakers
568
+ self.gin_channels = gin_channels
569
+
570
+ self.use_sdp = use_sdp
571
+
572
+ self.enc_p = TextEncoder(
573
+ n_vocab,
574
+ inter_channels,
575
+ hidden_channels,
576
+ filter_channels,
577
+ n_heads,
578
+ n_layers,
579
+ kernel_size,
580
+ p_dropout,
581
+ )
582
+ self.dec = Generator(
583
+ inter_channels,
584
+ resblock,
585
+ resblock_kernel_sizes,
586
+ resblock_dilation_sizes,
587
+ upsample_rates,
588
+ upsample_initial_channel,
589
+ upsample_kernel_sizes,
590
+ gin_channels=gin_channels,
591
+ )
592
+ self.enc_q = PosteriorEncoder(
593
+ spec_channels,
594
+ inter_channels,
595
+ hidden_channels,
596
+ 5,
597
+ 1,
598
+ 16,
599
+ gin_channels=gin_channels,
600
+ )
601
+ self.flow = ResidualCouplingBlock(
602
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
603
+ )
604
+
605
+ if use_sdp:
606
+ self.dp = StochasticDurationPredictor(
607
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
608
+ )
609
+ else:
610
+ self.dp = DurationPredictor(
611
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
612
+ )
613
+
614
+ if n_speakers > 1:
615
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
616
+
617
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
618
+
619
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
620
+ if self.n_speakers > 1:
621
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
622
+ else:
623
+ g = None
624
+
625
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
626
+ z_p = self.flow(z, y_mask, g=g)
627
+
628
+ with torch.no_grad():
629
+ # negative cross-entropy
630
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
631
+ neg_cent1 = torch.sum(
632
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
633
+ ) # [b, 1, t_s]
634
+ neg_cent2 = torch.matmul(
635
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
636
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
637
+ neg_cent3 = torch.matmul(
638
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
639
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
640
+ neg_cent4 = torch.sum(
641
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
642
+ ) # [b, 1, t_s]
643
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
644
+
645
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
646
+ attn = (
647
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
648
+ .unsqueeze(1)
649
+ .detach()
650
+ )
651
+
652
+ w = attn.sum(2)
653
+ if self.use_sdp:
654
+ l_length = self.dp(x, x_mask, w, g=g)
655
+ l_length = l_length / torch.sum(x_mask)
656
+ else:
657
+ logw_ = torch.log(w + 1e-6) * x_mask
658
+ logw = self.dp(x, x_mask, g=g)
659
+ l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
660
+ x_mask
661
+ ) # for averaging
662
+
663
+ # expand prior
664
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
665
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
666
+
667
+ z_slice, ids_slice = commons.rand_slice_segments(
668
+ z, y_lengths, self.segment_size
669
+ )
670
+ o = self.dec(z_slice, g=g)
671
+ return (
672
+ o,
673
+ l_length,
674
+ attn,
675
+ ids_slice,
676
+ x_mask,
677
+ y_mask,
678
+ (z, z_p, m_p, logs_p, m_q, logs_q),
679
+ )
680
+
681
+ def infer(
682
+ self,
683
+ x,
684
+ x_lengths,
685
+ sid=None,
686
+ noise_scale=0.667,
687
+ length_scale=1,
688
+ noise_scale_w=0.8,
689
+ max_len=None,
690
+ ):
691
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
692
+ if self.n_speakers > 1:
693
+ assert sid is not None, "Missing speaker id"
694
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
695
+ else:
696
+ g = None
697
+
698
+ if self.use_sdp:
699
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
700
+ else:
701
+ logw = self.dp(x, x_mask, g=g)
702
+ w = torch.exp(logw) * x_mask * length_scale
703
+ w_ceil = torch.ceil(w)
704
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
705
+ y_mask = torch.unsqueeze(
706
+ commons.sequence_mask(y_lengths, y_lengths.max()), 1
707
+ ).type_as(x_mask)
708
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
709
+ attn = commons.generate_path(w_ceil, attn_mask)
710
+
711
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
712
+ 1, 2
713
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
714
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
715
+ 1, 2
716
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
717
+
718
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
719
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
720
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
721
+
722
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
723
+
724
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
725
+ assert self.n_speakers > 1, "n_speakers have to be larger than 1."
726
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
727
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
728
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
729
+ z_p = self.flow(z, y_mask, g=g_src)
730
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
731
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
732
+ return o_hat, y_mask, (z, z_p, z_hat)