phoonnx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. phoonnx/__init__.py +0 -0
  2. phoonnx/config.py +490 -0
  3. phoonnx/locale/ca/phonetic_spellings.txt +2 -0
  4. phoonnx/locale/en/phonetic_spellings.txt +1 -0
  5. phoonnx/locale/gl/phonetic_spellings.txt +2 -0
  6. phoonnx/locale/pt/phonetic_spellings.txt +2 -0
  7. phoonnx/phoneme_ids.py +453 -0
  8. phoonnx/phonemizers/__init__.py +45 -0
  9. phoonnx/phonemizers/ar.py +42 -0
  10. phoonnx/phonemizers/base.py +216 -0
  11. phoonnx/phonemizers/en.py +250 -0
  12. phoonnx/phonemizers/fa.py +46 -0
  13. phoonnx/phonemizers/gl.py +142 -0
  14. phoonnx/phonemizers/he.py +67 -0
  15. phoonnx/phonemizers/ja.py +119 -0
  16. phoonnx/phonemizers/ko.py +97 -0
  17. phoonnx/phonemizers/mul.py +606 -0
  18. phoonnx/phonemizers/vi.py +44 -0
  19. phoonnx/phonemizers/zh.py +308 -0
  20. phoonnx/thirdparty/__init__.py +0 -0
  21. phoonnx/thirdparty/arpa2ipa.py +249 -0
  22. phoonnx/thirdparty/cotovia/cotovia_aarch64 +0 -0
  23. phoonnx/thirdparty/cotovia/cotovia_x86_64 +0 -0
  24. phoonnx/thirdparty/hangul2ipa.py +783 -0
  25. phoonnx/thirdparty/ko_tables/aspiration.csv +20 -0
  26. phoonnx/thirdparty/ko_tables/assimilation.csv +31 -0
  27. phoonnx/thirdparty/ko_tables/double_coda.csv +17 -0
  28. phoonnx/thirdparty/ko_tables/hanja.tsv +8525 -0
  29. phoonnx/thirdparty/ko_tables/ipa.csv +22 -0
  30. phoonnx/thirdparty/ko_tables/neutralization.csv +11 -0
  31. phoonnx/thirdparty/ko_tables/tensification.csv +56 -0
  32. phoonnx/thirdparty/ko_tables/yale.csv +22 -0
  33. phoonnx/thirdparty/kog2p/__init__.py +385 -0
  34. phoonnx/thirdparty/kog2p/rulebook.txt +212 -0
  35. phoonnx/thirdparty/mantoq/__init__.py +67 -0
  36. phoonnx/thirdparty/mantoq/buck/__init__.py +0 -0
  37. phoonnx/thirdparty/mantoq/buck/phonetise_buckwalter.py +569 -0
  38. phoonnx/thirdparty/mantoq/buck/symbols.py +64 -0
  39. phoonnx/thirdparty/mantoq/buck/tokenization.py +105 -0
  40. phoonnx/thirdparty/mantoq/num2words.py +37 -0
  41. phoonnx/thirdparty/mantoq/pyarabic/__init__.py +12 -0
  42. phoonnx/thirdparty/mantoq/pyarabic/arabrepr.py +64 -0
  43. phoonnx/thirdparty/mantoq/pyarabic/araby.py +1647 -0
  44. phoonnx/thirdparty/mantoq/pyarabic/named_const.py +227 -0
  45. phoonnx/thirdparty/mantoq/pyarabic/normalize.py +161 -0
  46. phoonnx/thirdparty/mantoq/pyarabic/number.py +826 -0
  47. phoonnx/thirdparty/mantoq/pyarabic/number_const.py +1704 -0
  48. phoonnx/thirdparty/mantoq/pyarabic/stack.py +52 -0
  49. phoonnx/thirdparty/mantoq/pyarabic/trans.py +517 -0
  50. phoonnx/thirdparty/mantoq/unicode_symbol2label.py +4173 -0
  51. phoonnx/thirdparty/tashkeel/LICENSE +22 -0
  52. phoonnx/thirdparty/tashkeel/SOURCE +1 -0
  53. phoonnx/thirdparty/tashkeel/__init__.py +212 -0
  54. phoonnx/thirdparty/tashkeel/hint_id_map.json +18 -0
  55. phoonnx/thirdparty/tashkeel/input_id_map.json +56 -0
  56. phoonnx/thirdparty/tashkeel/model.onnx +0 -0
  57. phoonnx/thirdparty/tashkeel/target_id_map.json +17 -0
  58. phoonnx/thirdparty/zh_num.py +238 -0
  59. phoonnx/util.py +705 -0
  60. phoonnx/version.py +6 -0
  61. phoonnx/voice.py +521 -0
  62. phoonnx-0.0.0.dist-info/METADATA +255 -0
  63. phoonnx-0.0.0.dist-info/RECORD +86 -0
  64. phoonnx-0.0.0.dist-info/WHEEL +5 -0
  65. phoonnx-0.0.0.dist-info/top_level.txt +2 -0
  66. phoonnx_train/__main__.py +151 -0
  67. phoonnx_train/export_onnx.py +109 -0
  68. phoonnx_train/norm_audio/__init__.py +92 -0
  69. phoonnx_train/norm_audio/trim.py +54 -0
  70. phoonnx_train/norm_audio/vad.py +54 -0
  71. phoonnx_train/preprocess.py +420 -0
  72. phoonnx_train/vits/__init__.py +0 -0
  73. phoonnx_train/vits/attentions.py +427 -0
  74. phoonnx_train/vits/commons.py +147 -0
  75. phoonnx_train/vits/config.py +330 -0
  76. phoonnx_train/vits/dataset.py +214 -0
  77. phoonnx_train/vits/lightning.py +352 -0
  78. phoonnx_train/vits/losses.py +58 -0
  79. phoonnx_train/vits/mel_processing.py +139 -0
  80. phoonnx_train/vits/models.py +732 -0
  81. phoonnx_train/vits/modules.py +527 -0
  82. phoonnx_train/vits/monotonic_align/__init__.py +20 -0
  83. phoonnx_train/vits/monotonic_align/setup.py +13 -0
  84. phoonnx_train/vits/transforms.py +212 -0
  85. phoonnx_train/vits/utils.py +16 -0
  86. phoonnx_train/vits/wavfile.py +860 -0
@@ -0,0 +1,427 @@
1
+ import math
2
+ import typing
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .commons import subsequent_mask
9
+ from .modules import LayerNorm
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ hidden_channels: int,
16
+ filter_channels: int,
17
+ n_heads: int,
18
+ n_layers: int,
19
+ kernel_size: int = 1,
20
+ p_dropout: float = 0.0,
21
+ window_size: int = 4,
22
+ **kwargs
23
+ ):
24
+ super().__init__()
25
+ self.hidden_channels = hidden_channels
26
+ self.filter_channels = filter_channels
27
+ self.n_heads = n_heads
28
+ self.n_layers = n_layers
29
+ self.kernel_size = kernel_size
30
+ self.p_dropout = p_dropout
31
+ self.window_size = window_size
32
+
33
+ self.drop = nn.Dropout(p_dropout)
34
+ self.attn_layers = nn.ModuleList()
35
+ self.norm_layers_1 = nn.ModuleList()
36
+ self.ffn_layers = nn.ModuleList()
37
+ self.norm_layers_2 = nn.ModuleList()
38
+ for i in range(self.n_layers):
39
+ self.attn_layers.append(
40
+ MultiHeadAttention(
41
+ hidden_channels,
42
+ hidden_channels,
43
+ n_heads,
44
+ p_dropout=p_dropout,
45
+ window_size=window_size,
46
+ )
47
+ )
48
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
49
+ self.ffn_layers.append(
50
+ FFN(
51
+ hidden_channels,
52
+ hidden_channels,
53
+ filter_channels,
54
+ kernel_size,
55
+ p_dropout=p_dropout,
56
+ )
57
+ )
58
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
59
+
60
+ def forward(self, x, x_mask):
61
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
62
+ x = x * x_mask
63
+ for attn_layer, norm_layer_1, ffn_layer, norm_layer_2 in zip(
64
+ self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
65
+ ):
66
+ y = attn_layer(x, x, attn_mask)
67
+ y = self.drop(y)
68
+ x = norm_layer_1(x + y)
69
+
70
+ y = ffn_layer(x, x_mask)
71
+ y = self.drop(y)
72
+ x = norm_layer_2(x + y)
73
+ x = x * x_mask
74
+ return x
75
+
76
+
77
+ class Decoder(nn.Module):
78
+ def __init__(
79
+ self,
80
+ hidden_channels: int,
81
+ filter_channels: int,
82
+ n_heads: int,
83
+ n_layers: int,
84
+ kernel_size: int = 1,
85
+ p_dropout: float = 0.0,
86
+ proximal_bias: bool = False,
87
+ proximal_init: bool = True,
88
+ **kwargs
89
+ ):
90
+ super().__init__()
91
+ self.hidden_channels = hidden_channels
92
+ self.filter_channels = filter_channels
93
+ self.n_heads = n_heads
94
+ self.n_layers = n_layers
95
+ self.kernel_size = kernel_size
96
+ self.p_dropout = p_dropout
97
+ self.proximal_bias = proximal_bias
98
+ self.proximal_init = proximal_init
99
+
100
+ self.drop = nn.Dropout(p_dropout)
101
+ self.self_attn_layers = nn.ModuleList()
102
+ self.norm_layers_0 = nn.ModuleList()
103
+ self.encdec_attn_layers = nn.ModuleList()
104
+ self.norm_layers_1 = nn.ModuleList()
105
+ self.ffn_layers = nn.ModuleList()
106
+ self.norm_layers_2 = nn.ModuleList()
107
+ for i in range(self.n_layers):
108
+ self.self_attn_layers.append(
109
+ MultiHeadAttention(
110
+ hidden_channels,
111
+ hidden_channels,
112
+ n_heads,
113
+ p_dropout=p_dropout,
114
+ proximal_bias=proximal_bias,
115
+ proximal_init=proximal_init,
116
+ )
117
+ )
118
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
119
+ self.encdec_attn_layers.append(
120
+ MultiHeadAttention(
121
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
122
+ )
123
+ )
124
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
125
+ self.ffn_layers.append(
126
+ FFN(
127
+ hidden_channels,
128
+ hidden_channels,
129
+ filter_channels,
130
+ kernel_size,
131
+ p_dropout=p_dropout,
132
+ causal=True,
133
+ )
134
+ )
135
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
136
+
137
+ def forward(self, x, x_mask, h, h_mask):
138
+ """
139
+ x: decoder input
140
+ h: encoder output
141
+ """
142
+ self_attn_mask = subsequent_mask(x_mask.size(2)).type_as(x)
143
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
144
+ x = x * x_mask
145
+ for i in range(self.n_layers):
146
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
147
+ y = self.drop(y)
148
+ x = self.norm_layers_0[i](x + y)
149
+
150
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
151
+ y = self.drop(y)
152
+ x = self.norm_layers_1[i](x + y)
153
+
154
+ y = self.ffn_layers[i](x, x_mask)
155
+ y = self.drop(y)
156
+ x = self.norm_layers_2[i](x + y)
157
+ x = x * x_mask
158
+ return x
159
+
160
+
161
+ class MultiHeadAttention(nn.Module):
162
+ def __init__(
163
+ self,
164
+ channels: int,
165
+ out_channels: int,
166
+ n_heads: int,
167
+ p_dropout: float = 0.0,
168
+ window_size: typing.Optional[int] = None,
169
+ heads_share: bool = True,
170
+ block_length: typing.Optional[int] = None,
171
+ proximal_bias: bool = False,
172
+ proximal_init: bool = False,
173
+ ):
174
+ super().__init__()
175
+ assert channels % n_heads == 0
176
+
177
+ self.channels = channels
178
+ self.out_channels = out_channels
179
+ self.n_heads = n_heads
180
+ self.p_dropout = p_dropout
181
+ self.window_size = window_size
182
+ self.heads_share = heads_share
183
+ self.block_length = block_length
184
+ self.proximal_bias = proximal_bias
185
+ self.proximal_init = proximal_init
186
+ self.attn = torch.zeros(1)
187
+
188
+ self.k_channels = channels // n_heads
189
+ self.conv_q = nn.Conv1d(channels, channels, 1)
190
+ self.conv_k = nn.Conv1d(channels, channels, 1)
191
+ self.conv_v = nn.Conv1d(channels, channels, 1)
192
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
193
+ self.drop = nn.Dropout(p_dropout)
194
+
195
+ if window_size is not None:
196
+ n_heads_rel = 1 if heads_share else n_heads
197
+ rel_stddev = self.k_channels**-0.5
198
+ self.emb_rel_k = nn.Parameter(
199
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
200
+ * rel_stddev
201
+ )
202
+ self.emb_rel_v = nn.Parameter(
203
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
204
+ * rel_stddev
205
+ )
206
+
207
+ nn.init.xavier_uniform_(self.conv_q.weight)
208
+ nn.init.xavier_uniform_(self.conv_k.weight)
209
+ nn.init.xavier_uniform_(self.conv_v.weight)
210
+ if proximal_init:
211
+ with torch.no_grad():
212
+ self.conv_k.weight.copy_(self.conv_q.weight)
213
+ self.conv_k.bias.copy_(self.conv_q.bias)
214
+
215
+ def forward(self, x, c, attn_mask=None):
216
+ q = self.conv_q(x)
217
+ k = self.conv_k(c)
218
+ v = self.conv_v(c)
219
+
220
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
221
+
222
+ x = self.conv_o(x)
223
+ return x
224
+
225
+ def attention(self, query, key, value, mask=None):
226
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
227
+ b, d, t_s, t_t = (key.size(0), key.size(1), key.size(2), query.size(2))
228
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
229
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
230
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+
232
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
233
+ if self.window_size is not None:
234
+ assert (
235
+ t_s == t_t
236
+ ), "Relative attention is only available for self-attention."
237
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
238
+ rel_logits = self._matmul_with_relative_keys(
239
+ query / math.sqrt(self.k_channels), key_relative_embeddings
240
+ )
241
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
242
+ scores = scores + scores_local
243
+ if self.proximal_bias:
244
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
245
+ scores = scores + self._attention_bias_proximal(t_s).type_as(scores)
246
+ if mask is not None:
247
+ scores = scores.masked_fill(mask == 0, -1e4)
248
+ if self.block_length is not None:
249
+ assert (
250
+ t_s == t_t
251
+ ), "Local attention is only available for self-attention."
252
+ block_mask = (
253
+ torch.ones_like(scores)
254
+ .triu(-self.block_length)
255
+ .tril(self.block_length)
256
+ )
257
+ scores = scores.masked_fill(block_mask == 0, -1e4)
258
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
259
+ p_attn = self.drop(p_attn)
260
+ output = torch.matmul(p_attn, value)
261
+ if self.window_size is not None:
262
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
263
+ value_relative_embeddings = self._get_relative_embeddings(
264
+ self.emb_rel_v, t_s
265
+ )
266
+ output = output + self._matmul_with_relative_values(
267
+ relative_weights, value_relative_embeddings
268
+ )
269
+ output = (
270
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
271
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
272
+ return output, p_attn
273
+
274
+ def _matmul_with_relative_values(self, x, y):
275
+ """
276
+ x: [b, h, l, m]
277
+ y: [h or 1, m, d]
278
+ ret: [b, h, l, d]
279
+ """
280
+ ret = torch.matmul(x, y.unsqueeze(0))
281
+ return ret
282
+
283
+ def _matmul_with_relative_keys(self, x, y):
284
+ """
285
+ x: [b, h, l, d]
286
+ y: [h or 1, m, d]
287
+ ret: [b, h, l, m]
288
+ """
289
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
290
+ return ret
291
+
292
+ def _get_relative_embeddings(self, relative_embeddings, length: int):
293
+ # max_relative_position = 2 * self.window_size + 1
294
+ # Pad first before slice to avoid using cond ops.
295
+ pad_length = max(length - (self.window_size + 1), 0)
296
+ slice_start_position = max((self.window_size + 1) - length, 0)
297
+ slice_end_position = slice_start_position + 2 * length - 1
298
+ if pad_length > 0:
299
+ padded_relative_embeddings = F.pad(
300
+ relative_embeddings,
301
+ # convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
302
+ (0, 0, pad_length, pad_length, 0, 0),
303
+ )
304
+ else:
305
+ padded_relative_embeddings = relative_embeddings
306
+ used_relative_embeddings = padded_relative_embeddings[
307
+ :, slice_start_position:slice_end_position
308
+ ]
309
+ return used_relative_embeddings
310
+
311
+ def _relative_position_to_absolute_position(self, x):
312
+ """
313
+ x: [b, h, l, 2*l-1]
314
+ ret: [b, h, l, l]
315
+ """
316
+ batch, heads, length, _ = x.size()
317
+
318
+ # Concat columns of pad to shift from relative to absolute indexing.
319
+ # x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
320
+ x = F.pad(x, (0, 1, 0, 0, 0, 0, 0, 0))
321
+
322
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
+ x_flat = x.view([batch, heads, length * 2 * length])
324
+ # x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
325
+ x_flat = F.pad(x_flat, (0, length - 1, 0, 0, 0, 0))
326
+
327
+ # Reshape and slice out the padded elements.
328
+ x_final = x_flat.view([batch, heads, length + 1, (2 * length) - 1])[
329
+ :, :, :length, length - 1 :
330
+ ]
331
+ return x_final
332
+
333
+ def _absolute_position_to_relative_position(self, x):
334
+ """
335
+ x: [b, h, l, l]
336
+ ret: [b, h, l, 2*l-1]
337
+ """
338
+ batch, heads, length, _ = x.size()
339
+
340
+ # padd along column
341
+ # x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
342
+ x = F.pad(x, (0, length - 1, 0, 0, 0, 0, 0, 0))
343
+ x_flat = x.view([batch, heads, (length * length) + (length * (length - 1))])
344
+ # add 0's in the beginning that will skew the elements after reshape
345
+ # x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
346
+ x_flat = F.pad(x_flat, (length, 0, 0, 0, 0, 0))
347
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
+ return x_final
349
+
350
+ def _attention_bias_proximal(self, length: int):
351
+ """Bias for self-attention to encourage attention to close positions.
352
+ Args:
353
+ length: an integer scalar.
354
+ Returns:
355
+ a Tensor with shape [1, 1, length, length]
356
+ """
357
+ r = torch.arange(length, dtype=torch.float32)
358
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
+
361
+
362
+ class FFN(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels: int,
366
+ out_channels: int,
367
+ filter_channels: int,
368
+ kernel_size: int,
369
+ p_dropout: float = 0.0,
370
+ activation: str = "",
371
+ causal: bool = False,
372
+ ):
373
+ super().__init__()
374
+ self.in_channels = in_channels
375
+ self.out_channels = out_channels
376
+ self.filter_channels = filter_channels
377
+ self.kernel_size = kernel_size
378
+ self.p_dropout = p_dropout
379
+ self.activation = activation
380
+ self.causal = causal
381
+
382
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
383
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
384
+ self.drop = nn.Dropout(p_dropout)
385
+
386
+ def forward(self, x, x_mask):
387
+ if self.causal:
388
+ padding1 = self._causal_padding(x * x_mask)
389
+ else:
390
+ padding1 = self._same_padding(x * x_mask)
391
+
392
+ x = self.conv_1(padding1)
393
+
394
+ if self.activation == "gelu":
395
+ x = x * torch.sigmoid(1.702 * x)
396
+ else:
397
+ x = torch.relu(x)
398
+ x = self.drop(x)
399
+
400
+ if self.causal:
401
+ padding2 = self._causal_padding(x * x_mask)
402
+ else:
403
+ padding2 = self._same_padding(x * x_mask)
404
+
405
+ x = self.conv_2(padding2)
406
+
407
+ return x * x_mask
408
+
409
+ def _causal_padding(self, x):
410
+ if self.kernel_size == 1:
411
+ return x
412
+ pad_l = self.kernel_size - 1
413
+ pad_r = 0
414
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
415
+ # x = F.pad(x, convert_pad_shape(padding))
416
+ x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
417
+ return x
418
+
419
+ def _same_padding(self, x):
420
+ if self.kernel_size == 1:
421
+ return x
422
+ pad_l = (self.kernel_size - 1) // 2
423
+ pad_r = self.kernel_size // 2
424
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
425
+ # x = F.pad(x, convert_pad_shape(padding))
426
+ x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
427
+ return x
@@ -0,0 +1,147 @@
1
+ import logging
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ _LOGGER = logging.getLogger("vits.commons")
9
+
10
+
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ def intersperse(lst, item):
22
+ result = [item] * (len(lst) * 2 + 1)
23
+ result[1::2] = lst
24
+ return result
25
+
26
+
27
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
28
+ """KL(P||Q)"""
29
+ kl = (logs_q - logs_p) - 0.5
30
+ kl += (
31
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
32
+ )
33
+ return kl
34
+
35
+
36
+ def rand_gumbel(shape):
37
+ """Sample from the Gumbel distribution, protect from overflows."""
38
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
39
+ return -torch.log(-torch.log(uniform_samples))
40
+
41
+
42
+ def rand_gumbel_like(x):
43
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
44
+ return g
45
+
46
+
47
+ def slice_segments(x, ids_str, segment_size=4):
48
+ ret = torch.zeros_like(x[:, :, :segment_size])
49
+ for i in range(x.size(0)):
50
+ idx_str = max(0, ids_str[i])
51
+ idx_end = idx_str + segment_size
52
+ ret[i] = x[i, :, idx_str:idx_end]
53
+ return ret
54
+
55
+
56
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
57
+ b, d, t = x.size()
58
+ if x_lengths is None:
59
+ x_lengths = t
60
+ ids_str_max = x_lengths - segment_size + 1
61
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
62
+ ret = slice_segments(x, ids_str, segment_size)
63
+ return ret, ids_str
64
+
65
+
66
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
67
+ position = torch.arange(length, dtype=torch.float)
68
+ num_timescales = channels // 2
69
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
70
+ num_timescales - 1
71
+ )
72
+ inv_timescales = min_timescale * torch.exp(
73
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
74
+ )
75
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
76
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
77
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
78
+ signal = signal.view(1, channels, length)
79
+ return signal
80
+
81
+
82
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
83
+ b, channels, length = x.size()
84
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
85
+ return x + signal.to(dtype=x.dtype, device=x.device)
86
+
87
+
88
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
89
+ b, channels, length = x.size()
90
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
92
+
93
+
94
+ def subsequent_mask(length: int):
95
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
96
+ return mask
97
+
98
+
99
+ @torch.jit.script
100
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
101
+ n_channels_int = n_channels[0]
102
+ in_act = input_a + input_b
103
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
104
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
105
+ acts = t_act * s_act
106
+ return acts
107
+
108
+
109
+ def sequence_mask(length, max_length: Optional[int] = None):
110
+ if max_length is None:
111
+ max_length = length.max()
112
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
113
+ return x.unsqueeze(0) < length.unsqueeze(1)
114
+
115
+
116
+ def generate_path(duration, mask):
117
+ """
118
+ duration: [b, 1, t_x]
119
+ mask: [b, 1, t_y, t_x]
120
+ """
121
+ b, _, t_y, t_x = mask.shape
122
+ cum_duration = torch.cumsum(duration, -1)
123
+
124
+ cum_duration_flat = cum_duration.view(b * t_x)
125
+ path = sequence_mask(cum_duration_flat, t_y).type_as(mask)
126
+ path = path.view(b, t_x, t_y)
127
+ path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
128
+ path = path.unsqueeze(1).transpose(2, 3) * mask
129
+ return path
130
+
131
+
132
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
133
+ if isinstance(parameters, torch.Tensor):
134
+ parameters = [parameters]
135
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
136
+ norm_type = float(norm_type)
137
+ if clip_value is not None:
138
+ clip_value = float(clip_value)
139
+
140
+ total_norm = 0
141
+ for p in parameters:
142
+ param_norm = p.grad.data.norm(norm_type)
143
+ total_norm += param_norm.item() ** norm_type
144
+ if clip_value is not None:
145
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
146
+ total_norm = total_norm ** (1.0 / norm_type)
147
+ return total_norm