lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from einops import repeat
5
+
6
+
7
+ class SineGen(nn.Module):
8
+ def __init__(
9
+ self,
10
+ samp_rate,
11
+ upsample_scale,
12
+ harmonic_num=0,
13
+ sine_amp=0.1,
14
+ noise_std=0.003,
15
+ voiced_threshold=0,
16
+ flag_for_pulse=False,
17
+ ):
18
+ super().__init__()
19
+ self.sampling_rate = samp_rate
20
+ self.upsample_scale = upsample_scale
21
+ self.harmonic_num = harmonic_num
22
+ self.sine_amp = sine_amp
23
+ self.noise_std = noise_std
24
+ self.voiced_threshold = voiced_threshold
25
+ self.flag_for_pulse = flag_for_pulse
26
+ self.dim = self.harmonic_num + 1 # fundamental + harmonics
27
+
28
+ def _f02uv_b(self, f0):
29
+ return (f0 > self.voiced_threshold).float() # [B, T]
30
+
31
+ def _f02uv(self, f0):
32
+ return (f0 > self.voiced_threshold).float().unsqueeze(-1) # -> (B, T, 1)
33
+
34
+ @torch.no_grad()
35
+ def _f02sine(self, f0_values):
36
+ """
37
+ f0_values: (B, T, 1)
38
+ Output: sine waves (B, T * upsample, dim)
39
+ """
40
+ B, T, _ = f0_values.size()
41
+ f0_upsampled = repeat(
42
+ f0_values, "b t d -> b (t r) d", r=self.upsample_scale
43
+ ) # (B, T_up, 1)
44
+
45
+ # Create harmonics
46
+ harmonics = (
47
+ torch.arange(1, self.dim + 1, device=f0_values.device)
48
+ .float()
49
+ .view(1, 1, -1)
50
+ )
51
+ f0_harm = f0_upsampled * harmonics # (B, T_up, dim)
52
+
53
+ # Convert Hz to radians (2πf/sr), then integrate to get phase
54
+ rad_values = f0_harm / self.sampling_rate # normalized freq
55
+ rad_values = rad_values % 1.0 # remove multiples of 2π
56
+
57
+ # Random initial phase for each harmonic (except 0th if pulse mode)
58
+ if self.flag_for_pulse:
59
+ rand_ini = torch.zeros((B, 1, self.dim), device=f0_values.device)
60
+ else:
61
+ rand_ini = torch.rand((B, 1, self.dim), device=f0_values.device)
62
+
63
+ rand_ini = rand_ini * 2 * math.pi
64
+
65
+ # Compute cumulative phase
66
+ rad_values = rad_values * 2 * math.pi
67
+ phase = torch.cumsum(rad_values, dim=1) + rand_ini # (B, T_up, dim)
68
+
69
+ sine_waves = torch.sin(phase) # (B, T_up, dim)
70
+ return sine_waves
71
+
72
+ def _forward(self, f0):
73
+ """
74
+ f0: (B, T, 1)
75
+ returns: sine signal with harmonics and noise added
76
+ """
77
+ sine_waves = self._f02sine(f0) # (B, T_up, dim)
78
+ uv = self._f02uv_b(f0) # (B, T, 1)
79
+ uv = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
80
+
81
+ # voiced sine + unvoiced noise
82
+ sine_signal = self.sine_amp * sine_waves * uv # (B, T_up, dim)
83
+ noise = torch.randn_like(sine_signal) * self.noise_std
84
+ output = sine_signal + noise * (1.0 - uv) # noise added only on unvoiced
85
+
86
+ return output # (B, T_up, dim)
87
+
88
+ def forward(self, f0):
89
+ """
90
+ Args:
91
+ f0: (B, T) in Hz (before upsampling)
92
+ Returns:
93
+ sine_waves: (B, T_up, dim)
94
+ uv: (B, T_up, 1)
95
+ noise: (B, T_up, 1)
96
+ """
97
+ B, T = f0.shape
98
+ device = f0.device
99
+
100
+ # Get uv mask (before upsampling)
101
+ uv = self._f02uv(f0) # (B, T, 1)
102
+
103
+ # Expand f0 to include harmonics: (B, T, dim)
104
+ f0 = f0.unsqueeze(-1) # (B, T, 1)
105
+ harmonics = (
106
+ torch.arange(1, self.dim + 1, device=device).float().view(1, 1, -1)
107
+ ) # (1, 1, dim)
108
+ f0_harm = f0 * harmonics # (B, T, dim)
109
+
110
+ # Upsample
111
+ f0_harm_up = repeat(
112
+ f0_harm, "b t d -> b (t r) d", r=self.upsample_scale
113
+ ) # (B, T_up, dim)
114
+ uv_up = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
115
+
116
+ # Convert to radians
117
+ rad_per_sample = f0_harm_up / self.sampling_rate # Hz → cycles/sample
118
+ rad_per_sample = rad_per_sample * 2 * math.pi # cycles → radians/sample
119
+
120
+ # Random phase init for each sample
121
+ B, T_up, D = rad_per_sample.shape
122
+ rand_phase = torch.rand(B, D, device=device) * 2 * math.pi # (B, D)
123
+
124
+ # Compute cumulative phase
125
+ phase = torch.cumsum(rad_per_sample, dim=1) + rand_phase.unsqueeze(
126
+ 1
127
+ ) # (B, T_up, D)
128
+
129
+ # Apply sine
130
+ sine_waves = torch.sin(phase) * self.sine_amp # (B, T_up, D)
131
+
132
+ # Handle unvoiced: create noise only for fundamental
133
+ noise = torch.randn(B, T_up, 1, device=device) * self.noise_std
134
+ if self.flag_for_pulse:
135
+ # If pulse mode is on, align phase at start of voiced segments
136
+ # Optional and tricky to implement — may require segmenting uv
137
+ pass
138
+
139
+ # Replace sine by noise for unvoiced (only on fundamental)
140
+ sine_waves[:, :, 0:1] = sine_waves[:, :, 0:1] * uv_up + noise * (1 - uv_up)
141
+
142
+ return sine_waves, uv_up, noise
@@ -1,41 +1,45 @@
1
- __all__ = ["AudioSettings", "AudioDecoder"]
1
+ __all__ = ["AudioSettings", "AudioDecoderTrainer", "AudioGeneratorOnlyTrainer"]
2
2
  import gc
3
- import math
4
3
  import itertools
5
4
  from lt_utils.common import *
6
5
  import torch.nn.functional as F
7
6
  from lt_tensor.torch_commons import *
8
7
  from lt_tensor.model_base import Model
9
- from lt_tensor.misc_utils import log_tensor
10
8
  from lt_utils.misc_utils import log_traceback
11
9
  from lt_tensor.processors import AudioProcessor
12
10
  from lt_tensor.misc_utils import set_seed, clear_cache
13
- from lt_utils.type_utils import is_dir, is_pathlike, is_file
14
- from lt_tensor.config_templates import updateDict, ModelConfig
11
+ from lt_utils.type_utils import is_dir, is_pathlike
12
+ from lt_tensor.config_templates import ModelConfig
15
13
  from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
16
- from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
17
- from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
14
+ from lt_tensor.model_zoo.discriminator import (
15
+ MultiPeriodDiscriminator,
16
+ MultiScaleDiscriminator,
17
+ )
18
18
 
19
19
 
20
- def feature_loss(real_feats, fake_feats):
21
- loss = 0.0
22
- for r, f in zip(real_feats, fake_feats):
23
- for ri, fi in zip(r, f):
24
- loss += F.l1_loss(ri, fi)
25
- return loss
20
+ def feature_loss(fmap_r, fmap_g):
21
+ loss = 0
22
+ for dr, dg in zip(fmap_r, fmap_g):
23
+ for rl, gl in zip(dr, dg):
24
+ loss += torch.mean(torch.abs(rl - gl))
25
+ return loss * 2
26
26
 
27
27
 
28
- def generator_adv_loss(fake_preds):
29
- loss = 0.0
30
- for f in fake_preds:
31
- loss += torch.mean((f - 1.0) ** 2)
28
+ def generator_adv_loss(disc_outputs):
29
+ loss = 0
30
+ for dg in disc_outputs:
31
+ l = torch.mean((1 - dg) ** 2)
32
+ loss += l
32
33
  return loss
33
34
 
34
35
 
35
- def discriminator_loss(real_preds, fake_preds):
36
- loss = 0.0
37
- for r, f in zip(real_preds, fake_preds):
38
- loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
36
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
37
+ loss = 0
38
+
39
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
40
+ r_loss = torch.mean((1 - dr) ** 2)
41
+ g_loss = torch.mean(dg**2)
42
+ loss += r_loss + g_loss
39
43
  return loss
40
44
 
41
45
 
@@ -79,12 +83,12 @@ class AudioSettings(ModelConfig):
79
83
  self.scheduler_template = scheduler_template
80
84
 
81
85
 
82
- class AudioDecoder(Model):
86
+ class AudioDecoderTrainer(Model):
83
87
  def __init__(
84
88
  self,
85
89
  audio_processor: AudioProcessor,
86
90
  settings: Optional[AudioSettings] = None,
87
- generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initalized!
91
+ generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
88
92
  ):
89
93
  super().__init__()
90
94
  if settings is None:
@@ -284,9 +288,6 @@ class AudioDecoder(Model):
284
288
  win_length=self.settings.n_fft,
285
289
  # length=real_audio.shape[-1]
286
290
  )[:, : real_audio.shape[-1]]
287
- # smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
288
- # real_audio = real_audio[:, :, :smallest].squeeze(1)
289
- # fake_audio = fake_audio[:, :smallest]
290
291
 
291
292
  disc_kwargs = dict(
292
293
  real_audio=real_audio,
@@ -299,7 +300,7 @@ class AudioDecoder(Model):
299
300
  else:
300
301
  disc_out = self._discriminator_step(**disc_kwargs)
301
302
 
302
- generato_kwargs = dict(
303
+ generator_kwargs = dict(
303
304
  mels=mels,
304
305
  real_audio=real_audio,
305
306
  fake_audio=fake_audio,
@@ -314,8 +315,8 @@ class AudioDecoder(Model):
314
315
 
315
316
  if is_generator_frozen:
316
317
  with torch.no_grad():
317
- return self._generator_step(**generato_kwargs)
318
- return self._generator_step(**generato_kwargs)
318
+ return self._generator_step(**generator_kwargs)
319
+ return self._generator_step(**generator_kwargs)
319
320
 
320
321
  def _discriminator_step(
321
322
  self,
@@ -324,7 +325,8 @@ class AudioDecoder(Model):
324
325
  am_i_frozen: bool = False,
325
326
  ):
326
327
  # ========== Discriminator Forward Pass ==========
327
-
328
+ if not am_i_frozen:
329
+ self.d_optim.zero_grad()
328
330
  # MPD
329
331
  real_mpd_preds, _ = self.mpd(real_audio)
330
332
  fake_mpd_preds, _ = self.mpd(fake_audio)
@@ -337,7 +339,6 @@ class AudioDecoder(Model):
337
339
  loss_d = loss_d_mpd + loss_d_msd
338
340
 
339
341
  if not am_i_frozen:
340
- self.d_optim.zero_grad()
341
342
  loss_d.backward()
342
343
  self.d_optim.step()
343
344
 
@@ -359,6 +360,8 @@ class AudioDecoder(Model):
359
360
  am_i_frozen: bool = False,
360
361
  ):
361
362
  # ========== Generator Loss ==========
363
+ if not am_i_frozen:
364
+ self.g_optim.zero_grad()
362
365
  real_mpd_feats = self.mpd(real_audio)[1]
363
366
  real_msd_feats = self.msd(real_audio)[1]
364
367
 
@@ -372,7 +375,7 @@ class AudioDecoder(Model):
372
375
 
373
376
  loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
374
377
  loss_mel = (
375
- F.l1_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
378
+ F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
376
379
  )
377
380
  loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
378
381
 
@@ -380,9 +383,10 @@ class AudioDecoder(Model):
380
383
 
381
384
  loss_g = loss_adv + loss_fm + loss_stft + loss_mel
382
385
  if not am_i_frozen:
383
- self.g_optim.zero_grad()
384
386
  loss_g.backward()
385
387
  self.g_optim.step()
388
+
389
+ lr_g, lr_d = self.get_lr()
386
390
  return {
387
391
  "loss_g": loss_g.item(),
388
392
  "loss_d": loss_d,
@@ -390,8 +394,8 @@ class AudioDecoder(Model):
390
394
  "loss_fm": loss_fm.item(),
391
395
  "loss_stft": loss_stft.item(),
392
396
  "loss_mel": loss_mel.item(),
393
- "lr_g": self.g_optim.param_groups[0]["lr"],
394
- "lr_d": self.d_optim.param_groups[0]["lr"],
397
+ "lr_g": lr_g,
398
+ "lr_d": lr_d,
395
399
  }
396
400
 
397
401
  def step_scheduler(
@@ -417,34 +421,198 @@ class AudioDecoder(Model):
417
421
  self.g_scheduler = self.settings.scheduler_template(self.g_optim)
418
422
 
419
423
 
420
- class ResBlocks(ConvNets):
424
+ class AudioGeneratorOnlyTrainer(Model):
421
425
  def __init__(
422
426
  self,
423
- channels: int,
424
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
425
- resblock_dilation_sizes: List[Union[int, List[int]]] = [
426
- [1, 3, 5],
427
- [1, 3, 5],
428
- [1, 3, 5],
429
- ],
430
- activation: nn.Module = nn.LeakyReLU(0.1),
427
+ audio_processor: AudioProcessor,
428
+ settings: Optional[AudioSettings] = None,
429
+ generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
431
430
  ):
432
431
  super().__init__()
433
- self.num_kernels = len(resblock_kernel_sizes)
434
- self.rb = nn.ModuleList()
435
- self.activation = activation
432
+ if settings is None:
433
+ self.settings = AudioSettings()
434
+ elif isinstance(settings, dict):
435
+ self.settings = AudioSettings(**settings)
436
+ elif isinstance(settings, AudioSettings):
437
+ self.settings = settings
438
+ else:
439
+ raise ValueError(
440
+ "Cannot initialize the waveDecoder with the given settings. "
441
+ "Use either a dictionary, or the class WaveSettings to setup the settings. "
442
+ "Alternatively, leave it None to use the default values."
443
+ )
444
+ if self.settings.seed is not None:
445
+ set_seed(self.settings.seed)
446
+ if generator is None:
447
+ generator = iSTFTGenerator
448
+ self.generator: iSTFTGenerator = generator(
449
+ in_channels=self.settings.in_channels,
450
+ upsample_rates=self.settings.upsample_rates,
451
+ upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
452
+ upsample_initial_channel=self.settings.upsample_initial_channel,
453
+ resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
454
+ resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
455
+ n_fft=self.settings.n_fft,
456
+ activation=self.settings.activation,
457
+ )
458
+ self.generator.eval()
459
+ self.gen_training = False
460
+ self.audio_processor = audio_processor
436
461
 
437
- for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
438
- self.rb.append(ResBlock1D(channels, k, j, activation))
462
+ def setup_training_mode(self, *args, **kwargs):
463
+ self.finish_training_setup()
464
+ self.update_schedulers_and_optimizer()
465
+ self.gen_training = True
466
+ return True
467
+
468
+ def update_schedulers_and_optimizer(self):
469
+ self.g_optim = optim.AdamW(
470
+ self.generator.parameters(),
471
+ lr=self.settings.lr,
472
+ betas=self.settings.adamw_betas,
473
+ )
474
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)
439
475
 
440
- self.rb.apply(self.init_weights)
476
+ def set_lr(self, new_lr: float = 1e-4):
477
+ if self.g_optim is not None:
478
+ for groups in self.g_optim.param_groups:
479
+ groups["lr"] = new_lr
480
+ return self.get_lr()
441
481
 
442
- def forward(self, x: torch.Tensor):
443
- xs = None
444
- for i, block in enumerate(self.rb):
445
- if i == 0:
446
- xs = block(x)
447
- else:
448
- xs += block(x)
449
- x = xs / self.num_kernels
450
- return self.activation(x)
482
+ def get_lr(self) -> Tuple[float, float]:
483
+ if self.g_optim is not None:
484
+ return self.g_optim.param_groups[0]["lr"]
485
+ return float("nan")
486
+
487
+ def save_weights(self, path, replace=True):
488
+ is_pathlike(path, check_if_empty=True, validate=True)
489
+ if str(path).endswith(".pt"):
490
+ path = Path(path).parent
491
+ else:
492
+ path = Path(path)
493
+ self.generator.save_weights(Path(path, "generator.pt"), replace)
494
+
495
+ def load_weights(
496
+ self,
497
+ path,
498
+ raise_if_not_exists=False,
499
+ strict=True,
500
+ assign=False,
501
+ weights_only=False,
502
+ mmap=None,
503
+ **torch_loader_kwargs
504
+ ):
505
+ is_pathlike(path, check_if_empty=True, validate=True)
506
+ if str(path).endswith(".pt"):
507
+ path = Path(path)
508
+ else:
509
+ path = Path(path, "generator.pt")
510
+
511
+ self.generator.load_weights(
512
+ path,
513
+ raise_if_not_exists,
514
+ strict,
515
+ assign,
516
+ weights_only,
517
+ mmap,
518
+ **torch_loader_kwargs,
519
+ )
520
+
521
+ def finish_training_setup(self):
522
+ gc.collect()
523
+ clear_cache()
524
+ self.eval()
525
+ self.gen_training = False
526
+
527
+ def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
528
+ """Returns the generated spec and phase"""
529
+ return self.generator.forward(mel_spec)
530
+
531
+ def inference(
532
+ self,
533
+ mel_spec: Tensor,
534
+ return_dict: bool = False,
535
+ ) -> Union[Dict[str, Tensor], Tensor]:
536
+ spec, phase = super().inference(mel_spec)
537
+ wave = self.audio_processor.inverse_transform(
538
+ spec,
539
+ phase,
540
+ self.settings.n_fft,
541
+ hop_length=4,
542
+ win_length=self.settings.n_fft,
543
+ )
544
+ if not return_dict:
545
+ return wave[:, : wave.shape[-1] - 256]
546
+ return {
547
+ "wave": wave[:, : wave.shape[-1] - 256],
548
+ "spec": spec,
549
+ "phase": phase,
550
+ }
551
+
552
+ def set_device(self, device: str):
553
+ self.to(device=device)
554
+ self.generator.to(device=device)
555
+ self.audio_processor.to(device=device)
556
+ self.msd.to(device=device)
557
+ self.mpd.to(device=device)
558
+
559
+ def train_step(
560
+ self,
561
+ mels: Tensor,
562
+ real_audio: Tensor,
563
+ stft_scale: float = 1.0,
564
+ mel_scale: float = 1.0,
565
+ ext_loss: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
566
+ ):
567
+ if not self.gen_training:
568
+ self.setup_training_mode()
569
+
570
+ self.g_optim.zero_grad()
571
+ spec, phase = self.generator.train_step(mels)
572
+
573
+ real_audio = real_audio.squeeze(1)
574
+ with torch.no_grad():
575
+ fake_audio = self.audio_processor.inverse_transform(
576
+ spec,
577
+ phase,
578
+ self.settings.n_fft,
579
+ hop_length=4,
580
+ win_length=self.settings.n_fft,
581
+ )[:, : real_audio.shape[-1]]
582
+ loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
583
+ loss_mel = (
584
+ F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
585
+ )
586
+ loss_g.backward()
587
+ loss_g = loss_stft + loss_mel
588
+ loss_ext = 0
589
+
590
+ if ext_loss is not None:
591
+ l_ext = ext_loss(fake_audio, real_audio)
592
+ loss_g = loss_g + l_ext
593
+ loss_ext = l_ext.item()
594
+
595
+ self.g_optim.step()
596
+ return {
597
+ "loss": loss_g.item(),
598
+ "loss_stft": loss_stft.item(),
599
+ "loss_mel": loss_mel.item(),
600
+ "loss_ext": loss_ext,
601
+ "lr": self.get_lr(),
602
+ }
603
+
604
+ def step_scheduler(self):
605
+
606
+ if self.g_scheduler is not None:
607
+ self.g_scheduler.step()
608
+
609
+ def reset_schedulers(self, lr: Optional[float] = None):
610
+ """
611
+ In case you have adopted another strategy, with this function,
612
+ it is possible restart the scheduler and set the lr to another value.
613
+ """
614
+ if lr is not None:
615
+ self.set_lr(lr)
616
+ if self.g_optim is not None:
617
+ self.g_scheduler = None
618
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)