lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +141 -46
- lt_tensor/misc_utils.py +38 -1
- lt_tensor/model_zoo/__init__.py +18 -9
- lt_tensor/model_zoo/{bsc.py → basic.py} +118 -2
- lt_tensor/model_zoo/features.py +416 -0
- lt_tensor/model_zoo/fusion.py +164 -0
- lt_tensor/model_zoo/istft/generator.py +5 -65
- lt_tensor/model_zoo/istft/sg.py +142 -0
- lt_tensor/model_zoo/istft/trainer.py +227 -59
- lt_tensor/model_zoo/residual.py +252 -0
- lt_tensor/model_zoo/{tfrms.py → transformer.py} +2 -2
- lt_tensor/processors/audio.py +207 -80
- lt_tensor/transform.py +7 -16
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/METADATA +7 -5
- lt_tensor-0.0.1a14.dist-info/RECORD +32 -0
- lt_tensor/model_zoo/fsn.py +0 -67
- lt_tensor/model_zoo/gns.py +0 -185
- lt_tensor/model_zoo/istft.py +0 -591
- lt_tensor/model_zoo/rsd.py +0 -107
- lt_tensor-0.0.1a12.dist-info/RECORD +0 -32
- /lt_tensor/model_zoo/{disc.py → discriminator.py} +0 -0
- /lt_tensor/model_zoo/{pos.py → pos_encoder.py} +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import math
|
4
|
+
from einops import repeat
|
5
|
+
|
6
|
+
|
7
|
+
class SineGen(nn.Module):
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
samp_rate,
|
11
|
+
upsample_scale,
|
12
|
+
harmonic_num=0,
|
13
|
+
sine_amp=0.1,
|
14
|
+
noise_std=0.003,
|
15
|
+
voiced_threshold=0,
|
16
|
+
flag_for_pulse=False,
|
17
|
+
):
|
18
|
+
super().__init__()
|
19
|
+
self.sampling_rate = samp_rate
|
20
|
+
self.upsample_scale = upsample_scale
|
21
|
+
self.harmonic_num = harmonic_num
|
22
|
+
self.sine_amp = sine_amp
|
23
|
+
self.noise_std = noise_std
|
24
|
+
self.voiced_threshold = voiced_threshold
|
25
|
+
self.flag_for_pulse = flag_for_pulse
|
26
|
+
self.dim = self.harmonic_num + 1 # fundamental + harmonics
|
27
|
+
|
28
|
+
def _f02uv_b(self, f0):
|
29
|
+
return (f0 > self.voiced_threshold).float() # [B, T]
|
30
|
+
|
31
|
+
def _f02uv(self, f0):
|
32
|
+
return (f0 > self.voiced_threshold).float().unsqueeze(-1) # -> (B, T, 1)
|
33
|
+
|
34
|
+
@torch.no_grad()
|
35
|
+
def _f02sine(self, f0_values):
|
36
|
+
"""
|
37
|
+
f0_values: (B, T, 1)
|
38
|
+
Output: sine waves (B, T * upsample, dim)
|
39
|
+
"""
|
40
|
+
B, T, _ = f0_values.size()
|
41
|
+
f0_upsampled = repeat(
|
42
|
+
f0_values, "b t d -> b (t r) d", r=self.upsample_scale
|
43
|
+
) # (B, T_up, 1)
|
44
|
+
|
45
|
+
# Create harmonics
|
46
|
+
harmonics = (
|
47
|
+
torch.arange(1, self.dim + 1, device=f0_values.device)
|
48
|
+
.float()
|
49
|
+
.view(1, 1, -1)
|
50
|
+
)
|
51
|
+
f0_harm = f0_upsampled * harmonics # (B, T_up, dim)
|
52
|
+
|
53
|
+
# Convert Hz to radians (2πf/sr), then integrate to get phase
|
54
|
+
rad_values = f0_harm / self.sampling_rate # normalized freq
|
55
|
+
rad_values = rad_values % 1.0 # remove multiples of 2π
|
56
|
+
|
57
|
+
# Random initial phase for each harmonic (except 0th if pulse mode)
|
58
|
+
if self.flag_for_pulse:
|
59
|
+
rand_ini = torch.zeros((B, 1, self.dim), device=f0_values.device)
|
60
|
+
else:
|
61
|
+
rand_ini = torch.rand((B, 1, self.dim), device=f0_values.device)
|
62
|
+
|
63
|
+
rand_ini = rand_ini * 2 * math.pi
|
64
|
+
|
65
|
+
# Compute cumulative phase
|
66
|
+
rad_values = rad_values * 2 * math.pi
|
67
|
+
phase = torch.cumsum(rad_values, dim=1) + rand_ini # (B, T_up, dim)
|
68
|
+
|
69
|
+
sine_waves = torch.sin(phase) # (B, T_up, dim)
|
70
|
+
return sine_waves
|
71
|
+
|
72
|
+
def _forward(self, f0):
|
73
|
+
"""
|
74
|
+
f0: (B, T, 1)
|
75
|
+
returns: sine signal with harmonics and noise added
|
76
|
+
"""
|
77
|
+
sine_waves = self._f02sine(f0) # (B, T_up, dim)
|
78
|
+
uv = self._f02uv_b(f0) # (B, T, 1)
|
79
|
+
uv = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
|
80
|
+
|
81
|
+
# voiced sine + unvoiced noise
|
82
|
+
sine_signal = self.sine_amp * sine_waves * uv # (B, T_up, dim)
|
83
|
+
noise = torch.randn_like(sine_signal) * self.noise_std
|
84
|
+
output = sine_signal + noise * (1.0 - uv) # noise added only on unvoiced
|
85
|
+
|
86
|
+
return output # (B, T_up, dim)
|
87
|
+
|
88
|
+
def forward(self, f0):
|
89
|
+
"""
|
90
|
+
Args:
|
91
|
+
f0: (B, T) in Hz (before upsampling)
|
92
|
+
Returns:
|
93
|
+
sine_waves: (B, T_up, dim)
|
94
|
+
uv: (B, T_up, 1)
|
95
|
+
noise: (B, T_up, 1)
|
96
|
+
"""
|
97
|
+
B, T = f0.shape
|
98
|
+
device = f0.device
|
99
|
+
|
100
|
+
# Get uv mask (before upsampling)
|
101
|
+
uv = self._f02uv(f0) # (B, T, 1)
|
102
|
+
|
103
|
+
# Expand f0 to include harmonics: (B, T, dim)
|
104
|
+
f0 = f0.unsqueeze(-1) # (B, T, 1)
|
105
|
+
harmonics = (
|
106
|
+
torch.arange(1, self.dim + 1, device=device).float().view(1, 1, -1)
|
107
|
+
) # (1, 1, dim)
|
108
|
+
f0_harm = f0 * harmonics # (B, T, dim)
|
109
|
+
|
110
|
+
# Upsample
|
111
|
+
f0_harm_up = repeat(
|
112
|
+
f0_harm, "b t d -> b (t r) d", r=self.upsample_scale
|
113
|
+
) # (B, T_up, dim)
|
114
|
+
uv_up = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
|
115
|
+
|
116
|
+
# Convert to radians
|
117
|
+
rad_per_sample = f0_harm_up / self.sampling_rate # Hz → cycles/sample
|
118
|
+
rad_per_sample = rad_per_sample * 2 * math.pi # cycles → radians/sample
|
119
|
+
|
120
|
+
# Random phase init for each sample
|
121
|
+
B, T_up, D = rad_per_sample.shape
|
122
|
+
rand_phase = torch.rand(B, D, device=device) * 2 * math.pi # (B, D)
|
123
|
+
|
124
|
+
# Compute cumulative phase
|
125
|
+
phase = torch.cumsum(rad_per_sample, dim=1) + rand_phase.unsqueeze(
|
126
|
+
1
|
127
|
+
) # (B, T_up, D)
|
128
|
+
|
129
|
+
# Apply sine
|
130
|
+
sine_waves = torch.sin(phase) * self.sine_amp # (B, T_up, D)
|
131
|
+
|
132
|
+
# Handle unvoiced: create noise only for fundamental
|
133
|
+
noise = torch.randn(B, T_up, 1, device=device) * self.noise_std
|
134
|
+
if self.flag_for_pulse:
|
135
|
+
# If pulse mode is on, align phase at start of voiced segments
|
136
|
+
# Optional and tricky to implement — may require segmenting uv
|
137
|
+
pass
|
138
|
+
|
139
|
+
# Replace sine by noise for unvoiced (only on fundamental)
|
140
|
+
sine_waves[:, :, 0:1] = sine_waves[:, :, 0:1] * uv_up + noise * (1 - uv_up)
|
141
|
+
|
142
|
+
return sine_waves, uv_up, noise
|
@@ -1,41 +1,45 @@
|
|
1
|
-
__all__ = ["AudioSettings", "
|
1
|
+
__all__ = ["AudioSettings", "AudioDecoderTrainer", "AudioGeneratorOnlyTrainer"]
|
2
2
|
import gc
|
3
|
-
import math
|
4
3
|
import itertools
|
5
4
|
from lt_utils.common import *
|
6
5
|
import torch.nn.functional as F
|
7
6
|
from lt_tensor.torch_commons import *
|
8
7
|
from lt_tensor.model_base import Model
|
9
|
-
from lt_tensor.misc_utils import log_tensor
|
10
8
|
from lt_utils.misc_utils import log_traceback
|
11
9
|
from lt_tensor.processors import AudioProcessor
|
12
10
|
from lt_tensor.misc_utils import set_seed, clear_cache
|
13
|
-
from lt_utils.type_utils import is_dir, is_pathlike
|
14
|
-
from lt_tensor.config_templates import
|
11
|
+
from lt_utils.type_utils import is_dir, is_pathlike
|
12
|
+
from lt_tensor.config_templates import ModelConfig
|
15
13
|
from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
|
16
|
-
from lt_tensor.model_zoo.
|
17
|
-
|
14
|
+
from lt_tensor.model_zoo.discriminator import (
|
15
|
+
MultiPeriodDiscriminator,
|
16
|
+
MultiScaleDiscriminator,
|
17
|
+
)
|
18
18
|
|
19
19
|
|
20
|
-
def feature_loss(
|
21
|
-
loss = 0
|
22
|
-
for
|
23
|
-
for
|
24
|
-
loss +=
|
25
|
-
return loss
|
20
|
+
def feature_loss(fmap_r, fmap_g):
|
21
|
+
loss = 0
|
22
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
23
|
+
for rl, gl in zip(dr, dg):
|
24
|
+
loss += torch.mean(torch.abs(rl - gl))
|
25
|
+
return loss * 2
|
26
26
|
|
27
27
|
|
28
|
-
def generator_adv_loss(
|
29
|
-
loss = 0
|
30
|
-
for
|
31
|
-
|
28
|
+
def generator_adv_loss(disc_outputs):
|
29
|
+
loss = 0
|
30
|
+
for dg in disc_outputs:
|
31
|
+
l = torch.mean((1 - dg) ** 2)
|
32
|
+
loss += l
|
32
33
|
return loss
|
33
34
|
|
34
35
|
|
35
|
-
def discriminator_loss(
|
36
|
-
loss = 0
|
37
|
-
|
38
|
-
|
36
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
37
|
+
loss = 0
|
38
|
+
|
39
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
40
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
41
|
+
g_loss = torch.mean(dg**2)
|
42
|
+
loss += r_loss + g_loss
|
39
43
|
return loss
|
40
44
|
|
41
45
|
|
@@ -79,12 +83,12 @@ class AudioSettings(ModelConfig):
|
|
79
83
|
self.scheduler_template = scheduler_template
|
80
84
|
|
81
85
|
|
82
|
-
class
|
86
|
+
class AudioDecoderTrainer(Model):
|
83
87
|
def __init__(
|
84
88
|
self,
|
85
89
|
audio_processor: AudioProcessor,
|
86
90
|
settings: Optional[AudioSettings] = None,
|
87
|
-
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non
|
91
|
+
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
|
88
92
|
):
|
89
93
|
super().__init__()
|
90
94
|
if settings is None:
|
@@ -284,9 +288,6 @@ class AudioDecoder(Model):
|
|
284
288
|
win_length=self.settings.n_fft,
|
285
289
|
# length=real_audio.shape[-1]
|
286
290
|
)[:, : real_audio.shape[-1]]
|
287
|
-
# smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
|
288
|
-
# real_audio = real_audio[:, :, :smallest].squeeze(1)
|
289
|
-
# fake_audio = fake_audio[:, :smallest]
|
290
291
|
|
291
292
|
disc_kwargs = dict(
|
292
293
|
real_audio=real_audio,
|
@@ -299,7 +300,7 @@ class AudioDecoder(Model):
|
|
299
300
|
else:
|
300
301
|
disc_out = self._discriminator_step(**disc_kwargs)
|
301
302
|
|
302
|
-
|
303
|
+
generator_kwargs = dict(
|
303
304
|
mels=mels,
|
304
305
|
real_audio=real_audio,
|
305
306
|
fake_audio=fake_audio,
|
@@ -314,8 +315,8 @@ class AudioDecoder(Model):
|
|
314
315
|
|
315
316
|
if is_generator_frozen:
|
316
317
|
with torch.no_grad():
|
317
|
-
return self._generator_step(**
|
318
|
-
return self._generator_step(**
|
318
|
+
return self._generator_step(**generator_kwargs)
|
319
|
+
return self._generator_step(**generator_kwargs)
|
319
320
|
|
320
321
|
def _discriminator_step(
|
321
322
|
self,
|
@@ -324,7 +325,8 @@ class AudioDecoder(Model):
|
|
324
325
|
am_i_frozen: bool = False,
|
325
326
|
):
|
326
327
|
# ========== Discriminator Forward Pass ==========
|
327
|
-
|
328
|
+
if not am_i_frozen:
|
329
|
+
self.d_optim.zero_grad()
|
328
330
|
# MPD
|
329
331
|
real_mpd_preds, _ = self.mpd(real_audio)
|
330
332
|
fake_mpd_preds, _ = self.mpd(fake_audio)
|
@@ -337,7 +339,6 @@ class AudioDecoder(Model):
|
|
337
339
|
loss_d = loss_d_mpd + loss_d_msd
|
338
340
|
|
339
341
|
if not am_i_frozen:
|
340
|
-
self.d_optim.zero_grad()
|
341
342
|
loss_d.backward()
|
342
343
|
self.d_optim.step()
|
343
344
|
|
@@ -359,6 +360,8 @@ class AudioDecoder(Model):
|
|
359
360
|
am_i_frozen: bool = False,
|
360
361
|
):
|
361
362
|
# ========== Generator Loss ==========
|
363
|
+
if not am_i_frozen:
|
364
|
+
self.g_optim.zero_grad()
|
362
365
|
real_mpd_feats = self.mpd(real_audio)[1]
|
363
366
|
real_msd_feats = self.msd(real_audio)[1]
|
364
367
|
|
@@ -372,7 +375,7 @@ class AudioDecoder(Model):
|
|
372
375
|
|
373
376
|
loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
374
377
|
loss_mel = (
|
375
|
-
F.
|
378
|
+
F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
|
376
379
|
)
|
377
380
|
loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
|
378
381
|
|
@@ -380,9 +383,10 @@ class AudioDecoder(Model):
|
|
380
383
|
|
381
384
|
loss_g = loss_adv + loss_fm + loss_stft + loss_mel
|
382
385
|
if not am_i_frozen:
|
383
|
-
self.g_optim.zero_grad()
|
384
386
|
loss_g.backward()
|
385
387
|
self.g_optim.step()
|
388
|
+
|
389
|
+
lr_g, lr_d = self.get_lr()
|
386
390
|
return {
|
387
391
|
"loss_g": loss_g.item(),
|
388
392
|
"loss_d": loss_d,
|
@@ -390,8 +394,8 @@ class AudioDecoder(Model):
|
|
390
394
|
"loss_fm": loss_fm.item(),
|
391
395
|
"loss_stft": loss_stft.item(),
|
392
396
|
"loss_mel": loss_mel.item(),
|
393
|
-
"lr_g":
|
394
|
-
"lr_d":
|
397
|
+
"lr_g": lr_g,
|
398
|
+
"lr_d": lr_d,
|
395
399
|
}
|
396
400
|
|
397
401
|
def step_scheduler(
|
@@ -417,34 +421,198 @@ class AudioDecoder(Model):
|
|
417
421
|
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
418
422
|
|
419
423
|
|
420
|
-
class
|
424
|
+
class AudioGeneratorOnlyTrainer(Model):
|
421
425
|
def __init__(
|
422
426
|
self,
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
[1, 3, 5],
|
427
|
-
[1, 3, 5],
|
428
|
-
[1, 3, 5],
|
429
|
-
],
|
430
|
-
activation: nn.Module = nn.LeakyReLU(0.1),
|
427
|
+
audio_processor: AudioProcessor,
|
428
|
+
settings: Optional[AudioSettings] = None,
|
429
|
+
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
|
431
430
|
):
|
432
431
|
super().__init__()
|
433
|
-
|
434
|
-
|
435
|
-
|
432
|
+
if settings is None:
|
433
|
+
self.settings = AudioSettings()
|
434
|
+
elif isinstance(settings, dict):
|
435
|
+
self.settings = AudioSettings(**settings)
|
436
|
+
elif isinstance(settings, AudioSettings):
|
437
|
+
self.settings = settings
|
438
|
+
else:
|
439
|
+
raise ValueError(
|
440
|
+
"Cannot initialize the waveDecoder with the given settings. "
|
441
|
+
"Use either a dictionary, or the class WaveSettings to setup the settings. "
|
442
|
+
"Alternatively, leave it None to use the default values."
|
443
|
+
)
|
444
|
+
if self.settings.seed is not None:
|
445
|
+
set_seed(self.settings.seed)
|
446
|
+
if generator is None:
|
447
|
+
generator = iSTFTGenerator
|
448
|
+
self.generator: iSTFTGenerator = generator(
|
449
|
+
in_channels=self.settings.in_channels,
|
450
|
+
upsample_rates=self.settings.upsample_rates,
|
451
|
+
upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
|
452
|
+
upsample_initial_channel=self.settings.upsample_initial_channel,
|
453
|
+
resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
|
454
|
+
resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
|
455
|
+
n_fft=self.settings.n_fft,
|
456
|
+
activation=self.settings.activation,
|
457
|
+
)
|
458
|
+
self.generator.eval()
|
459
|
+
self.gen_training = False
|
460
|
+
self.audio_processor = audio_processor
|
436
461
|
|
437
|
-
|
438
|
-
|
462
|
+
def setup_training_mode(self, *args, **kwargs):
|
463
|
+
self.finish_training_setup()
|
464
|
+
self.update_schedulers_and_optimizer()
|
465
|
+
self.gen_training = True
|
466
|
+
return True
|
467
|
+
|
468
|
+
def update_schedulers_and_optimizer(self):
|
469
|
+
self.g_optim = optim.AdamW(
|
470
|
+
self.generator.parameters(),
|
471
|
+
lr=self.settings.lr,
|
472
|
+
betas=self.settings.adamw_betas,
|
473
|
+
)
|
474
|
+
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
439
475
|
|
440
|
-
|
476
|
+
def set_lr(self, new_lr: float = 1e-4):
|
477
|
+
if self.g_optim is not None:
|
478
|
+
for groups in self.g_optim.param_groups:
|
479
|
+
groups["lr"] = new_lr
|
480
|
+
return self.get_lr()
|
441
481
|
|
442
|
-
def
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
482
|
+
def get_lr(self) -> Tuple[float, float]:
|
483
|
+
if self.g_optim is not None:
|
484
|
+
return self.g_optim.param_groups[0]["lr"]
|
485
|
+
return float("nan")
|
486
|
+
|
487
|
+
def save_weights(self, path, replace=True):
|
488
|
+
is_pathlike(path, check_if_empty=True, validate=True)
|
489
|
+
if str(path).endswith(".pt"):
|
490
|
+
path = Path(path).parent
|
491
|
+
else:
|
492
|
+
path = Path(path)
|
493
|
+
self.generator.save_weights(Path(path, "generator.pt"), replace)
|
494
|
+
|
495
|
+
def load_weights(
|
496
|
+
self,
|
497
|
+
path,
|
498
|
+
raise_if_not_exists=False,
|
499
|
+
strict=True,
|
500
|
+
assign=False,
|
501
|
+
weights_only=False,
|
502
|
+
mmap=None,
|
503
|
+
**torch_loader_kwargs
|
504
|
+
):
|
505
|
+
is_pathlike(path, check_if_empty=True, validate=True)
|
506
|
+
if str(path).endswith(".pt"):
|
507
|
+
path = Path(path)
|
508
|
+
else:
|
509
|
+
path = Path(path, "generator.pt")
|
510
|
+
|
511
|
+
self.generator.load_weights(
|
512
|
+
path,
|
513
|
+
raise_if_not_exists,
|
514
|
+
strict,
|
515
|
+
assign,
|
516
|
+
weights_only,
|
517
|
+
mmap,
|
518
|
+
**torch_loader_kwargs,
|
519
|
+
)
|
520
|
+
|
521
|
+
def finish_training_setup(self):
|
522
|
+
gc.collect()
|
523
|
+
clear_cache()
|
524
|
+
self.eval()
|
525
|
+
self.gen_training = False
|
526
|
+
|
527
|
+
def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
|
528
|
+
"""Returns the generated spec and phase"""
|
529
|
+
return self.generator.forward(mel_spec)
|
530
|
+
|
531
|
+
def inference(
|
532
|
+
self,
|
533
|
+
mel_spec: Tensor,
|
534
|
+
return_dict: bool = False,
|
535
|
+
) -> Union[Dict[str, Tensor], Tensor]:
|
536
|
+
spec, phase = super().inference(mel_spec)
|
537
|
+
wave = self.audio_processor.inverse_transform(
|
538
|
+
spec,
|
539
|
+
phase,
|
540
|
+
self.settings.n_fft,
|
541
|
+
hop_length=4,
|
542
|
+
win_length=self.settings.n_fft,
|
543
|
+
)
|
544
|
+
if not return_dict:
|
545
|
+
return wave[:, : wave.shape[-1] - 256]
|
546
|
+
return {
|
547
|
+
"wave": wave[:, : wave.shape[-1] - 256],
|
548
|
+
"spec": spec,
|
549
|
+
"phase": phase,
|
550
|
+
}
|
551
|
+
|
552
|
+
def set_device(self, device: str):
|
553
|
+
self.to(device=device)
|
554
|
+
self.generator.to(device=device)
|
555
|
+
self.audio_processor.to(device=device)
|
556
|
+
self.msd.to(device=device)
|
557
|
+
self.mpd.to(device=device)
|
558
|
+
|
559
|
+
def train_step(
|
560
|
+
self,
|
561
|
+
mels: Tensor,
|
562
|
+
real_audio: Tensor,
|
563
|
+
stft_scale: float = 1.0,
|
564
|
+
mel_scale: float = 1.0,
|
565
|
+
ext_loss: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
|
566
|
+
):
|
567
|
+
if not self.gen_training:
|
568
|
+
self.setup_training_mode()
|
569
|
+
|
570
|
+
self.g_optim.zero_grad()
|
571
|
+
spec, phase = self.generator.train_step(mels)
|
572
|
+
|
573
|
+
real_audio = real_audio.squeeze(1)
|
574
|
+
with torch.no_grad():
|
575
|
+
fake_audio = self.audio_processor.inverse_transform(
|
576
|
+
spec,
|
577
|
+
phase,
|
578
|
+
self.settings.n_fft,
|
579
|
+
hop_length=4,
|
580
|
+
win_length=self.settings.n_fft,
|
581
|
+
)[:, : real_audio.shape[-1]]
|
582
|
+
loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
583
|
+
loss_mel = (
|
584
|
+
F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
|
585
|
+
)
|
586
|
+
loss_g.backward()
|
587
|
+
loss_g = loss_stft + loss_mel
|
588
|
+
loss_ext = 0
|
589
|
+
|
590
|
+
if ext_loss is not None:
|
591
|
+
l_ext = ext_loss(fake_audio, real_audio)
|
592
|
+
loss_g = loss_g + l_ext
|
593
|
+
loss_ext = l_ext.item()
|
594
|
+
|
595
|
+
self.g_optim.step()
|
596
|
+
return {
|
597
|
+
"loss": loss_g.item(),
|
598
|
+
"loss_stft": loss_stft.item(),
|
599
|
+
"loss_mel": loss_mel.item(),
|
600
|
+
"loss_ext": loss_ext,
|
601
|
+
"lr": self.get_lr(),
|
602
|
+
}
|
603
|
+
|
604
|
+
def step_scheduler(self):
|
605
|
+
|
606
|
+
if self.g_scheduler is not None:
|
607
|
+
self.g_scheduler.step()
|
608
|
+
|
609
|
+
def reset_schedulers(self, lr: Optional[float] = None):
|
610
|
+
"""
|
611
|
+
In case you have adopted another strategy, with this function,
|
612
|
+
it is possible restart the scheduler and set the lr to another value.
|
613
|
+
"""
|
614
|
+
if lr is not None:
|
615
|
+
self.set_lr(lr)
|
616
|
+
if self.g_optim is not None:
|
617
|
+
self.g_scheduler = None
|
618
|
+
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|