lt-tensor 0.0.1a13__py3-none-any.whl → 0.0.1a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,52 +1,7 @@
1
- __all__ = ["iSTFTGenerator", "ResBlocks"]
2
- import gc
3
- import math
4
- import itertools
1
+ __all__ = ["iSTFTGenerator"]
5
2
  from lt_utils.common import *
6
3
  from lt_tensor.torch_commons import *
7
- from lt_tensor.model_base import Model
8
- from lt_tensor.misc_utils import log_tensor
9
- from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
10
- from lt_utils.misc_utils import log_traceback
11
- from lt_tensor.processors import AudioProcessor
12
- from lt_utils.type_utils import is_dir, is_pathlike
13
- from lt_tensor.misc_utils import set_seed, clear_cache
14
- from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
15
- import torch.nn.functional as F
16
- from lt_tensor.config_templates import updateDict, ModelConfig
17
-
18
-
19
- class ResBlocks(ConvNets):
20
- def __init__(
21
- self,
22
- channels: int,
23
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
24
- resblock_dilation_sizes: List[Union[int, List[int]]] = [
25
- [1, 3, 5],
26
- [1, 3, 5],
27
- [1, 3, 5],
28
- ],
29
- activation: nn.Module = nn.LeakyReLU(0.1),
30
- ):
31
- super().__init__()
32
- self.num_kernels = len(resblock_kernel_sizes)
33
- self.rb = nn.ModuleList()
34
- self.activation = activation
35
-
36
- for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
37
- self.rb.append(ResBlock1D(channels, k, j, activation))
38
-
39
- self.rb.apply(self.init_weights)
40
-
41
- def forward(self, x: torch.Tensor):
42
- xs = None
43
- for i, block in enumerate(self.rb):
44
- if i == 0:
45
- xs = block(x)
46
- else:
47
- xs += block(x)
48
- x = xs / self.num_kernels
49
- return self.activation(x)
4
+ from lt_tensor.model_zoo.residual import ConvNets, ResBlocks1D, ResBlock1D, ResBlock1D2
50
5
 
51
6
 
52
7
  class iSTFTGenerator(ConvNets):
@@ -65,6 +20,7 @@ class iSTFTGenerator(ConvNets):
65
20
  n_fft: int = 16,
66
21
  activation: nn.Module = nn.LeakyReLU(0.1),
67
22
  hop_length: int = 256,
23
+ residual_cls: Union[ResBlock1D, ResBlock1D2] = ResBlock1D
68
24
  ):
69
25
  super().__init__()
70
26
  self.num_kernels = len(resblock_kernel_sizes)
@@ -82,6 +38,7 @@ class iSTFTGenerator(ConvNets):
82
38
  upsample_initial_channel,
83
39
  resblock_kernel_sizes,
84
40
  resblock_dilation_sizes,
41
+ residual_cls
85
42
  )
86
43
  )
87
44
 
@@ -91,25 +48,13 @@ class iSTFTGenerator(ConvNets):
91
48
  self.conv_post.apply(self.init_weights)
92
49
  self.reflection_pad = nn.ReflectionPad1d((1, 0))
93
50
 
94
- self.phase = nn.Sequential(
95
- nn.LeakyReLU(0.2),
96
- nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
97
- nn.LeakyReLU(0.2),
98
- nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
99
- )
100
- self.spec = nn.Sequential(
101
- nn.LeakyReLU(0.2),
102
- nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
103
- nn.LeakyReLU(0.2),
104
- nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
105
- )
106
-
107
51
  def _make_blocks(
108
52
  self,
109
53
  state: Tuple[int, int, int],
110
54
  upsample_initial_channel: int,
111
55
  resblock_kernel_sizes: List[Union[int, List[int]]],
112
56
  resblock_dilation_sizes: List[int | List[int]],
57
+ residual: nn.Module
113
58
  ):
114
59
  i, k, u = state
115
60
  channels = upsample_initial_channel // (2 ** (i + 1))
@@ -127,11 +72,12 @@ class iSTFTGenerator(ConvNets):
127
72
  )
128
73
  ).apply(self.init_weights),
129
74
  ),
130
- residual=ResBlocks(
75
+ residual=ResBlocks1D(
131
76
  channels,
132
77
  resblock_kernel_sizes,
133
78
  resblock_dilation_sizes,
134
79
  self.activation,
80
+ residual
135
81
  ),
136
82
  )
137
83
  )
@@ -142,9 +88,7 @@ class iSTFTGenerator(ConvNets):
142
88
  x = block["up"](x)
143
89
  x = block["residual"](x)
144
90
 
145
- x = self.reflection_pad(x)
146
- x = self.conv_post(x)
147
- spec = torch.exp(self.spec(x[:, : self.post_n_fft, :]))
148
- phase = torch.sin(self.phase(x[:, self.post_n_fft :, :]))
149
-
91
+ x = self.conv_post(self.activation(self.reflection_pad(x)))
92
+ spec = torch.exp(x[:, : self.post_n_fft, :])
93
+ phase = torch.sin(x[:, self.post_n_fft :, :])
150
94
  return spec, phase
@@ -1,20 +1,21 @@
1
- __all__ = ["AudioSettings", "AudioDecoder"]
1
+ __all__ = ["AudioSettings", "AudioDecoderTrainer", "AudioGeneratorOnlyTrainer"]
2
2
  import gc
3
- import math
4
3
  import itertools
5
4
  from lt_utils.common import *
6
5
  import torch.nn.functional as F
7
6
  from lt_tensor.torch_commons import *
8
7
  from lt_tensor.model_base import Model
9
- from lt_tensor.misc_utils import log_tensor
10
8
  from lt_utils.misc_utils import log_traceback
11
9
  from lt_tensor.processors import AudioProcessor
12
10
  from lt_tensor.misc_utils import set_seed, clear_cache
13
- from lt_utils.type_utils import is_dir, is_pathlike, is_file
14
- from lt_tensor.config_templates import updateDict, ModelConfig
11
+ from lt_utils.type_utils import is_dir, is_pathlike
12
+ from lt_tensor.config_templates import ModelConfig
15
13
  from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
16
- from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
17
- from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
14
+ from lt_tensor.model_zoo.discriminator import (
15
+ MultiPeriodDiscriminator,
16
+ MultiScaleDiscriminator,
17
+ )
18
+ from lt_tensor.model_zoo.residual import ResBlock1D2, ResBlock1D
18
19
 
19
20
 
20
21
  def feature_loss(fmap_r, fmap_g):
@@ -29,7 +30,6 @@ def generator_adv_loss(disc_outputs):
29
30
  loss = 0
30
31
  for dg in disc_outputs:
31
32
  l = torch.mean((1 - dg) ** 2)
32
-
33
33
  loss += l
34
34
  return loss
35
35
 
@@ -44,29 +44,6 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
44
44
  return loss
45
45
 
46
46
 
47
- """def feature_loss(fmap_r, fmap_g):
48
- loss = 0
49
- for dr, dg in zip(fmap_r, fmap_g):
50
- for rl, gl in zip(dr, dg):
51
- loss += torch.mean(torch.abs(rl - gl))
52
- return loss * 2
53
-
54
-
55
- def generator_adv_loss(fake_preds):
56
- loss = 0.0
57
- for f in fake_preds:
58
- loss += torch.mean((f - 1.0) ** 2)
59
- return loss
60
-
61
-
62
- def discriminator_loss(real_preds, fake_preds):
63
- loss = 0.0
64
- for r, f in zip(real_preds, fake_preds):
65
- loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
66
- return loss
67
- """
68
-
69
-
70
47
  class AudioSettings(ModelConfig):
71
48
  def __init__(
72
49
  self,
@@ -90,6 +67,7 @@ class AudioSettings(ModelConfig):
90
67
  scheduler_template: Callable[
91
68
  [optim.Optimizer], optim.lr_scheduler.LRScheduler
92
69
  ] = lambda optimizer: optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998),
70
+ residual_cls: Union[ResBlock1D, ResBlock1D2] = ResBlock1D,
93
71
  ):
94
72
  self.in_channels = n_mels
95
73
  self.upsample_rates = upsample_rates
@@ -105,14 +83,15 @@ class AudioSettings(ModelConfig):
105
83
  self.lr = lr
106
84
  self.adamw_betas = adamw_betas
107
85
  self.scheduler_template = scheduler_template
86
+ self.residual_cls = residual_cls
108
87
 
109
88
 
110
- class AudioDecoder(Model):
89
+ class AudioDecoderTrainer(Model):
111
90
  def __init__(
112
91
  self,
113
92
  audio_processor: AudioProcessor,
114
93
  settings: Optional[AudioSettings] = None,
115
- generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initalized!
94
+ generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
116
95
  ):
117
96
  super().__init__()
118
97
  if settings is None:
@@ -175,14 +154,20 @@ class AudioDecoder(Model):
175
154
  return True
176
155
 
177
156
  def update_schedulers_and_optimizer(self):
157
+ gc.collect()
158
+ self.g_optim = None
159
+ self.g_scheduler = None
160
+ gc.collect()
178
161
  self.g_optim = optim.AdamW(
179
162
  self.generator.parameters(),
180
163
  lr=self.settings.lr,
181
164
  betas=self.settings.adamw_betas,
182
165
  )
166
+ gc.collect()
183
167
  self.g_scheduler = self.settings.scheduler_template(self.g_optim)
184
168
  if any([self.mpd is None, self.msd is None]):
185
169
  return
170
+ gc.collect()
186
171
  self.d_optim = optim.AdamW(
187
172
  itertools.chain(self.mpd.parameters(), self.msd.parameters()),
188
173
  lr=self.settings.lr,
@@ -274,9 +259,9 @@ class AudioDecoder(Model):
274
259
  win_length=self.settings.n_fft,
275
260
  )
276
261
  if not return_dict:
277
- return wave[:, : wave.shape[-1] - 256]
262
+ return wave
278
263
  return {
279
- "wave": wave[:, : wave.shape[-1] - 256],
264
+ "wave": wave,
280
265
  "spec": spec,
281
266
  "phase": phase,
282
267
  }
@@ -310,8 +295,8 @@ class AudioDecoder(Model):
310
295
  self.settings.n_fft,
311
296
  hop_length=4,
312
297
  win_length=self.settings.n_fft,
313
- # length=real_audio.shape[-1]
314
- )[:, : real_audio.shape[-1]]
298
+ length=real_audio.shape[-1],
299
+ )
315
300
 
316
301
  disc_kwargs = dict(
317
302
  real_audio=real_audio,
@@ -324,7 +309,7 @@ class AudioDecoder(Model):
324
309
  else:
325
310
  disc_out = self._discriminator_step(**disc_kwargs)
326
311
 
327
- generato_kwargs = dict(
312
+ generator_kwargs = dict(
328
313
  mels=mels,
329
314
  real_audio=real_audio,
330
315
  fake_audio=fake_audio,
@@ -339,8 +324,8 @@ class AudioDecoder(Model):
339
324
 
340
325
  if is_generator_frozen:
341
326
  with torch.no_grad():
342
- return self._generator_step(**generato_kwargs)
343
- return self._generator_step(**generato_kwargs)
327
+ return self._generator_step(**generator_kwargs)
328
+ return self._generator_step(**generator_kwargs)
344
329
 
345
330
  def _discriminator_step(
346
331
  self,
@@ -349,7 +334,8 @@ class AudioDecoder(Model):
349
334
  am_i_frozen: bool = False,
350
335
  ):
351
336
  # ========== Discriminator Forward Pass ==========
352
-
337
+ if not am_i_frozen:
338
+ self.d_optim.zero_grad()
353
339
  # MPD
354
340
  real_mpd_preds, _ = self.mpd(real_audio)
355
341
  fake_mpd_preds, _ = self.mpd(fake_audio)
@@ -362,7 +348,6 @@ class AudioDecoder(Model):
362
348
  loss_d = loss_d_mpd + loss_d_msd
363
349
 
364
350
  if not am_i_frozen:
365
- self.d_optim.zero_grad()
366
351
  loss_d.backward()
367
352
  self.d_optim.step()
368
353
 
@@ -384,6 +369,8 @@ class AudioDecoder(Model):
384
369
  am_i_frozen: bool = False,
385
370
  ):
386
371
  # ========== Generator Loss ==========
372
+ if not am_i_frozen:
373
+ self.g_optim.zero_grad()
387
374
  real_mpd_feats = self.mpd(real_audio)[1]
388
375
  real_msd_feats = self.msd(real_audio)[1]
389
376
 
@@ -395,7 +382,7 @@ class AudioDecoder(Model):
395
382
  loss_fm_mpd = feature_loss(real_mpd_feats, fake_mpd_feats)
396
383
  loss_fm_msd = feature_loss(real_msd_feats, fake_msd_feats)
397
384
 
398
- loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
385
+ # loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
399
386
  loss_mel = (
400
387
  F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
401
388
  )
@@ -403,20 +390,21 @@ class AudioDecoder(Model):
403
390
 
404
391
  loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
405
392
 
406
- loss_g = loss_adv + loss_fm + loss_stft # + loss_mel
393
+ loss_g = loss_adv + loss_fm + loss_mel # + loss_stft
407
394
  if not am_i_frozen:
408
- self.g_optim.zero_grad()
409
395
  loss_g.backward()
410
396
  self.g_optim.step()
397
+
398
+ lr_g, lr_d = self.get_lr()
411
399
  return {
412
400
  "loss_g": loss_g.item(),
413
401
  "loss_d": loss_d,
414
402
  "loss_adv": loss_adv.item(),
415
403
  "loss_fm": loss_fm.item(),
416
- "loss_stft": loss_stft.item(),
404
+ "loss_stft": 1.0, # loss_stft.item(),
417
405
  "loss_mel": loss_mel.item(),
418
- "lr_g": self.g_optim.param_groups[0]["lr"],
419
- "lr_d": self.d_optim.param_groups[0]["lr"],
406
+ "lr_g": lr_g,
407
+ "lr_d": lr_d,
420
408
  }
421
409
 
422
410
  def step_scheduler(
@@ -442,34 +430,198 @@ class AudioDecoder(Model):
442
430
  self.g_scheduler = self.settings.scheduler_template(self.g_optim)
443
431
 
444
432
 
445
- class ResBlocks(ConvNets):
433
+ class AudioGeneratorOnlyTrainer(Model):
446
434
  def __init__(
447
435
  self,
448
- channels: int,
449
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
450
- resblock_dilation_sizes: List[Union[int, List[int]]] = [
451
- [1, 3, 5],
452
- [1, 3, 5],
453
- [1, 3, 5],
454
- ],
455
- activation: nn.Module = nn.LeakyReLU(0.1),
436
+ audio_processor: AudioProcessor,
437
+ settings: Optional[AudioSettings] = None,
438
+ generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
456
439
  ):
457
440
  super().__init__()
458
- self.num_kernels = len(resblock_kernel_sizes)
459
- self.rb = nn.ModuleList()
460
- self.activation = activation
441
+ if settings is None:
442
+ self.settings = AudioSettings()
443
+ elif isinstance(settings, dict):
444
+ self.settings = AudioSettings(**settings)
445
+ elif isinstance(settings, AudioSettings):
446
+ self.settings = settings
447
+ else:
448
+ raise ValueError(
449
+ "Cannot initialize the waveDecoder with the given settings. "
450
+ "Use either a dictionary, or the class WaveSettings to setup the settings. "
451
+ "Alternatively, leave it None to use the default values."
452
+ )
453
+ if self.settings.seed is not None:
454
+ set_seed(self.settings.seed)
455
+ if generator is None:
456
+ generator = iSTFTGenerator
457
+ self.generator: iSTFTGenerator = generator(
458
+ in_channels=self.settings.in_channels,
459
+ upsample_rates=self.settings.upsample_rates,
460
+ upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
461
+ upsample_initial_channel=self.settings.upsample_initial_channel,
462
+ resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
463
+ resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
464
+ n_fft=self.settings.n_fft,
465
+ activation=self.settings.activation,
466
+ )
467
+ self.generator.eval()
468
+ self.gen_training = False
469
+ self.audio_processor = audio_processor
470
+
471
+ def setup_training_mode(self, *args, **kwargs):
472
+ self.finish_training_setup()
473
+ self.update_schedulers_and_optimizer()
474
+ self.gen_training = True
475
+ return True
476
+
477
+ def update_schedulers_and_optimizer(self):
478
+ self.g_optim = optim.AdamW(
479
+ self.generator.parameters(),
480
+ lr=self.settings.lr,
481
+ betas=self.settings.adamw_betas,
482
+ )
483
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)
484
+
485
+ def set_lr(self, new_lr: float = 1e-4):
486
+ if self.g_optim is not None:
487
+ for groups in self.g_optim.param_groups:
488
+ groups["lr"] = new_lr
489
+ return self.get_lr()
490
+
491
+ def get_lr(self) -> Tuple[float, float]:
492
+ if self.g_optim is not None:
493
+ return self.g_optim.param_groups[0]["lr"]
494
+ return float("nan")
495
+
496
+ def save_weights(self, path, replace=True):
497
+ is_pathlike(path, check_if_empty=True, validate=True)
498
+ if str(path).endswith(".pt"):
499
+ path = Path(path).parent
500
+ else:
501
+ path = Path(path)
502
+ self.generator.save_weights(Path(path, "generator.pt"), replace)
461
503
 
462
- for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
463
- self.rb.append(ResBlock1D(channels, k, j, activation))
504
+ def load_weights(
505
+ self,
506
+ path,
507
+ raise_if_not_exists=False,
508
+ strict=True,
509
+ assign=False,
510
+ weights_only=False,
511
+ mmap=None,
512
+ **torch_loader_kwargs
513
+ ):
514
+ is_pathlike(path, check_if_empty=True, validate=True)
515
+ if str(path).endswith(".pt"):
516
+ path = Path(path)
517
+ else:
518
+ path = Path(path, "generator.pt")
464
519
 
465
- self.rb.apply(self.init_weights)
520
+ self.generator.load_weights(
521
+ path,
522
+ raise_if_not_exists,
523
+ strict,
524
+ assign,
525
+ weights_only,
526
+ mmap,
527
+ **torch_loader_kwargs,
528
+ )
529
+
530
+ def finish_training_setup(self):
531
+ gc.collect()
532
+ clear_cache()
533
+ self.eval()
534
+ self.gen_training = False
535
+
536
+ def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
537
+ """Returns the generated spec and phase"""
538
+ return self.generator.forward(mel_spec)
466
539
 
467
- def forward(self, x: torch.Tensor):
468
- xs = None
469
- for i, block in enumerate(self.rb):
470
- if i == 0:
471
- xs = block(x)
472
- else:
473
- xs += block(x)
474
- x = xs / self.num_kernels
475
- return self.activation(x)
540
+ def inference(
541
+ self,
542
+ mel_spec: Tensor,
543
+ return_dict: bool = False,
544
+ ) -> Union[Dict[str, Tensor], Tensor]:
545
+ spec, phase = super().inference(mel_spec)
546
+ wave = self.audio_processor.inverse_transform(
547
+ spec,
548
+ phase,
549
+ self.settings.n_fft,
550
+ hop_length=4,
551
+ win_length=self.settings.n_fft,
552
+ )
553
+ if not return_dict:
554
+ return wave[:, : wave.shape[-1] - 256]
555
+ return {
556
+ "wave": wave[:, : wave.shape[-1] - 256],
557
+ "spec": spec,
558
+ "phase": phase,
559
+ }
560
+
561
+ def set_device(self, device: str):
562
+ self.to(device=device)
563
+ self.generator.to(device=device)
564
+ self.audio_processor.to(device=device)
565
+ self.msd.to(device=device)
566
+ self.mpd.to(device=device)
567
+
568
+ def train_step(
569
+ self,
570
+ mels: Tensor,
571
+ real_audio: Tensor,
572
+ stft_scale: float = 1.0,
573
+ mel_scale: float = 1.0,
574
+ ext_loss: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
575
+ ):
576
+ if not self.gen_training:
577
+ self.setup_training_mode()
578
+
579
+ self.g_optim.zero_grad()
580
+ spec, phase = self.generator.train_step(mels)
581
+
582
+ real_audio = real_audio.squeeze(1)
583
+ with torch.no_grad():
584
+ fake_audio = self.audio_processor.inverse_transform(
585
+ spec,
586
+ phase,
587
+ self.settings.n_fft,
588
+ hop_length=4,
589
+ win_length=self.settings.n_fft,
590
+ )[:, : real_audio.shape[-1]]
591
+ loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
592
+ loss_mel = (
593
+ F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
594
+ )
595
+ loss_g.backward()
596
+ loss_g = loss_stft + loss_mel
597
+ loss_ext = 0
598
+
599
+ if ext_loss is not None:
600
+ l_ext = ext_loss(fake_audio, real_audio)
601
+ loss_g = loss_g + l_ext
602
+ loss_ext = l_ext.item()
603
+
604
+ self.g_optim.step()
605
+ return {
606
+ "loss": loss_g.item(),
607
+ "loss_stft": loss_stft.item(),
608
+ "loss_mel": loss_mel.item(),
609
+ "loss_ext": loss_ext,
610
+ "lr": self.get_lr(),
611
+ }
612
+
613
+ def step_scheduler(self):
614
+
615
+ if self.g_scheduler is not None:
616
+ self.g_scheduler.step()
617
+
618
+ def reset_schedulers(self, lr: Optional[float] = None):
619
+ """
620
+ In case you have adopted another strategy, with this function,
621
+ it is possible restart the scheduler and set the lr to another value.
622
+ """
623
+ if lr is not None:
624
+ self.set_lr(lr)
625
+ if self.g_optim is not None:
626
+ self.g_scheduler = None
627
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)