phoonnx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. phoonnx/__init__.py +0 -0
  2. phoonnx/config.py +490 -0
  3. phoonnx/locale/ca/phonetic_spellings.txt +2 -0
  4. phoonnx/locale/en/phonetic_spellings.txt +1 -0
  5. phoonnx/locale/gl/phonetic_spellings.txt +2 -0
  6. phoonnx/locale/pt/phonetic_spellings.txt +2 -0
  7. phoonnx/phoneme_ids.py +453 -0
  8. phoonnx/phonemizers/__init__.py +45 -0
  9. phoonnx/phonemizers/ar.py +42 -0
  10. phoonnx/phonemizers/base.py +216 -0
  11. phoonnx/phonemizers/en.py +250 -0
  12. phoonnx/phonemizers/fa.py +46 -0
  13. phoonnx/phonemizers/gl.py +142 -0
  14. phoonnx/phonemizers/he.py +67 -0
  15. phoonnx/phonemizers/ja.py +119 -0
  16. phoonnx/phonemizers/ko.py +97 -0
  17. phoonnx/phonemizers/mul.py +606 -0
  18. phoonnx/phonemizers/vi.py +44 -0
  19. phoonnx/phonemizers/zh.py +308 -0
  20. phoonnx/thirdparty/__init__.py +0 -0
  21. phoonnx/thirdparty/arpa2ipa.py +249 -0
  22. phoonnx/thirdparty/cotovia/cotovia_aarch64 +0 -0
  23. phoonnx/thirdparty/cotovia/cotovia_x86_64 +0 -0
  24. phoonnx/thirdparty/hangul2ipa.py +783 -0
  25. phoonnx/thirdparty/ko_tables/aspiration.csv +20 -0
  26. phoonnx/thirdparty/ko_tables/assimilation.csv +31 -0
  27. phoonnx/thirdparty/ko_tables/double_coda.csv +17 -0
  28. phoonnx/thirdparty/ko_tables/hanja.tsv +8525 -0
  29. phoonnx/thirdparty/ko_tables/ipa.csv +22 -0
  30. phoonnx/thirdparty/ko_tables/neutralization.csv +11 -0
  31. phoonnx/thirdparty/ko_tables/tensification.csv +56 -0
  32. phoonnx/thirdparty/ko_tables/yale.csv +22 -0
  33. phoonnx/thirdparty/kog2p/__init__.py +385 -0
  34. phoonnx/thirdparty/kog2p/rulebook.txt +212 -0
  35. phoonnx/thirdparty/mantoq/__init__.py +67 -0
  36. phoonnx/thirdparty/mantoq/buck/__init__.py +0 -0
  37. phoonnx/thirdparty/mantoq/buck/phonetise_buckwalter.py +569 -0
  38. phoonnx/thirdparty/mantoq/buck/symbols.py +64 -0
  39. phoonnx/thirdparty/mantoq/buck/tokenization.py +105 -0
  40. phoonnx/thirdparty/mantoq/num2words.py +37 -0
  41. phoonnx/thirdparty/mantoq/pyarabic/__init__.py +12 -0
  42. phoonnx/thirdparty/mantoq/pyarabic/arabrepr.py +64 -0
  43. phoonnx/thirdparty/mantoq/pyarabic/araby.py +1647 -0
  44. phoonnx/thirdparty/mantoq/pyarabic/named_const.py +227 -0
  45. phoonnx/thirdparty/mantoq/pyarabic/normalize.py +161 -0
  46. phoonnx/thirdparty/mantoq/pyarabic/number.py +826 -0
  47. phoonnx/thirdparty/mantoq/pyarabic/number_const.py +1704 -0
  48. phoonnx/thirdparty/mantoq/pyarabic/stack.py +52 -0
  49. phoonnx/thirdparty/mantoq/pyarabic/trans.py +517 -0
  50. phoonnx/thirdparty/mantoq/unicode_symbol2label.py +4173 -0
  51. phoonnx/thirdparty/tashkeel/LICENSE +22 -0
  52. phoonnx/thirdparty/tashkeel/SOURCE +1 -0
  53. phoonnx/thirdparty/tashkeel/__init__.py +212 -0
  54. phoonnx/thirdparty/tashkeel/hint_id_map.json +18 -0
  55. phoonnx/thirdparty/tashkeel/input_id_map.json +56 -0
  56. phoonnx/thirdparty/tashkeel/model.onnx +0 -0
  57. phoonnx/thirdparty/tashkeel/target_id_map.json +17 -0
  58. phoonnx/thirdparty/zh_num.py +238 -0
  59. phoonnx/util.py +705 -0
  60. phoonnx/version.py +6 -0
  61. phoonnx/voice.py +521 -0
  62. phoonnx-0.0.0.dist-info/METADATA +255 -0
  63. phoonnx-0.0.0.dist-info/RECORD +86 -0
  64. phoonnx-0.0.0.dist-info/WHEEL +5 -0
  65. phoonnx-0.0.0.dist-info/top_level.txt +2 -0
  66. phoonnx_train/__main__.py +151 -0
  67. phoonnx_train/export_onnx.py +109 -0
  68. phoonnx_train/norm_audio/__init__.py +92 -0
  69. phoonnx_train/norm_audio/trim.py +54 -0
  70. phoonnx_train/norm_audio/vad.py +54 -0
  71. phoonnx_train/preprocess.py +420 -0
  72. phoonnx_train/vits/__init__.py +0 -0
  73. phoonnx_train/vits/attentions.py +427 -0
  74. phoonnx_train/vits/commons.py +147 -0
  75. phoonnx_train/vits/config.py +330 -0
  76. phoonnx_train/vits/dataset.py +214 -0
  77. phoonnx_train/vits/lightning.py +352 -0
  78. phoonnx_train/vits/losses.py +58 -0
  79. phoonnx_train/vits/mel_processing.py +139 -0
  80. phoonnx_train/vits/models.py +732 -0
  81. phoonnx_train/vits/modules.py +527 -0
  82. phoonnx_train/vits/monotonic_align/__init__.py +20 -0
  83. phoonnx_train/vits/monotonic_align/setup.py +13 -0
  84. phoonnx_train/vits/transforms.py +212 -0
  85. phoonnx_train/vits/utils.py +16 -0
  86. phoonnx_train/vits/wavfile.py +860 -0
@@ -0,0 +1,352 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from torch import autocast
8
+ from torch.nn import functional as F
9
+ from torch.utils.data import DataLoader, Dataset, random_split
10
+
11
+ from .commons import slice_segments
12
+ from .dataset import Batch, PiperDataset, UtteranceCollate
13
+ from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
14
+ from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
15
+ from .models import MultiPeriodDiscriminator, SynthesizerTrn
16
+
17
+ _LOGGER = logging.getLogger("vits.lightning")
18
+
19
+
20
+ class VitsModel(pl.LightningModule):
21
+ def __init__(
22
+ self,
23
+ num_symbols: int,
24
+ num_speakers: int,
25
+ # audio
26
+ resblock="2",
27
+ resblock_kernel_sizes=(3, 5, 7),
28
+ resblock_dilation_sizes=(
29
+ (1, 2),
30
+ (2, 6),
31
+ (3, 12),
32
+ ),
33
+ upsample_rates=(8, 8, 4),
34
+ upsample_initial_channel=256,
35
+ upsample_kernel_sizes=(16, 16, 8),
36
+ # mel
37
+ filter_length: int = 1024,
38
+ hop_length: int = 256,
39
+ win_length: int = 1024,
40
+ mel_channels: int = 80,
41
+ sample_rate: int = 22050,
42
+ sample_bytes: int = 2,
43
+ channels: int = 1,
44
+ mel_fmin: float = 0.0,
45
+ mel_fmax: Optional[float] = None,
46
+ # model
47
+ inter_channels: int = 192,
48
+ hidden_channels: int = 192,
49
+ filter_channels: int = 768,
50
+ n_heads: int = 2,
51
+ n_layers: int = 6,
52
+ kernel_size: int = 3,
53
+ p_dropout: float = 0.1,
54
+ n_layers_q: int = 3,
55
+ use_spectral_norm: bool = False,
56
+ gin_channels: int = 0,
57
+ use_sdp: bool = True,
58
+ segment_size: int = 8192,
59
+ # training
60
+ dataset: Optional[List[Union[str, Path]]] = None,
61
+ learning_rate: float = 2e-4,
62
+ betas: Tuple[float, float] = (0.8, 0.99),
63
+ eps: float = 1e-9,
64
+ batch_size: int = 1,
65
+ lr_decay: float = 0.999875,
66
+ init_lr_ratio: float = 1.0,
67
+ warmup_epochs: int = 0,
68
+ c_mel: int = 45,
69
+ c_kl: float = 1.0,
70
+ grad_clip: Optional[float] = None,
71
+ num_workers: int = 1,
72
+ seed: int = 1234,
73
+ num_test_examples: int = 5,
74
+ validation_split: float = 0.1,
75
+ max_phoneme_ids: Optional[int] = None,
76
+ **kwargs,
77
+ ):
78
+ super().__init__()
79
+ self.save_hyperparameters()
80
+
81
+ if (self.hparams.num_speakers > 1) and (self.hparams.gin_channels <= 0):
82
+ # Default gin_channels for multi-speaker model
83
+ self.hparams.gin_channels = 512
84
+
85
+ # Set up models
86
+ self.model_g = SynthesizerTrn(
87
+ n_vocab=self.hparams.num_symbols,
88
+ spec_channels=self.hparams.filter_length // 2 + 1,
89
+ segment_size=self.hparams.segment_size // self.hparams.hop_length,
90
+ inter_channels=self.hparams.inter_channels,
91
+ hidden_channels=self.hparams.hidden_channels,
92
+ filter_channels=self.hparams.filter_channels,
93
+ n_heads=self.hparams.n_heads,
94
+ n_layers=self.hparams.n_layers,
95
+ kernel_size=self.hparams.kernel_size,
96
+ p_dropout=self.hparams.p_dropout,
97
+ resblock=self.hparams.resblock,
98
+ resblock_kernel_sizes=self.hparams.resblock_kernel_sizes,
99
+ resblock_dilation_sizes=self.hparams.resblock_dilation_sizes,
100
+ upsample_rates=self.hparams.upsample_rates,
101
+ upsample_initial_channel=self.hparams.upsample_initial_channel,
102
+ upsample_kernel_sizes=self.hparams.upsample_kernel_sizes,
103
+ n_speakers=self.hparams.num_speakers,
104
+ gin_channels=self.hparams.gin_channels,
105
+ use_sdp=self.hparams.use_sdp,
106
+ )
107
+ self.model_d = MultiPeriodDiscriminator(
108
+ use_spectral_norm=self.hparams.use_spectral_norm
109
+ )
110
+
111
+ # Dataset splits
112
+ self._train_dataset: Optional[Dataset] = None
113
+ self._val_dataset: Optional[Dataset] = None
114
+ self._test_dataset: Optional[Dataset] = None
115
+ self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)
116
+
117
+ # State kept between training optimizers
118
+ self._y = None
119
+ self._y_hat = None
120
+
121
+ def _load_datasets(
122
+ self,
123
+ validation_split: float,
124
+ num_test_examples: int,
125
+ max_phoneme_ids: Optional[int] = None,
126
+ ):
127
+ if self.hparams.dataset is None:
128
+ _LOGGER.debug("No dataset to load")
129
+ return
130
+
131
+ full_dataset = PiperDataset(
132
+ self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
133
+ )
134
+ valid_set_size = int(len(full_dataset) * validation_split)
135
+ train_set_size = len(full_dataset) - valid_set_size - num_test_examples
136
+
137
+ self._train_dataset, self._test_dataset, self._val_dataset = random_split(
138
+ full_dataset, [train_set_size, num_test_examples, valid_set_size]
139
+ )
140
+
141
+ def forward(self, text, text_lengths, scales, sid=None):
142
+ noise_scale = scales[0]
143
+ length_scale = scales[1]
144
+ noise_scale_w = scales[2]
145
+ audio, *_ = self.model_g.infer(
146
+ text,
147
+ text_lengths,
148
+ noise_scale=noise_scale,
149
+ length_scale=length_scale,
150
+ noise_scale_w=noise_scale_w,
151
+ sid=sid,
152
+ )
153
+
154
+ return audio
155
+
156
+ def train_dataloader(self):
157
+ return DataLoader(
158
+ self._train_dataset,
159
+ collate_fn=UtteranceCollate(
160
+ is_multispeaker=self.hparams.num_speakers > 1,
161
+ segment_size=self.hparams.segment_size,
162
+ ),
163
+ num_workers=self.hparams.num_workers,
164
+ batch_size=self.hparams.batch_size,
165
+ )
166
+
167
+ def val_dataloader(self):
168
+ return DataLoader(
169
+ self._val_dataset,
170
+ collate_fn=UtteranceCollate(
171
+ is_multispeaker=self.hparams.num_speakers > 1,
172
+ segment_size=self.hparams.segment_size,
173
+ ),
174
+ num_workers=self.hparams.num_workers,
175
+ batch_size=self.hparams.batch_size,
176
+ )
177
+
178
+ def test_dataloader(self):
179
+ return DataLoader(
180
+ self._test_dataset,
181
+ collate_fn=UtteranceCollate(
182
+ is_multispeaker=self.hparams.num_speakers > 1,
183
+ segment_size=self.hparams.segment_size,
184
+ ),
185
+ num_workers=self.hparams.num_workers,
186
+ batch_size=self.hparams.batch_size,
187
+ )
188
+
189
+ def training_step(self, batch: Batch, batch_idx: int, optimizer_idx: int):
190
+ if optimizer_idx == 0:
191
+ return self.training_step_g(batch)
192
+
193
+ if optimizer_idx == 1:
194
+ return self.training_step_d(batch)
195
+
196
+ def training_step_g(self, batch: Batch):
197
+ x, x_lengths, y, _, spec, spec_lengths, speaker_ids = (
198
+ batch.phoneme_ids,
199
+ batch.phoneme_lengths,
200
+ batch.audios,
201
+ batch.audio_lengths,
202
+ batch.spectrograms,
203
+ batch.spectrogram_lengths,
204
+ batch.speaker_ids if batch.speaker_ids is not None else None,
205
+ )
206
+ (
207
+ y_hat,
208
+ l_length,
209
+ _attn,
210
+ ids_slice,
211
+ _x_mask,
212
+ z_mask,
213
+ (_z, z_p, m_p, logs_p, _m_q, logs_q),
214
+ ) = self.model_g(x, x_lengths, spec, spec_lengths, speaker_ids)
215
+ self._y_hat = y_hat
216
+
217
+ mel = spec_to_mel_torch(
218
+ spec,
219
+ self.hparams.filter_length,
220
+ self.hparams.mel_channels,
221
+ self.hparams.sample_rate,
222
+ self.hparams.mel_fmin,
223
+ self.hparams.mel_fmax,
224
+ )
225
+ y_mel = slice_segments(
226
+ mel,
227
+ ids_slice,
228
+ self.hparams.segment_size // self.hparams.hop_length,
229
+ )
230
+ y_hat_mel = mel_spectrogram_torch(
231
+ y_hat.squeeze(1),
232
+ self.hparams.filter_length,
233
+ self.hparams.mel_channels,
234
+ self.hparams.sample_rate,
235
+ self.hparams.hop_length,
236
+ self.hparams.win_length,
237
+ self.hparams.mel_fmin,
238
+ self.hparams.mel_fmax,
239
+ )
240
+ y = slice_segments(
241
+ y,
242
+ ids_slice * self.hparams.hop_length,
243
+ self.hparams.segment_size,
244
+ ) # slice
245
+
246
+ # Save for training_step_d
247
+ self._y = y
248
+
249
+ _y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)
250
+
251
+ with autocast(self.device.type, enabled=False):
252
+ # Generator loss
253
+ loss_dur = torch.sum(l_length.float())
254
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.c_mel
255
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.c_kl
256
+
257
+ loss_fm = feature_loss(fmap_r, fmap_g)
258
+ loss_gen, _losses_gen = generator_loss(y_d_hat_g)
259
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
260
+
261
+ self.log("loss_gen_all", loss_gen_all)
262
+
263
+ return loss_gen_all
264
+
265
+ def training_step_d(self, batch: Batch):
266
+ # From training_step_g
267
+ y = self._y
268
+ y_hat = self._y_hat
269
+ y_d_hat_r, y_d_hat_g, _, _ = self.model_d(y, y_hat.detach())
270
+
271
+ with autocast(self.device.type, enabled=False):
272
+ # Discriminator
273
+ loss_disc, _losses_disc_r, _losses_disc_g = discriminator_loss(
274
+ y_d_hat_r, y_d_hat_g
275
+ )
276
+ loss_disc_all = loss_disc
277
+
278
+ self.log("loss_disc_all", loss_disc_all)
279
+
280
+ return loss_disc_all
281
+
282
+ def validation_step(self, batch: Batch, batch_idx: int):
283
+ val_loss = self.training_step_g(batch) + self.training_step_d(batch)
284
+ self.log("val_loss", val_loss)
285
+
286
+ # Generate audio examples
287
+ for utt_idx, test_utt in enumerate(self._test_dataset):
288
+ text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
289
+ text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
290
+ scales = [0.667, 1.0, 0.8]
291
+ sid = (
292
+ test_utt.speaker_id.to(self.device)
293
+ if test_utt.speaker_id is not None
294
+ else None
295
+ )
296
+ test_audio = self(text, text_lengths, scales, sid=sid).detach()
297
+
298
+ # Scale to make louder in [-1, 1]
299
+ test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
300
+
301
+ tag = test_utt.text or str(utt_idx)
302
+ self.logger.experiment.add_audio(
303
+ tag, test_audio, sample_rate=self.hparams.sample_rate
304
+ )
305
+
306
+ return val_loss
307
+
308
+ def configure_optimizers(self):
309
+ optimizers = [
310
+ torch.optim.AdamW(
311
+ self.model_g.parameters(),
312
+ lr=self.hparams.learning_rate,
313
+ betas=self.hparams.betas,
314
+ eps=self.hparams.eps,
315
+ ),
316
+ torch.optim.AdamW(
317
+ self.model_d.parameters(),
318
+ lr=self.hparams.learning_rate,
319
+ betas=self.hparams.betas,
320
+ eps=self.hparams.eps,
321
+ ),
322
+ ]
323
+ schedulers = [
324
+ torch.optim.lr_scheduler.ExponentialLR(
325
+ optimizers[0], gamma=self.hparams.lr_decay
326
+ ),
327
+ torch.optim.lr_scheduler.ExponentialLR(
328
+ optimizers[1], gamma=self.hparams.lr_decay
329
+ ),
330
+ ]
331
+
332
+ return optimizers, schedulers
333
+
334
+ @staticmethod
335
+ def add_model_specific_args(parent_parser):
336
+ parser = parent_parser.add_argument_group("VitsModel")
337
+ parser.add_argument("--batch-size", type=int, required=True)
338
+ parser.add_argument("--validation-split", type=float, default=0.1)
339
+ parser.add_argument("--num-test-examples", type=int, default=5)
340
+ parser.add_argument(
341
+ "--max-phoneme-ids",
342
+ type=int,
343
+ help="Exclude utterances with phoneme id lists longer than this",
344
+ )
345
+ #
346
+ parser.add_argument("--hidden-channels", type=int, default=192)
347
+ parser.add_argument("--inter-channels", type=int, default=192)
348
+ parser.add_argument("--filter-channels", type=int, default=768)
349
+ parser.add_argument("--n-layers", type=int, default=6)
350
+ parser.add_argument("--n-heads", type=int, default=2)
351
+ #
352
+ return parent_parser
@@ -0,0 +1,58 @@
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr) ** 2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += r_loss + g_loss
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l_dg = torch.mean((1 - dg) ** 2)
37
+ gen_losses.append(l_dg)
38
+ loss += l_dg
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l_kl = kl / torch.sum(z_mask)
58
+ return l_kl
@@ -0,0 +1,139 @@
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+
8
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9
+ """
10
+ PARAMS
11
+ ------
12
+ C: compression factor
13
+ """
14
+ return torch.log(torch.clamp(x, min=clip_val) * C)
15
+
16
+
17
+ def dynamic_range_decompression_torch(x, C=1):
18
+ """
19
+ PARAMS
20
+ ------
21
+ C: compression factor used to compress
22
+ """
23
+ return torch.exp(x) / C
24
+
25
+
26
+ def spectral_normalize_torch(magnitudes):
27
+ output = dynamic_range_compression_torch(magnitudes)
28
+ return output
29
+
30
+
31
+ def spectral_de_normalize_torch(magnitudes):
32
+ output = dynamic_range_decompression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ mel_basis = {}
37
+ hann_window = {}
38
+
39
+
40
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41
+ if torch.min(y) < -1.0:
42
+ print("min value is ", torch.min(y))
43
+ if torch.max(y) > 1.0:
44
+ print("max value is ", torch.max(y))
45
+
46
+ global hann_window
47
+ dtype_device = str(y.dtype) + "_" + str(y.device)
48
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
49
+ if wnsize_dtype_device not in hann_window:
50
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).type_as(y)
51
+
52
+ y = torch.nn.functional.pad(
53
+ y.unsqueeze(1),
54
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
55
+ mode="reflect",
56
+ )
57
+ y = y.squeeze(1)
58
+
59
+ spec = torch.view_as_real(
60
+ torch.stft(
61
+ y,
62
+ n_fft,
63
+ hop_length=hop_size,
64
+ win_length=win_size,
65
+ window=hann_window[wnsize_dtype_device],
66
+ center=center,
67
+ pad_mode="reflect",
68
+ normalized=False,
69
+ onesided=True,
70
+ return_complex=True,
71
+ )
72
+ )
73
+
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
75
+
76
+ return spec
77
+
78
+
79
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
80
+ global mel_basis
81
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
82
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
83
+ if fmax_dtype_device not in mel_basis:
84
+ mel = librosa_mel_fn(
85
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
86
+ )
87
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).type_as(spec)
88
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
89
+ spec = spectral_normalize_torch(spec)
90
+ return spec
91
+
92
+
93
+ def mel_spectrogram_torch(
94
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
95
+ ):
96
+ if torch.min(y) < -1.0:
97
+ print("min value is ", torch.min(y))
98
+ if torch.max(y) > 1.0:
99
+ print("max value is ", torch.max(y))
100
+
101
+ global mel_basis, hann_window
102
+ dtype_device = str(y.dtype) + "_" + str(y.device)
103
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
104
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
105
+ if fmax_dtype_device not in mel_basis:
106
+ mel = librosa_mel_fn(
107
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
108
+ )
109
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).type_as(y)
110
+ if wnsize_dtype_device not in hann_window:
111
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).type_as(y)
112
+
113
+ y = torch.nn.functional.pad(
114
+ y.unsqueeze(1),
115
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
116
+ mode="reflect",
117
+ )
118
+ y = y.squeeze(1)
119
+ spec = torch.view_as_real(
120
+ torch.stft(
121
+ y,
122
+ n_fft,
123
+ hop_length=hop_size,
124
+ win_length=win_size,
125
+ window=hann_window[wnsize_dtype_device],
126
+ center=center,
127
+ pad_mode="reflect",
128
+ normalized=False,
129
+ onesided=True,
130
+ return_complex=True,
131
+ )
132
+ )
133
+
134
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
135
+
136
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
137
+ spec = spectral_normalize_torch(spec)
138
+
139
+ return spec