phoonnx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phoonnx/__init__.py +0 -0
- phoonnx/config.py +490 -0
- phoonnx/locale/ca/phonetic_spellings.txt +2 -0
- phoonnx/locale/en/phonetic_spellings.txt +1 -0
- phoonnx/locale/gl/phonetic_spellings.txt +2 -0
- phoonnx/locale/pt/phonetic_spellings.txt +2 -0
- phoonnx/phoneme_ids.py +453 -0
- phoonnx/phonemizers/__init__.py +45 -0
- phoonnx/phonemizers/ar.py +42 -0
- phoonnx/phonemizers/base.py +216 -0
- phoonnx/phonemizers/en.py +250 -0
- phoonnx/phonemizers/fa.py +46 -0
- phoonnx/phonemizers/gl.py +142 -0
- phoonnx/phonemizers/he.py +67 -0
- phoonnx/phonemizers/ja.py +119 -0
- phoonnx/phonemizers/ko.py +97 -0
- phoonnx/phonemizers/mul.py +606 -0
- phoonnx/phonemizers/vi.py +44 -0
- phoonnx/phonemizers/zh.py +308 -0
- phoonnx/thirdparty/__init__.py +0 -0
- phoonnx/thirdparty/arpa2ipa.py +249 -0
- phoonnx/thirdparty/cotovia/cotovia_aarch64 +0 -0
- phoonnx/thirdparty/cotovia/cotovia_x86_64 +0 -0
- phoonnx/thirdparty/hangul2ipa.py +783 -0
- phoonnx/thirdparty/ko_tables/aspiration.csv +20 -0
- phoonnx/thirdparty/ko_tables/assimilation.csv +31 -0
- phoonnx/thirdparty/ko_tables/double_coda.csv +17 -0
- phoonnx/thirdparty/ko_tables/hanja.tsv +8525 -0
- phoonnx/thirdparty/ko_tables/ipa.csv +22 -0
- phoonnx/thirdparty/ko_tables/neutralization.csv +11 -0
- phoonnx/thirdparty/ko_tables/tensification.csv +56 -0
- phoonnx/thirdparty/ko_tables/yale.csv +22 -0
- phoonnx/thirdparty/kog2p/__init__.py +385 -0
- phoonnx/thirdparty/kog2p/rulebook.txt +212 -0
- phoonnx/thirdparty/mantoq/__init__.py +67 -0
- phoonnx/thirdparty/mantoq/buck/__init__.py +0 -0
- phoonnx/thirdparty/mantoq/buck/phonetise_buckwalter.py +569 -0
- phoonnx/thirdparty/mantoq/buck/symbols.py +64 -0
- phoonnx/thirdparty/mantoq/buck/tokenization.py +105 -0
- phoonnx/thirdparty/mantoq/num2words.py +37 -0
- phoonnx/thirdparty/mantoq/pyarabic/__init__.py +12 -0
- phoonnx/thirdparty/mantoq/pyarabic/arabrepr.py +64 -0
- phoonnx/thirdparty/mantoq/pyarabic/araby.py +1647 -0
- phoonnx/thirdparty/mantoq/pyarabic/named_const.py +227 -0
- phoonnx/thirdparty/mantoq/pyarabic/normalize.py +161 -0
- phoonnx/thirdparty/mantoq/pyarabic/number.py +826 -0
- phoonnx/thirdparty/mantoq/pyarabic/number_const.py +1704 -0
- phoonnx/thirdparty/mantoq/pyarabic/stack.py +52 -0
- phoonnx/thirdparty/mantoq/pyarabic/trans.py +517 -0
- phoonnx/thirdparty/mantoq/unicode_symbol2label.py +4173 -0
- phoonnx/thirdparty/tashkeel/LICENSE +22 -0
- phoonnx/thirdparty/tashkeel/SOURCE +1 -0
- phoonnx/thirdparty/tashkeel/__init__.py +212 -0
- phoonnx/thirdparty/tashkeel/hint_id_map.json +18 -0
- phoonnx/thirdparty/tashkeel/input_id_map.json +56 -0
- phoonnx/thirdparty/tashkeel/model.onnx +0 -0
- phoonnx/thirdparty/tashkeel/target_id_map.json +17 -0
- phoonnx/thirdparty/zh_num.py +238 -0
- phoonnx/util.py +705 -0
- phoonnx/version.py +6 -0
- phoonnx/voice.py +521 -0
- phoonnx-0.0.0.dist-info/METADATA +255 -0
- phoonnx-0.0.0.dist-info/RECORD +86 -0
- phoonnx-0.0.0.dist-info/WHEEL +5 -0
- phoonnx-0.0.0.dist-info/top_level.txt +2 -0
- phoonnx_train/__main__.py +151 -0
- phoonnx_train/export_onnx.py +109 -0
- phoonnx_train/norm_audio/__init__.py +92 -0
- phoonnx_train/norm_audio/trim.py +54 -0
- phoonnx_train/norm_audio/vad.py +54 -0
- phoonnx_train/preprocess.py +420 -0
- phoonnx_train/vits/__init__.py +0 -0
- phoonnx_train/vits/attentions.py +427 -0
- phoonnx_train/vits/commons.py +147 -0
- phoonnx_train/vits/config.py +330 -0
- phoonnx_train/vits/dataset.py +214 -0
- phoonnx_train/vits/lightning.py +352 -0
- phoonnx_train/vits/losses.py +58 -0
- phoonnx_train/vits/mel_processing.py +139 -0
- phoonnx_train/vits/models.py +732 -0
- phoonnx_train/vits/modules.py +527 -0
- phoonnx_train/vits/monotonic_align/__init__.py +20 -0
- phoonnx_train/vits/monotonic_align/setup.py +13 -0
- phoonnx_train/vits/transforms.py +212 -0
- phoonnx_train/vits/utils.py +16 -0
- phoonnx_train/vits/wavfile.py +860 -0
@@ -0,0 +1,352 @@
|
|
1
|
+
import logging
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
import pytorch_lightning as pl
|
6
|
+
import torch
|
7
|
+
from torch import autocast
|
8
|
+
from torch.nn import functional as F
|
9
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
10
|
+
|
11
|
+
from .commons import slice_segments
|
12
|
+
from .dataset import Batch, PiperDataset, UtteranceCollate
|
13
|
+
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
14
|
+
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
15
|
+
from .models import MultiPeriodDiscriminator, SynthesizerTrn
|
16
|
+
|
17
|
+
_LOGGER = logging.getLogger("vits.lightning")
|
18
|
+
|
19
|
+
|
20
|
+
class VitsModel(pl.LightningModule):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
num_symbols: int,
|
24
|
+
num_speakers: int,
|
25
|
+
# audio
|
26
|
+
resblock="2",
|
27
|
+
resblock_kernel_sizes=(3, 5, 7),
|
28
|
+
resblock_dilation_sizes=(
|
29
|
+
(1, 2),
|
30
|
+
(2, 6),
|
31
|
+
(3, 12),
|
32
|
+
),
|
33
|
+
upsample_rates=(8, 8, 4),
|
34
|
+
upsample_initial_channel=256,
|
35
|
+
upsample_kernel_sizes=(16, 16, 8),
|
36
|
+
# mel
|
37
|
+
filter_length: int = 1024,
|
38
|
+
hop_length: int = 256,
|
39
|
+
win_length: int = 1024,
|
40
|
+
mel_channels: int = 80,
|
41
|
+
sample_rate: int = 22050,
|
42
|
+
sample_bytes: int = 2,
|
43
|
+
channels: int = 1,
|
44
|
+
mel_fmin: float = 0.0,
|
45
|
+
mel_fmax: Optional[float] = None,
|
46
|
+
# model
|
47
|
+
inter_channels: int = 192,
|
48
|
+
hidden_channels: int = 192,
|
49
|
+
filter_channels: int = 768,
|
50
|
+
n_heads: int = 2,
|
51
|
+
n_layers: int = 6,
|
52
|
+
kernel_size: int = 3,
|
53
|
+
p_dropout: float = 0.1,
|
54
|
+
n_layers_q: int = 3,
|
55
|
+
use_spectral_norm: bool = False,
|
56
|
+
gin_channels: int = 0,
|
57
|
+
use_sdp: bool = True,
|
58
|
+
segment_size: int = 8192,
|
59
|
+
# training
|
60
|
+
dataset: Optional[List[Union[str, Path]]] = None,
|
61
|
+
learning_rate: float = 2e-4,
|
62
|
+
betas: Tuple[float, float] = (0.8, 0.99),
|
63
|
+
eps: float = 1e-9,
|
64
|
+
batch_size: int = 1,
|
65
|
+
lr_decay: float = 0.999875,
|
66
|
+
init_lr_ratio: float = 1.0,
|
67
|
+
warmup_epochs: int = 0,
|
68
|
+
c_mel: int = 45,
|
69
|
+
c_kl: float = 1.0,
|
70
|
+
grad_clip: Optional[float] = None,
|
71
|
+
num_workers: int = 1,
|
72
|
+
seed: int = 1234,
|
73
|
+
num_test_examples: int = 5,
|
74
|
+
validation_split: float = 0.1,
|
75
|
+
max_phoneme_ids: Optional[int] = None,
|
76
|
+
**kwargs,
|
77
|
+
):
|
78
|
+
super().__init__()
|
79
|
+
self.save_hyperparameters()
|
80
|
+
|
81
|
+
if (self.hparams.num_speakers > 1) and (self.hparams.gin_channels <= 0):
|
82
|
+
# Default gin_channels for multi-speaker model
|
83
|
+
self.hparams.gin_channels = 512
|
84
|
+
|
85
|
+
# Set up models
|
86
|
+
self.model_g = SynthesizerTrn(
|
87
|
+
n_vocab=self.hparams.num_symbols,
|
88
|
+
spec_channels=self.hparams.filter_length // 2 + 1,
|
89
|
+
segment_size=self.hparams.segment_size // self.hparams.hop_length,
|
90
|
+
inter_channels=self.hparams.inter_channels,
|
91
|
+
hidden_channels=self.hparams.hidden_channels,
|
92
|
+
filter_channels=self.hparams.filter_channels,
|
93
|
+
n_heads=self.hparams.n_heads,
|
94
|
+
n_layers=self.hparams.n_layers,
|
95
|
+
kernel_size=self.hparams.kernel_size,
|
96
|
+
p_dropout=self.hparams.p_dropout,
|
97
|
+
resblock=self.hparams.resblock,
|
98
|
+
resblock_kernel_sizes=self.hparams.resblock_kernel_sizes,
|
99
|
+
resblock_dilation_sizes=self.hparams.resblock_dilation_sizes,
|
100
|
+
upsample_rates=self.hparams.upsample_rates,
|
101
|
+
upsample_initial_channel=self.hparams.upsample_initial_channel,
|
102
|
+
upsample_kernel_sizes=self.hparams.upsample_kernel_sizes,
|
103
|
+
n_speakers=self.hparams.num_speakers,
|
104
|
+
gin_channels=self.hparams.gin_channels,
|
105
|
+
use_sdp=self.hparams.use_sdp,
|
106
|
+
)
|
107
|
+
self.model_d = MultiPeriodDiscriminator(
|
108
|
+
use_spectral_norm=self.hparams.use_spectral_norm
|
109
|
+
)
|
110
|
+
|
111
|
+
# Dataset splits
|
112
|
+
self._train_dataset: Optional[Dataset] = None
|
113
|
+
self._val_dataset: Optional[Dataset] = None
|
114
|
+
self._test_dataset: Optional[Dataset] = None
|
115
|
+
self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)
|
116
|
+
|
117
|
+
# State kept between training optimizers
|
118
|
+
self._y = None
|
119
|
+
self._y_hat = None
|
120
|
+
|
121
|
+
def _load_datasets(
|
122
|
+
self,
|
123
|
+
validation_split: float,
|
124
|
+
num_test_examples: int,
|
125
|
+
max_phoneme_ids: Optional[int] = None,
|
126
|
+
):
|
127
|
+
if self.hparams.dataset is None:
|
128
|
+
_LOGGER.debug("No dataset to load")
|
129
|
+
return
|
130
|
+
|
131
|
+
full_dataset = PiperDataset(
|
132
|
+
self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
|
133
|
+
)
|
134
|
+
valid_set_size = int(len(full_dataset) * validation_split)
|
135
|
+
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
|
136
|
+
|
137
|
+
self._train_dataset, self._test_dataset, self._val_dataset = random_split(
|
138
|
+
full_dataset, [train_set_size, num_test_examples, valid_set_size]
|
139
|
+
)
|
140
|
+
|
141
|
+
def forward(self, text, text_lengths, scales, sid=None):
|
142
|
+
noise_scale = scales[0]
|
143
|
+
length_scale = scales[1]
|
144
|
+
noise_scale_w = scales[2]
|
145
|
+
audio, *_ = self.model_g.infer(
|
146
|
+
text,
|
147
|
+
text_lengths,
|
148
|
+
noise_scale=noise_scale,
|
149
|
+
length_scale=length_scale,
|
150
|
+
noise_scale_w=noise_scale_w,
|
151
|
+
sid=sid,
|
152
|
+
)
|
153
|
+
|
154
|
+
return audio
|
155
|
+
|
156
|
+
def train_dataloader(self):
|
157
|
+
return DataLoader(
|
158
|
+
self._train_dataset,
|
159
|
+
collate_fn=UtteranceCollate(
|
160
|
+
is_multispeaker=self.hparams.num_speakers > 1,
|
161
|
+
segment_size=self.hparams.segment_size,
|
162
|
+
),
|
163
|
+
num_workers=self.hparams.num_workers,
|
164
|
+
batch_size=self.hparams.batch_size,
|
165
|
+
)
|
166
|
+
|
167
|
+
def val_dataloader(self):
|
168
|
+
return DataLoader(
|
169
|
+
self._val_dataset,
|
170
|
+
collate_fn=UtteranceCollate(
|
171
|
+
is_multispeaker=self.hparams.num_speakers > 1,
|
172
|
+
segment_size=self.hparams.segment_size,
|
173
|
+
),
|
174
|
+
num_workers=self.hparams.num_workers,
|
175
|
+
batch_size=self.hparams.batch_size,
|
176
|
+
)
|
177
|
+
|
178
|
+
def test_dataloader(self):
|
179
|
+
return DataLoader(
|
180
|
+
self._test_dataset,
|
181
|
+
collate_fn=UtteranceCollate(
|
182
|
+
is_multispeaker=self.hparams.num_speakers > 1,
|
183
|
+
segment_size=self.hparams.segment_size,
|
184
|
+
),
|
185
|
+
num_workers=self.hparams.num_workers,
|
186
|
+
batch_size=self.hparams.batch_size,
|
187
|
+
)
|
188
|
+
|
189
|
+
def training_step(self, batch: Batch, batch_idx: int, optimizer_idx: int):
|
190
|
+
if optimizer_idx == 0:
|
191
|
+
return self.training_step_g(batch)
|
192
|
+
|
193
|
+
if optimizer_idx == 1:
|
194
|
+
return self.training_step_d(batch)
|
195
|
+
|
196
|
+
def training_step_g(self, batch: Batch):
|
197
|
+
x, x_lengths, y, _, spec, spec_lengths, speaker_ids = (
|
198
|
+
batch.phoneme_ids,
|
199
|
+
batch.phoneme_lengths,
|
200
|
+
batch.audios,
|
201
|
+
batch.audio_lengths,
|
202
|
+
batch.spectrograms,
|
203
|
+
batch.spectrogram_lengths,
|
204
|
+
batch.speaker_ids if batch.speaker_ids is not None else None,
|
205
|
+
)
|
206
|
+
(
|
207
|
+
y_hat,
|
208
|
+
l_length,
|
209
|
+
_attn,
|
210
|
+
ids_slice,
|
211
|
+
_x_mask,
|
212
|
+
z_mask,
|
213
|
+
(_z, z_p, m_p, logs_p, _m_q, logs_q),
|
214
|
+
) = self.model_g(x, x_lengths, spec, spec_lengths, speaker_ids)
|
215
|
+
self._y_hat = y_hat
|
216
|
+
|
217
|
+
mel = spec_to_mel_torch(
|
218
|
+
spec,
|
219
|
+
self.hparams.filter_length,
|
220
|
+
self.hparams.mel_channels,
|
221
|
+
self.hparams.sample_rate,
|
222
|
+
self.hparams.mel_fmin,
|
223
|
+
self.hparams.mel_fmax,
|
224
|
+
)
|
225
|
+
y_mel = slice_segments(
|
226
|
+
mel,
|
227
|
+
ids_slice,
|
228
|
+
self.hparams.segment_size // self.hparams.hop_length,
|
229
|
+
)
|
230
|
+
y_hat_mel = mel_spectrogram_torch(
|
231
|
+
y_hat.squeeze(1),
|
232
|
+
self.hparams.filter_length,
|
233
|
+
self.hparams.mel_channels,
|
234
|
+
self.hparams.sample_rate,
|
235
|
+
self.hparams.hop_length,
|
236
|
+
self.hparams.win_length,
|
237
|
+
self.hparams.mel_fmin,
|
238
|
+
self.hparams.mel_fmax,
|
239
|
+
)
|
240
|
+
y = slice_segments(
|
241
|
+
y,
|
242
|
+
ids_slice * self.hparams.hop_length,
|
243
|
+
self.hparams.segment_size,
|
244
|
+
) # slice
|
245
|
+
|
246
|
+
# Save for training_step_d
|
247
|
+
self._y = y
|
248
|
+
|
249
|
+
_y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)
|
250
|
+
|
251
|
+
with autocast(self.device.type, enabled=False):
|
252
|
+
# Generator loss
|
253
|
+
loss_dur = torch.sum(l_length.float())
|
254
|
+
loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.c_mel
|
255
|
+
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.c_kl
|
256
|
+
|
257
|
+
loss_fm = feature_loss(fmap_r, fmap_g)
|
258
|
+
loss_gen, _losses_gen = generator_loss(y_d_hat_g)
|
259
|
+
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
|
260
|
+
|
261
|
+
self.log("loss_gen_all", loss_gen_all)
|
262
|
+
|
263
|
+
return loss_gen_all
|
264
|
+
|
265
|
+
def training_step_d(self, batch: Batch):
|
266
|
+
# From training_step_g
|
267
|
+
y = self._y
|
268
|
+
y_hat = self._y_hat
|
269
|
+
y_d_hat_r, y_d_hat_g, _, _ = self.model_d(y, y_hat.detach())
|
270
|
+
|
271
|
+
with autocast(self.device.type, enabled=False):
|
272
|
+
# Discriminator
|
273
|
+
loss_disc, _losses_disc_r, _losses_disc_g = discriminator_loss(
|
274
|
+
y_d_hat_r, y_d_hat_g
|
275
|
+
)
|
276
|
+
loss_disc_all = loss_disc
|
277
|
+
|
278
|
+
self.log("loss_disc_all", loss_disc_all)
|
279
|
+
|
280
|
+
return loss_disc_all
|
281
|
+
|
282
|
+
def validation_step(self, batch: Batch, batch_idx: int):
|
283
|
+
val_loss = self.training_step_g(batch) + self.training_step_d(batch)
|
284
|
+
self.log("val_loss", val_loss)
|
285
|
+
|
286
|
+
# Generate audio examples
|
287
|
+
for utt_idx, test_utt in enumerate(self._test_dataset):
|
288
|
+
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
|
289
|
+
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
|
290
|
+
scales = [0.667, 1.0, 0.8]
|
291
|
+
sid = (
|
292
|
+
test_utt.speaker_id.to(self.device)
|
293
|
+
if test_utt.speaker_id is not None
|
294
|
+
else None
|
295
|
+
)
|
296
|
+
test_audio = self(text, text_lengths, scales, sid=sid).detach()
|
297
|
+
|
298
|
+
# Scale to make louder in [-1, 1]
|
299
|
+
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
|
300
|
+
|
301
|
+
tag = test_utt.text or str(utt_idx)
|
302
|
+
self.logger.experiment.add_audio(
|
303
|
+
tag, test_audio, sample_rate=self.hparams.sample_rate
|
304
|
+
)
|
305
|
+
|
306
|
+
return val_loss
|
307
|
+
|
308
|
+
def configure_optimizers(self):
|
309
|
+
optimizers = [
|
310
|
+
torch.optim.AdamW(
|
311
|
+
self.model_g.parameters(),
|
312
|
+
lr=self.hparams.learning_rate,
|
313
|
+
betas=self.hparams.betas,
|
314
|
+
eps=self.hparams.eps,
|
315
|
+
),
|
316
|
+
torch.optim.AdamW(
|
317
|
+
self.model_d.parameters(),
|
318
|
+
lr=self.hparams.learning_rate,
|
319
|
+
betas=self.hparams.betas,
|
320
|
+
eps=self.hparams.eps,
|
321
|
+
),
|
322
|
+
]
|
323
|
+
schedulers = [
|
324
|
+
torch.optim.lr_scheduler.ExponentialLR(
|
325
|
+
optimizers[0], gamma=self.hparams.lr_decay
|
326
|
+
),
|
327
|
+
torch.optim.lr_scheduler.ExponentialLR(
|
328
|
+
optimizers[1], gamma=self.hparams.lr_decay
|
329
|
+
),
|
330
|
+
]
|
331
|
+
|
332
|
+
return optimizers, schedulers
|
333
|
+
|
334
|
+
@staticmethod
|
335
|
+
def add_model_specific_args(parent_parser):
|
336
|
+
parser = parent_parser.add_argument_group("VitsModel")
|
337
|
+
parser.add_argument("--batch-size", type=int, required=True)
|
338
|
+
parser.add_argument("--validation-split", type=float, default=0.1)
|
339
|
+
parser.add_argument("--num-test-examples", type=int, default=5)
|
340
|
+
parser.add_argument(
|
341
|
+
"--max-phoneme-ids",
|
342
|
+
type=int,
|
343
|
+
help="Exclude utterances with phoneme id lists longer than this",
|
344
|
+
)
|
345
|
+
#
|
346
|
+
parser.add_argument("--hidden-channels", type=int, default=192)
|
347
|
+
parser.add_argument("--inter-channels", type=int, default=192)
|
348
|
+
parser.add_argument("--filter-channels", type=int, default=768)
|
349
|
+
parser.add_argument("--n-layers", type=int, default=6)
|
350
|
+
parser.add_argument("--n-heads", type=int, default=2)
|
351
|
+
#
|
352
|
+
return parent_parser
|
@@ -0,0 +1,58 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
def feature_loss(fmap_r, fmap_g):
|
5
|
+
loss = 0
|
6
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
7
|
+
for rl, gl in zip(dr, dg):
|
8
|
+
rl = rl.float().detach()
|
9
|
+
gl = gl.float()
|
10
|
+
loss += torch.mean(torch.abs(rl - gl))
|
11
|
+
|
12
|
+
return loss * 2
|
13
|
+
|
14
|
+
|
15
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
16
|
+
loss = 0
|
17
|
+
r_losses = []
|
18
|
+
g_losses = []
|
19
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
20
|
+
dr = dr.float()
|
21
|
+
dg = dg.float()
|
22
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
23
|
+
g_loss = torch.mean(dg**2)
|
24
|
+
loss += r_loss + g_loss
|
25
|
+
r_losses.append(r_loss.item())
|
26
|
+
g_losses.append(g_loss.item())
|
27
|
+
|
28
|
+
return loss, r_losses, g_losses
|
29
|
+
|
30
|
+
|
31
|
+
def generator_loss(disc_outputs):
|
32
|
+
loss = 0
|
33
|
+
gen_losses = []
|
34
|
+
for dg in disc_outputs:
|
35
|
+
dg = dg.float()
|
36
|
+
l_dg = torch.mean((1 - dg) ** 2)
|
37
|
+
gen_losses.append(l_dg)
|
38
|
+
loss += l_dg
|
39
|
+
|
40
|
+
return loss, gen_losses
|
41
|
+
|
42
|
+
|
43
|
+
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
44
|
+
"""
|
45
|
+
z_p, logs_q: [b, h, t_t]
|
46
|
+
m_p, logs_p: [b, h, t_t]
|
47
|
+
"""
|
48
|
+
z_p = z_p.float()
|
49
|
+
logs_q = logs_q.float()
|
50
|
+
m_p = m_p.float()
|
51
|
+
logs_p = logs_p.float()
|
52
|
+
z_mask = z_mask.float()
|
53
|
+
|
54
|
+
kl = logs_p - logs_q - 0.5
|
55
|
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
56
|
+
kl = torch.sum(kl * z_mask)
|
57
|
+
l_kl = kl / torch.sum(z_mask)
|
58
|
+
return l_kl
|
@@ -0,0 +1,139 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.utils.data
|
3
|
+
from librosa.filters import mel as librosa_mel_fn
|
4
|
+
|
5
|
+
MAX_WAV_VALUE = 32768.0
|
6
|
+
|
7
|
+
|
8
|
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
9
|
+
"""
|
10
|
+
PARAMS
|
11
|
+
------
|
12
|
+
C: compression factor
|
13
|
+
"""
|
14
|
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
15
|
+
|
16
|
+
|
17
|
+
def dynamic_range_decompression_torch(x, C=1):
|
18
|
+
"""
|
19
|
+
PARAMS
|
20
|
+
------
|
21
|
+
C: compression factor used to compress
|
22
|
+
"""
|
23
|
+
return torch.exp(x) / C
|
24
|
+
|
25
|
+
|
26
|
+
def spectral_normalize_torch(magnitudes):
|
27
|
+
output = dynamic_range_compression_torch(magnitudes)
|
28
|
+
return output
|
29
|
+
|
30
|
+
|
31
|
+
def spectral_de_normalize_torch(magnitudes):
|
32
|
+
output = dynamic_range_decompression_torch(magnitudes)
|
33
|
+
return output
|
34
|
+
|
35
|
+
|
36
|
+
mel_basis = {}
|
37
|
+
hann_window = {}
|
38
|
+
|
39
|
+
|
40
|
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
41
|
+
if torch.min(y) < -1.0:
|
42
|
+
print("min value is ", torch.min(y))
|
43
|
+
if torch.max(y) > 1.0:
|
44
|
+
print("max value is ", torch.max(y))
|
45
|
+
|
46
|
+
global hann_window
|
47
|
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
48
|
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
49
|
+
if wnsize_dtype_device not in hann_window:
|
50
|
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).type_as(y)
|
51
|
+
|
52
|
+
y = torch.nn.functional.pad(
|
53
|
+
y.unsqueeze(1),
|
54
|
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
55
|
+
mode="reflect",
|
56
|
+
)
|
57
|
+
y = y.squeeze(1)
|
58
|
+
|
59
|
+
spec = torch.view_as_real(
|
60
|
+
torch.stft(
|
61
|
+
y,
|
62
|
+
n_fft,
|
63
|
+
hop_length=hop_size,
|
64
|
+
win_length=win_size,
|
65
|
+
window=hann_window[wnsize_dtype_device],
|
66
|
+
center=center,
|
67
|
+
pad_mode="reflect",
|
68
|
+
normalized=False,
|
69
|
+
onesided=True,
|
70
|
+
return_complex=True,
|
71
|
+
)
|
72
|
+
)
|
73
|
+
|
74
|
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
75
|
+
|
76
|
+
return spec
|
77
|
+
|
78
|
+
|
79
|
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
80
|
+
global mel_basis
|
81
|
+
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
82
|
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
83
|
+
if fmax_dtype_device not in mel_basis:
|
84
|
+
mel = librosa_mel_fn(
|
85
|
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
86
|
+
)
|
87
|
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).type_as(spec)
|
88
|
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
89
|
+
spec = spectral_normalize_torch(spec)
|
90
|
+
return spec
|
91
|
+
|
92
|
+
|
93
|
+
def mel_spectrogram_torch(
|
94
|
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
95
|
+
):
|
96
|
+
if torch.min(y) < -1.0:
|
97
|
+
print("min value is ", torch.min(y))
|
98
|
+
if torch.max(y) > 1.0:
|
99
|
+
print("max value is ", torch.max(y))
|
100
|
+
|
101
|
+
global mel_basis, hann_window
|
102
|
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
103
|
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
104
|
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
105
|
+
if fmax_dtype_device not in mel_basis:
|
106
|
+
mel = librosa_mel_fn(
|
107
|
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
108
|
+
)
|
109
|
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).type_as(y)
|
110
|
+
if wnsize_dtype_device not in hann_window:
|
111
|
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).type_as(y)
|
112
|
+
|
113
|
+
y = torch.nn.functional.pad(
|
114
|
+
y.unsqueeze(1),
|
115
|
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
116
|
+
mode="reflect",
|
117
|
+
)
|
118
|
+
y = y.squeeze(1)
|
119
|
+
spec = torch.view_as_real(
|
120
|
+
torch.stft(
|
121
|
+
y,
|
122
|
+
n_fft,
|
123
|
+
hop_length=hop_size,
|
124
|
+
win_length=win_size,
|
125
|
+
window=hann_window[wnsize_dtype_device],
|
126
|
+
center=center,
|
127
|
+
pad_mode="reflect",
|
128
|
+
normalized=False,
|
129
|
+
onesided=True,
|
130
|
+
return_complex=True,
|
131
|
+
)
|
132
|
+
)
|
133
|
+
|
134
|
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
135
|
+
|
136
|
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
137
|
+
spec = spectral_normalize_torch(spec)
|
138
|
+
|
139
|
+
return spec
|