lt-tensor 0.0.1a13__py3-none-any.whl → 0.0.1a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +23 -6
- lt_tensor/misc_utils.py +1 -1
- lt_tensor/model_base.py +163 -123
- lt_tensor/model_zoo/diffwave/__init__.py +0 -0
- lt_tensor/model_zoo/diffwave/model.py +200 -0
- lt_tensor/model_zoo/diffwave/params.py +58 -0
- lt_tensor/model_zoo/discriminator.py +269 -151
- lt_tensor/model_zoo/features.py +102 -11
- lt_tensor/model_zoo/istft/generator.py +10 -66
- lt_tensor/model_zoo/istft/trainer.py +224 -72
- lt_tensor/model_zoo/residual.py +136 -32
- lt_tensor/processors/audio.py +5 -16
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/METADATA +2 -2
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/RECORD +17 -14
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,7 @@
|
|
1
|
-
__all__ = ["iSTFTGenerator"
|
2
|
-
import gc
|
3
|
-
import math
|
4
|
-
import itertools
|
1
|
+
__all__ = ["iSTFTGenerator"]
|
5
2
|
from lt_utils.common import *
|
6
3
|
from lt_tensor.torch_commons import *
|
7
|
-
from lt_tensor.
|
8
|
-
from lt_tensor.misc_utils import log_tensor
|
9
|
-
from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
|
10
|
-
from lt_utils.misc_utils import log_traceback
|
11
|
-
from lt_tensor.processors import AudioProcessor
|
12
|
-
from lt_utils.type_utils import is_dir, is_pathlike
|
13
|
-
from lt_tensor.misc_utils import set_seed, clear_cache
|
14
|
-
from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
15
|
-
import torch.nn.functional as F
|
16
|
-
from lt_tensor.config_templates import updateDict, ModelConfig
|
17
|
-
|
18
|
-
|
19
|
-
class ResBlocks(ConvNets):
|
20
|
-
def __init__(
|
21
|
-
self,
|
22
|
-
channels: int,
|
23
|
-
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
24
|
-
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
25
|
-
[1, 3, 5],
|
26
|
-
[1, 3, 5],
|
27
|
-
[1, 3, 5],
|
28
|
-
],
|
29
|
-
activation: nn.Module = nn.LeakyReLU(0.1),
|
30
|
-
):
|
31
|
-
super().__init__()
|
32
|
-
self.num_kernels = len(resblock_kernel_sizes)
|
33
|
-
self.rb = nn.ModuleList()
|
34
|
-
self.activation = activation
|
35
|
-
|
36
|
-
for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
37
|
-
self.rb.append(ResBlock1D(channels, k, j, activation))
|
38
|
-
|
39
|
-
self.rb.apply(self.init_weights)
|
40
|
-
|
41
|
-
def forward(self, x: torch.Tensor):
|
42
|
-
xs = None
|
43
|
-
for i, block in enumerate(self.rb):
|
44
|
-
if i == 0:
|
45
|
-
xs = block(x)
|
46
|
-
else:
|
47
|
-
xs += block(x)
|
48
|
-
x = xs / self.num_kernels
|
49
|
-
return self.activation(x)
|
4
|
+
from lt_tensor.model_zoo.residual import ConvNets, ResBlocks1D, ResBlock1D, ResBlock1D2
|
50
5
|
|
51
6
|
|
52
7
|
class iSTFTGenerator(ConvNets):
|
@@ -65,6 +20,7 @@ class iSTFTGenerator(ConvNets):
|
|
65
20
|
n_fft: int = 16,
|
66
21
|
activation: nn.Module = nn.LeakyReLU(0.1),
|
67
22
|
hop_length: int = 256,
|
23
|
+
residual_cls: Union[ResBlock1D, ResBlock1D2] = ResBlock1D
|
68
24
|
):
|
69
25
|
super().__init__()
|
70
26
|
self.num_kernels = len(resblock_kernel_sizes)
|
@@ -82,6 +38,7 @@ class iSTFTGenerator(ConvNets):
|
|
82
38
|
upsample_initial_channel,
|
83
39
|
resblock_kernel_sizes,
|
84
40
|
resblock_dilation_sizes,
|
41
|
+
residual_cls
|
85
42
|
)
|
86
43
|
)
|
87
44
|
|
@@ -91,25 +48,13 @@ class iSTFTGenerator(ConvNets):
|
|
91
48
|
self.conv_post.apply(self.init_weights)
|
92
49
|
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
93
50
|
|
94
|
-
self.phase = nn.Sequential(
|
95
|
-
nn.LeakyReLU(0.2),
|
96
|
-
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
97
|
-
nn.LeakyReLU(0.2),
|
98
|
-
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
99
|
-
)
|
100
|
-
self.spec = nn.Sequential(
|
101
|
-
nn.LeakyReLU(0.2),
|
102
|
-
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
103
|
-
nn.LeakyReLU(0.2),
|
104
|
-
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
105
|
-
)
|
106
|
-
|
107
51
|
def _make_blocks(
|
108
52
|
self,
|
109
53
|
state: Tuple[int, int, int],
|
110
54
|
upsample_initial_channel: int,
|
111
55
|
resblock_kernel_sizes: List[Union[int, List[int]]],
|
112
56
|
resblock_dilation_sizes: List[int | List[int]],
|
57
|
+
residual: nn.Module
|
113
58
|
):
|
114
59
|
i, k, u = state
|
115
60
|
channels = upsample_initial_channel // (2 ** (i + 1))
|
@@ -127,11 +72,12 @@ class iSTFTGenerator(ConvNets):
|
|
127
72
|
)
|
128
73
|
).apply(self.init_weights),
|
129
74
|
),
|
130
|
-
residual=
|
75
|
+
residual=ResBlocks1D(
|
131
76
|
channels,
|
132
77
|
resblock_kernel_sizes,
|
133
78
|
resblock_dilation_sizes,
|
134
79
|
self.activation,
|
80
|
+
residual
|
135
81
|
),
|
136
82
|
)
|
137
83
|
)
|
@@ -142,9 +88,7 @@ class iSTFTGenerator(ConvNets):
|
|
142
88
|
x = block["up"](x)
|
143
89
|
x = block["residual"](x)
|
144
90
|
|
145
|
-
x = self.reflection_pad(x)
|
146
|
-
|
147
|
-
|
148
|
-
phase = torch.sin(self.phase(x[:, self.post_n_fft :, :]))
|
149
|
-
|
91
|
+
x = self.conv_post(self.activation(self.reflection_pad(x)))
|
92
|
+
spec = torch.exp(x[:, : self.post_n_fft, :])
|
93
|
+
phase = torch.sin(x[:, self.post_n_fft :, :])
|
150
94
|
return spec, phase
|
@@ -1,20 +1,21 @@
|
|
1
|
-
__all__ = ["AudioSettings", "
|
1
|
+
__all__ = ["AudioSettings", "AudioDecoderTrainer", "AudioGeneratorOnlyTrainer"]
|
2
2
|
import gc
|
3
|
-
import math
|
4
3
|
import itertools
|
5
4
|
from lt_utils.common import *
|
6
5
|
import torch.nn.functional as F
|
7
6
|
from lt_tensor.torch_commons import *
|
8
7
|
from lt_tensor.model_base import Model
|
9
|
-
from lt_tensor.misc_utils import log_tensor
|
10
8
|
from lt_utils.misc_utils import log_traceback
|
11
9
|
from lt_tensor.processors import AudioProcessor
|
12
10
|
from lt_tensor.misc_utils import set_seed, clear_cache
|
13
|
-
from lt_utils.type_utils import is_dir, is_pathlike
|
14
|
-
from lt_tensor.config_templates import
|
11
|
+
from lt_utils.type_utils import is_dir, is_pathlike
|
12
|
+
from lt_tensor.config_templates import ModelConfig
|
15
13
|
from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
|
16
|
-
from lt_tensor.model_zoo.
|
17
|
-
|
14
|
+
from lt_tensor.model_zoo.discriminator import (
|
15
|
+
MultiPeriodDiscriminator,
|
16
|
+
MultiScaleDiscriminator,
|
17
|
+
)
|
18
|
+
from lt_tensor.model_zoo.residual import ResBlock1D2, ResBlock1D
|
18
19
|
|
19
20
|
|
20
21
|
def feature_loss(fmap_r, fmap_g):
|
@@ -29,7 +30,6 @@ def generator_adv_loss(disc_outputs):
|
|
29
30
|
loss = 0
|
30
31
|
for dg in disc_outputs:
|
31
32
|
l = torch.mean((1 - dg) ** 2)
|
32
|
-
|
33
33
|
loss += l
|
34
34
|
return loss
|
35
35
|
|
@@ -44,29 +44,6 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|
44
44
|
return loss
|
45
45
|
|
46
46
|
|
47
|
-
"""def feature_loss(fmap_r, fmap_g):
|
48
|
-
loss = 0
|
49
|
-
for dr, dg in zip(fmap_r, fmap_g):
|
50
|
-
for rl, gl in zip(dr, dg):
|
51
|
-
loss += torch.mean(torch.abs(rl - gl))
|
52
|
-
return loss * 2
|
53
|
-
|
54
|
-
|
55
|
-
def generator_adv_loss(fake_preds):
|
56
|
-
loss = 0.0
|
57
|
-
for f in fake_preds:
|
58
|
-
loss += torch.mean((f - 1.0) ** 2)
|
59
|
-
return loss
|
60
|
-
|
61
|
-
|
62
|
-
def discriminator_loss(real_preds, fake_preds):
|
63
|
-
loss = 0.0
|
64
|
-
for r, f in zip(real_preds, fake_preds):
|
65
|
-
loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
|
66
|
-
return loss
|
67
|
-
"""
|
68
|
-
|
69
|
-
|
70
47
|
class AudioSettings(ModelConfig):
|
71
48
|
def __init__(
|
72
49
|
self,
|
@@ -90,6 +67,7 @@ class AudioSettings(ModelConfig):
|
|
90
67
|
scheduler_template: Callable[
|
91
68
|
[optim.Optimizer], optim.lr_scheduler.LRScheduler
|
92
69
|
] = lambda optimizer: optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998),
|
70
|
+
residual_cls: Union[ResBlock1D, ResBlock1D2] = ResBlock1D,
|
93
71
|
):
|
94
72
|
self.in_channels = n_mels
|
95
73
|
self.upsample_rates = upsample_rates
|
@@ -105,14 +83,15 @@ class AudioSettings(ModelConfig):
|
|
105
83
|
self.lr = lr
|
106
84
|
self.adamw_betas = adamw_betas
|
107
85
|
self.scheduler_template = scheduler_template
|
86
|
+
self.residual_cls = residual_cls
|
108
87
|
|
109
88
|
|
110
|
-
class
|
89
|
+
class AudioDecoderTrainer(Model):
|
111
90
|
def __init__(
|
112
91
|
self,
|
113
92
|
audio_processor: AudioProcessor,
|
114
93
|
settings: Optional[AudioSettings] = None,
|
115
|
-
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non
|
94
|
+
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
|
116
95
|
):
|
117
96
|
super().__init__()
|
118
97
|
if settings is None:
|
@@ -175,14 +154,20 @@ class AudioDecoder(Model):
|
|
175
154
|
return True
|
176
155
|
|
177
156
|
def update_schedulers_and_optimizer(self):
|
157
|
+
gc.collect()
|
158
|
+
self.g_optim = None
|
159
|
+
self.g_scheduler = None
|
160
|
+
gc.collect()
|
178
161
|
self.g_optim = optim.AdamW(
|
179
162
|
self.generator.parameters(),
|
180
163
|
lr=self.settings.lr,
|
181
164
|
betas=self.settings.adamw_betas,
|
182
165
|
)
|
166
|
+
gc.collect()
|
183
167
|
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
184
168
|
if any([self.mpd is None, self.msd is None]):
|
185
169
|
return
|
170
|
+
gc.collect()
|
186
171
|
self.d_optim = optim.AdamW(
|
187
172
|
itertools.chain(self.mpd.parameters(), self.msd.parameters()),
|
188
173
|
lr=self.settings.lr,
|
@@ -274,9 +259,9 @@ class AudioDecoder(Model):
|
|
274
259
|
win_length=self.settings.n_fft,
|
275
260
|
)
|
276
261
|
if not return_dict:
|
277
|
-
return wave
|
262
|
+
return wave
|
278
263
|
return {
|
279
|
-
"wave": wave
|
264
|
+
"wave": wave,
|
280
265
|
"spec": spec,
|
281
266
|
"phase": phase,
|
282
267
|
}
|
@@ -310,8 +295,8 @@ class AudioDecoder(Model):
|
|
310
295
|
self.settings.n_fft,
|
311
296
|
hop_length=4,
|
312
297
|
win_length=self.settings.n_fft,
|
313
|
-
|
314
|
-
)
|
298
|
+
length=real_audio.shape[-1],
|
299
|
+
)
|
315
300
|
|
316
301
|
disc_kwargs = dict(
|
317
302
|
real_audio=real_audio,
|
@@ -324,7 +309,7 @@ class AudioDecoder(Model):
|
|
324
309
|
else:
|
325
310
|
disc_out = self._discriminator_step(**disc_kwargs)
|
326
311
|
|
327
|
-
|
312
|
+
generator_kwargs = dict(
|
328
313
|
mels=mels,
|
329
314
|
real_audio=real_audio,
|
330
315
|
fake_audio=fake_audio,
|
@@ -339,8 +324,8 @@ class AudioDecoder(Model):
|
|
339
324
|
|
340
325
|
if is_generator_frozen:
|
341
326
|
with torch.no_grad():
|
342
|
-
return self._generator_step(**
|
343
|
-
return self._generator_step(**
|
327
|
+
return self._generator_step(**generator_kwargs)
|
328
|
+
return self._generator_step(**generator_kwargs)
|
344
329
|
|
345
330
|
def _discriminator_step(
|
346
331
|
self,
|
@@ -349,7 +334,8 @@ class AudioDecoder(Model):
|
|
349
334
|
am_i_frozen: bool = False,
|
350
335
|
):
|
351
336
|
# ========== Discriminator Forward Pass ==========
|
352
|
-
|
337
|
+
if not am_i_frozen:
|
338
|
+
self.d_optim.zero_grad()
|
353
339
|
# MPD
|
354
340
|
real_mpd_preds, _ = self.mpd(real_audio)
|
355
341
|
fake_mpd_preds, _ = self.mpd(fake_audio)
|
@@ -362,7 +348,6 @@ class AudioDecoder(Model):
|
|
362
348
|
loss_d = loss_d_mpd + loss_d_msd
|
363
349
|
|
364
350
|
if not am_i_frozen:
|
365
|
-
self.d_optim.zero_grad()
|
366
351
|
loss_d.backward()
|
367
352
|
self.d_optim.step()
|
368
353
|
|
@@ -384,6 +369,8 @@ class AudioDecoder(Model):
|
|
384
369
|
am_i_frozen: bool = False,
|
385
370
|
):
|
386
371
|
# ========== Generator Loss ==========
|
372
|
+
if not am_i_frozen:
|
373
|
+
self.g_optim.zero_grad()
|
387
374
|
real_mpd_feats = self.mpd(real_audio)[1]
|
388
375
|
real_msd_feats = self.msd(real_audio)[1]
|
389
376
|
|
@@ -395,7 +382,7 @@ class AudioDecoder(Model):
|
|
395
382
|
loss_fm_mpd = feature_loss(real_mpd_feats, fake_mpd_feats)
|
396
383
|
loss_fm_msd = feature_loss(real_msd_feats, fake_msd_feats)
|
397
384
|
|
398
|
-
loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
385
|
+
# loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
399
386
|
loss_mel = (
|
400
387
|
F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
|
401
388
|
)
|
@@ -403,20 +390,21 @@ class AudioDecoder(Model):
|
|
403
390
|
|
404
391
|
loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
|
405
392
|
|
406
|
-
loss_g = loss_adv + loss_fm +
|
393
|
+
loss_g = loss_adv + loss_fm + loss_mel # + loss_stft
|
407
394
|
if not am_i_frozen:
|
408
|
-
self.g_optim.zero_grad()
|
409
395
|
loss_g.backward()
|
410
396
|
self.g_optim.step()
|
397
|
+
|
398
|
+
lr_g, lr_d = self.get_lr()
|
411
399
|
return {
|
412
400
|
"loss_g": loss_g.item(),
|
413
401
|
"loss_d": loss_d,
|
414
402
|
"loss_adv": loss_adv.item(),
|
415
403
|
"loss_fm": loss_fm.item(),
|
416
|
-
"loss_stft": loss_stft.item(),
|
404
|
+
"loss_stft": 1.0, # loss_stft.item(),
|
417
405
|
"loss_mel": loss_mel.item(),
|
418
|
-
"lr_g":
|
419
|
-
"lr_d":
|
406
|
+
"lr_g": lr_g,
|
407
|
+
"lr_d": lr_d,
|
420
408
|
}
|
421
409
|
|
422
410
|
def step_scheduler(
|
@@ -442,34 +430,198 @@ class AudioDecoder(Model):
|
|
442
430
|
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
443
431
|
|
444
432
|
|
445
|
-
class
|
433
|
+
class AudioGeneratorOnlyTrainer(Model):
|
446
434
|
def __init__(
|
447
435
|
self,
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
[1, 3, 5],
|
452
|
-
[1, 3, 5],
|
453
|
-
[1, 3, 5],
|
454
|
-
],
|
455
|
-
activation: nn.Module = nn.LeakyReLU(0.1),
|
436
|
+
audio_processor: AudioProcessor,
|
437
|
+
settings: Optional[AudioSettings] = None,
|
438
|
+
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
|
456
439
|
):
|
457
440
|
super().__init__()
|
458
|
-
|
459
|
-
|
460
|
-
|
441
|
+
if settings is None:
|
442
|
+
self.settings = AudioSettings()
|
443
|
+
elif isinstance(settings, dict):
|
444
|
+
self.settings = AudioSettings(**settings)
|
445
|
+
elif isinstance(settings, AudioSettings):
|
446
|
+
self.settings = settings
|
447
|
+
else:
|
448
|
+
raise ValueError(
|
449
|
+
"Cannot initialize the waveDecoder with the given settings. "
|
450
|
+
"Use either a dictionary, or the class WaveSettings to setup the settings. "
|
451
|
+
"Alternatively, leave it None to use the default values."
|
452
|
+
)
|
453
|
+
if self.settings.seed is not None:
|
454
|
+
set_seed(self.settings.seed)
|
455
|
+
if generator is None:
|
456
|
+
generator = iSTFTGenerator
|
457
|
+
self.generator: iSTFTGenerator = generator(
|
458
|
+
in_channels=self.settings.in_channels,
|
459
|
+
upsample_rates=self.settings.upsample_rates,
|
460
|
+
upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
|
461
|
+
upsample_initial_channel=self.settings.upsample_initial_channel,
|
462
|
+
resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
|
463
|
+
resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
|
464
|
+
n_fft=self.settings.n_fft,
|
465
|
+
activation=self.settings.activation,
|
466
|
+
)
|
467
|
+
self.generator.eval()
|
468
|
+
self.gen_training = False
|
469
|
+
self.audio_processor = audio_processor
|
470
|
+
|
471
|
+
def setup_training_mode(self, *args, **kwargs):
|
472
|
+
self.finish_training_setup()
|
473
|
+
self.update_schedulers_and_optimizer()
|
474
|
+
self.gen_training = True
|
475
|
+
return True
|
476
|
+
|
477
|
+
def update_schedulers_and_optimizer(self):
|
478
|
+
self.g_optim = optim.AdamW(
|
479
|
+
self.generator.parameters(),
|
480
|
+
lr=self.settings.lr,
|
481
|
+
betas=self.settings.adamw_betas,
|
482
|
+
)
|
483
|
+
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
484
|
+
|
485
|
+
def set_lr(self, new_lr: float = 1e-4):
|
486
|
+
if self.g_optim is not None:
|
487
|
+
for groups in self.g_optim.param_groups:
|
488
|
+
groups["lr"] = new_lr
|
489
|
+
return self.get_lr()
|
490
|
+
|
491
|
+
def get_lr(self) -> Tuple[float, float]:
|
492
|
+
if self.g_optim is not None:
|
493
|
+
return self.g_optim.param_groups[0]["lr"]
|
494
|
+
return float("nan")
|
495
|
+
|
496
|
+
def save_weights(self, path, replace=True):
|
497
|
+
is_pathlike(path, check_if_empty=True, validate=True)
|
498
|
+
if str(path).endswith(".pt"):
|
499
|
+
path = Path(path).parent
|
500
|
+
else:
|
501
|
+
path = Path(path)
|
502
|
+
self.generator.save_weights(Path(path, "generator.pt"), replace)
|
461
503
|
|
462
|
-
|
463
|
-
|
504
|
+
def load_weights(
|
505
|
+
self,
|
506
|
+
path,
|
507
|
+
raise_if_not_exists=False,
|
508
|
+
strict=True,
|
509
|
+
assign=False,
|
510
|
+
weights_only=False,
|
511
|
+
mmap=None,
|
512
|
+
**torch_loader_kwargs
|
513
|
+
):
|
514
|
+
is_pathlike(path, check_if_empty=True, validate=True)
|
515
|
+
if str(path).endswith(".pt"):
|
516
|
+
path = Path(path)
|
517
|
+
else:
|
518
|
+
path = Path(path, "generator.pt")
|
464
519
|
|
465
|
-
self.
|
520
|
+
self.generator.load_weights(
|
521
|
+
path,
|
522
|
+
raise_if_not_exists,
|
523
|
+
strict,
|
524
|
+
assign,
|
525
|
+
weights_only,
|
526
|
+
mmap,
|
527
|
+
**torch_loader_kwargs,
|
528
|
+
)
|
529
|
+
|
530
|
+
def finish_training_setup(self):
|
531
|
+
gc.collect()
|
532
|
+
clear_cache()
|
533
|
+
self.eval()
|
534
|
+
self.gen_training = False
|
535
|
+
|
536
|
+
def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
|
537
|
+
"""Returns the generated spec and phase"""
|
538
|
+
return self.generator.forward(mel_spec)
|
466
539
|
|
467
|
-
def
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
540
|
+
def inference(
|
541
|
+
self,
|
542
|
+
mel_spec: Tensor,
|
543
|
+
return_dict: bool = False,
|
544
|
+
) -> Union[Dict[str, Tensor], Tensor]:
|
545
|
+
spec, phase = super().inference(mel_spec)
|
546
|
+
wave = self.audio_processor.inverse_transform(
|
547
|
+
spec,
|
548
|
+
phase,
|
549
|
+
self.settings.n_fft,
|
550
|
+
hop_length=4,
|
551
|
+
win_length=self.settings.n_fft,
|
552
|
+
)
|
553
|
+
if not return_dict:
|
554
|
+
return wave[:, : wave.shape[-1] - 256]
|
555
|
+
return {
|
556
|
+
"wave": wave[:, : wave.shape[-1] - 256],
|
557
|
+
"spec": spec,
|
558
|
+
"phase": phase,
|
559
|
+
}
|
560
|
+
|
561
|
+
def set_device(self, device: str):
|
562
|
+
self.to(device=device)
|
563
|
+
self.generator.to(device=device)
|
564
|
+
self.audio_processor.to(device=device)
|
565
|
+
self.msd.to(device=device)
|
566
|
+
self.mpd.to(device=device)
|
567
|
+
|
568
|
+
def train_step(
|
569
|
+
self,
|
570
|
+
mels: Tensor,
|
571
|
+
real_audio: Tensor,
|
572
|
+
stft_scale: float = 1.0,
|
573
|
+
mel_scale: float = 1.0,
|
574
|
+
ext_loss: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
|
575
|
+
):
|
576
|
+
if not self.gen_training:
|
577
|
+
self.setup_training_mode()
|
578
|
+
|
579
|
+
self.g_optim.zero_grad()
|
580
|
+
spec, phase = self.generator.train_step(mels)
|
581
|
+
|
582
|
+
real_audio = real_audio.squeeze(1)
|
583
|
+
with torch.no_grad():
|
584
|
+
fake_audio = self.audio_processor.inverse_transform(
|
585
|
+
spec,
|
586
|
+
phase,
|
587
|
+
self.settings.n_fft,
|
588
|
+
hop_length=4,
|
589
|
+
win_length=self.settings.n_fft,
|
590
|
+
)[:, : real_audio.shape[-1]]
|
591
|
+
loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
592
|
+
loss_mel = (
|
593
|
+
F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
|
594
|
+
)
|
595
|
+
loss_g.backward()
|
596
|
+
loss_g = loss_stft + loss_mel
|
597
|
+
loss_ext = 0
|
598
|
+
|
599
|
+
if ext_loss is not None:
|
600
|
+
l_ext = ext_loss(fake_audio, real_audio)
|
601
|
+
loss_g = loss_g + l_ext
|
602
|
+
loss_ext = l_ext.item()
|
603
|
+
|
604
|
+
self.g_optim.step()
|
605
|
+
return {
|
606
|
+
"loss": loss_g.item(),
|
607
|
+
"loss_stft": loss_stft.item(),
|
608
|
+
"loss_mel": loss_mel.item(),
|
609
|
+
"loss_ext": loss_ext,
|
610
|
+
"lr": self.get_lr(),
|
611
|
+
}
|
612
|
+
|
613
|
+
def step_scheduler(self):
|
614
|
+
|
615
|
+
if self.g_scheduler is not None:
|
616
|
+
self.g_scheduler.step()
|
617
|
+
|
618
|
+
def reset_schedulers(self, lr: Optional[float] = None):
|
619
|
+
"""
|
620
|
+
In case you have adopted another strategy, with this function,
|
621
|
+
it is possible restart the scheduler and set the lr to another value.
|
622
|
+
"""
|
623
|
+
if lr is not None:
|
624
|
+
self.set_lr(lr)
|
625
|
+
if self.g_optim is not None:
|
626
|
+
self.g_scheduler = None
|
627
|
+
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|