lt-tensor 0.0.1a32__tar.gz → 0.0.1a33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/LICENSE +1 -1
  2. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/PKG-INFO +2 -2
  3. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/__init__.py +1 -1
  4. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/math_ops.py +19 -6
  5. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/losses/discriminators.py +253 -60
  6. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/processors/audio.py +105 -59
  7. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor.egg-info/PKG-INFO +2 -2
  8. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor.egg-info/requires.txt +1 -1
  9. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/setup.py +2 -2
  10. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/README.md +0 -0
  11. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/config_templates.py +0 -0
  12. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/losses.py +0 -0
  13. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/lr_schedulers.py +0 -0
  14. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/misc_utils.py +0 -0
  15. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_base.py +0 -0
  16. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/__init__.py +0 -0
  17. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -0
  18. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/activations/alias_free_torch/act.py +0 -0
  19. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +0 -0
  20. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +0 -0
  21. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
  22. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
  23. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
  24. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +0 -0
  25. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/audio_models/istft/__init__.py +0 -0
  26. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/basic.py +0 -0
  27. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/convs.py +0 -0
  28. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/features.py +0 -0
  29. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/fusion.py +0 -0
  30. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/losses/__init__.py +0 -0
  31. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  32. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/residual.py +0 -0
  33. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/model_zoo/transformer.py +0 -0
  34. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/monotonic_align.py +0 -0
  35. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/noise_tools.py +0 -0
  36. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/processors/__init__.py +0 -0
  37. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/torch_commons.py +0 -0
  38. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor/transform.py +0 -0
  39. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor.egg-info/SOURCES.txt +0 -0
  40. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor.egg-info/dependency_links.txt +0 -0
  41. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/lt_tensor.egg-info/top_level.txt +0 -0
  42. {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a33}/setup.cfg +0 -0
@@ -186,7 +186,7 @@
186
186
  same "printed page" as the copyright notice for easier
187
187
  identification within third-party archives.
188
188
 
189
- Copyright 2025 gr1336
189
+ Copyright 2025 gr1336 (Gabriel Ribeiro)
190
190
 
191
191
  Licensed under the Apache License, Version 2.0 (the "License");
192
192
  you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a32
3
+ Version: 0.0.1a33
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils>=0.0.3
20
+ Requires-Dist: lt-utils>=0.0.4
21
21
  Requires-Dist: librosa==0.11.*
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a"
1
+ __version__ = "0.0.1a33"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
@@ -6,10 +6,12 @@ __all__ = [
6
6
  "apply_window",
7
7
  "shift_ring",
8
8
  "dot_product",
9
- "normalize_tensor",
10
9
  "log_magnitude",
11
10
  "shift_time",
12
11
  "phase",
12
+ "normalize_unit_norm",
13
+ "normalize_minmax",
14
+ "normalize_zscore",
13
15
  ]
14
16
 
15
17
  from lt_tensor.torch_commons import *
@@ -61,11 +63,6 @@ def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
61
63
  return torch.sum(x * y, dim=dim)
62
64
 
63
65
 
64
- def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
65
- """Normalizes a tensor to unit norm (L2)."""
66
- return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
67
-
68
-
69
66
  def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
70
67
  """Returns log magnitude from complex STFT."""
71
68
  magnitude = torch.abs(stft_complex)
@@ -76,3 +73,19 @@ def phase(stft_complex: Tensor) -> Tensor:
76
73
  """Returns phase from complex STFT."""
77
74
  return torch.angle(stft_complex)
78
75
 
76
+
77
+ def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-6):
78
+ norm = torch.norm(x, dim=-1, keepdim=True)
79
+ return x / (norm + eps)
80
+
81
+
82
+ def normalize_minmax(x: torch.Tensor, eps: float = 1e-6):
83
+ min_val = x.amin(dim=-1, keepdim=True)
84
+ max_val = x.amax(dim=-1, keepdim=True)
85
+ return (x - min_val) / (max_val - min_val + eps)
86
+
87
+
88
+ def normalize_zscore(x: torch.Tensor, eps: float = 1e-6):
89
+ mean = x.mean(dim=-1, keepdim=True)
90
+ std = x.std(dim=-1, keepdim=True)
91
+ return (x - mean) / (std + eps)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
2
4
  from lt_utils.common import *
3
5
  from lt_tensor.torch_commons import *
@@ -5,6 +7,8 @@ from lt_tensor.model_base import Model
5
7
  from lt_tensor.model_zoo.convs import ConvNets
6
8
  from torch.nn import functional as F
7
9
  from torchaudio import transforms as T
10
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
11
+
8
12
 
9
13
  MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
10
14
  List[Tensor],
@@ -14,11 +18,92 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
14
18
  ]
15
19
 
16
20
 
21
+ class MultiDiscriminatorWrapper(Model):
22
+ def __init__(self, list_discriminator: List["_MultiDiscriminatorT"]):
23
+ """Setup example:
24
+ model_d = MultiDiscriminatorStep(
25
+ [
26
+ MultiEnvelopeDiscriminator(),
27
+ MultiBandDiscriminator(),
28
+ MultiResolutionDiscriminator(),
29
+ MultiPeriodDiscriminator(0.5),
30
+ ]
31
+ )
32
+ """
33
+ super().__init__()
34
+ self.disc: Sequence[_MultiDiscriminatorT] = nn.ModuleList(list_discriminator)
35
+ self.total = len(self.disc)
36
+
37
+ def forward(
38
+ self,
39
+ y: Tensor,
40
+ y_hat: Tensor,
41
+ step_type: Literal["discriminator", "generator"],
42
+ ) -> Union[
43
+ Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
44
+ ]:
45
+ """
46
+ It returns the content based on the choice of "step_type", being it a
47
+ 'discriminator' or 'generator'
48
+
49
+ For generator it returns:
50
+ Tuple[Tensor, Tensor, List[float]]
51
+ "gen_loss, feat_loss, all_g_losses"
52
+
53
+ For 'discriminator' it returns:
54
+ Tuple[Tensor, List[float], List[float]]
55
+ "disc_loss, disc_real_losses, disc_gen_losses"
56
+ """
57
+ if step_type == "generator":
58
+ all_g_losses: List[float] = []
59
+ feat_loss: Tensor = 0
60
+ gen_loss: Tensor = 0
61
+ else:
62
+ disc_loss: Tensor = 0
63
+ disc_real_losses: List[float] = []
64
+ disc_gen_losses: List[float] = []
65
+
66
+ for disc in self.disc:
67
+ if step_type == "generator":
68
+ # feature loss, generator loss, list of generator losses (float)]
69
+ f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
70
+ gen_loss += g_loss
71
+ feat_loss += f_loss
72
+ all_g_losses.extend(g_losses)
73
+ else:
74
+ # [discriminator loss, (disc losses real, disc losses generated)]
75
+ d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
76
+ disc_loss += d_loss
77
+ disc_real_losses.extend(d_real_losses)
78
+ disc_gen_losses.extend(d_gen_losses)
79
+
80
+ if step_type == "generator":
81
+ return gen_loss, feat_loss, all_g_losses
82
+ return disc_loss, disc_real_losses, disc_gen_losses
83
+
84
+
85
+ def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-5):
86
+ norm = torch.norm(x, dim=-1, keepdim=True)
87
+ return x / (norm + eps)
88
+
89
+
90
+ def normalize_minmax(x: torch.Tensor, eps: float = 1e-5):
91
+ min_val = x.amin(dim=-1, keepdim=True)
92
+ max_val = x.amax(dim=-1, keepdim=True)
93
+ return (x - min_val) / (max_val - min_val + eps)
94
+
95
+
96
+ def normalize_zscore(x: torch.Tensor, eps: float = 1e-5):
97
+ mean = x.mean(dim=-1, keepdim=True)
98
+ std = x.std(dim=-1, keepdim=True)
99
+ return (x - mean) / (std + eps)
100
+
101
+
17
102
  def get_padding(kernel_size, dilation=1):
18
103
  return int((kernel_size * dilation - dilation) / 2)
19
104
 
20
105
 
21
- class MultiDiscriminatorWrapper(ConvNets):
106
+ class _MultiDiscriminatorT(ConvNets):
22
107
  """Base for all multi-steps type of discriminators"""
23
108
 
24
109
  def __init__(self, *args, **kwargs):
@@ -171,7 +256,7 @@ class DiscriminatorP(ConvNets):
171
256
  return x.flatten(1, -1), fmap
172
257
 
173
258
 
174
- class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
259
+ class MultiPeriodDiscriminator(_MultiDiscriminatorT):
175
260
  def __init__(
176
261
  self,
177
262
  discriminator_channel_mult: Number = 1,
@@ -258,7 +343,7 @@ class DiscriminatorEnvelope(ConvNets):
258
343
  return x.flatten(1), fmap
259
344
 
260
345
 
261
- class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
346
+ class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
262
347
  def __init__(self, use_spectral_norm: bool = False):
263
348
  super().__init__()
264
349
  self.discriminators = nn.ModuleList(
@@ -375,7 +460,7 @@ class DiscriminatorB(ConvNets):
375
460
  return x, fmap
376
461
 
377
462
 
378
- class MultiBandDiscriminator(MultiDiscriminatorWrapper):
463
+ class MultiBandDiscriminator(_MultiDiscriminatorT):
379
464
  """
380
465
  Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
381
466
  and the modified code adapted from https://github.com/gemelo-ai/vocos.
@@ -514,7 +599,7 @@ class DiscriminatorR(ConvNets):
514
599
  return mag
515
600
 
516
601
 
517
- class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
602
+ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
518
603
  def __init__(
519
604
  self,
520
605
  use_spectral_norm: bool = False,
@@ -554,69 +639,177 @@ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
554
639
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
555
640
 
556
641
 
557
- class MultiDiscriminatorStep(Model):
642
+ class MultiMelScaleLoss(Model):
643
+ # TODO: Make the normalization an argument to be chosen by the dev
558
644
  def __init__(
559
- self, list_discriminator: List[MultiDiscriminatorWrapper]
645
+ self,
646
+ sample_rate: int,
647
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
648
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
649
+ n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
650
+ hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
651
+ weight: float = 1.0,
652
+ lambda_mel: float = 1.0,
653
+ f_min: float = [0, 0, 0, 0, 0, 0, 0],
654
+ f_max: Optional[float] = [None, None, None, None, None, None, None],
655
+ loss_fn: Callable = nn.L1Loss(),
656
+ center: bool = True,
657
+ power: float = 1.0,
658
+ normalized: bool = False,
659
+ pad_mode: str = "reflect",
660
+ onesided: Optional[bool] = None,
661
+ std: int = 4,
662
+ mean: int = -4,
663
+ auto_interpolate: bool = True,
664
+ use_istft_norm: bool = True,
665
+ use_pitch_loss: bool = False,
666
+ use_rms_loss: bool = False,
667
+ lambda_pitch: float = 0.5,
668
+ lambda_rms: float = 0.5,
560
669
  ):
561
- """Setup example:
562
- model_d = MultiDiscriminatorStep(
670
+ super().__init__()
671
+ assert (
672
+ len(n_mels)
673
+ == len(window_lengths)
674
+ == len(n_ffts)
675
+ == len(hops)
676
+ == len(f_min)
677
+ == len(f_max)
678
+ )
679
+ self.loss_fn = loss_fn
680
+ self.lambda_mel = lambda_mel
681
+ self.weight = weight
682
+ self.use_istft_norm = use_istft_norm
683
+ self.auto_interpolate = auto_interpolate if not self.use_istft_norm else False
684
+ self.use_pitch_loss = use_pitch_loss
685
+ self.use_rms_loss = use_rms_loss
686
+ self.lambda_pitch = lambda_pitch
687
+ self.lambda_rms = lambda_rms
688
+
689
+ self._setup_mels(
690
+ sample_rate,
691
+ n_mels,
692
+ window_lengths,
693
+ n_ffts,
694
+ hops,
695
+ f_min,
696
+ f_max,
697
+ center,
698
+ power,
699
+ normalized,
700
+ pad_mode,
701
+ onesided,
702
+ std,
703
+ mean,
704
+ )
705
+
706
+ def _setup_mels(
707
+ self,
708
+ sample_rate: int,
709
+ n_mels: List[int],
710
+ window_lengths: List[int],
711
+ n_ffts: List[int],
712
+ hops: List[int],
713
+ f_min: List[float],
714
+ f_max: List[Optional[float]],
715
+ center: bool,
716
+ power: float,
717
+ normalized: bool,
718
+ pad_mode: str = "reflect",
719
+ onesided: Optional[bool] = None,
720
+ std: int = 4,
721
+ mean: int = -4,
722
+ ):
723
+ assert (
724
+ len(n_mels)
725
+ == len(window_lengths)
726
+ == len(n_ffts)
727
+ == len(hops)
728
+ == len(f_min)
729
+ == len(f_max)
730
+ )
731
+ _mel_kwargs = dict(
732
+ sample_rate=sample_rate,
733
+ center=center,
734
+ onesided=onesided,
735
+ normalized=normalized,
736
+ power=power,
737
+ pad_mode=pad_mode,
738
+ std=std,
739
+ mean=mean,
740
+ )
741
+ self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
563
742
  [
564
- MultiEnvelopeDiscriminator(),
565
- MultiBandDiscriminator(),
566
- MultiResolutionDiscriminator(),
567
- MultiPeriodDiscriminator(0.5),
743
+ AudioProcessor(
744
+ AudioProcessorConfig(
745
+ **_mel_kwargs,
746
+ n_mels=mel,
747
+ n_fft=n_fft,
748
+ win_length=win,
749
+ hop_length=hop,
750
+ f_min=fmin,
751
+ f_max=fmax,
752
+ )
753
+ )
754
+ for mel, win, n_fft, hop, fmin, fmax in zip(
755
+ n_mels, window_lengths, n_ffts, hops, f_min, f_max
756
+ )
568
757
  ]
569
758
  )
570
- """
571
- super().__init__()
572
- self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
573
- list_discriminator
574
- )
575
- self.total = len(self.disc)
576
759
 
577
- def forward(
760
+ def _process_tensor(
578
761
  self,
579
- y: Tensor,
580
- y_hat: Tensor,
581
- step_type: Literal["discriminator", "generator"],
582
- ) -> Union[
583
- Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
584
- ]:
585
- """
586
- It returns the content based on the choice of "step_type", being it a
587
- 'discriminator' or 'generator'
762
+ input_wave: torch.Tensor,
763
+ target_wave: torch.Tensor,
764
+ ):
765
+ if input_wave.shape[-1] != target_wave.shape[-1]:
766
+ if input_wave.ndim < 3:
767
+ # To be compatible with interpolatin
768
+ if input_wave.ndim == 2:
769
+ input_wave = input_wave.unsqueeze(1)
770
+ else:
771
+ input_wave = input_wave.unsqueeze(0).unsqueeze(0)
772
+ input_wave = F.interpolate(input_wave, target_wave.shape[-1], mode="linear")
773
+ return input_wave
588
774
 
589
- For generator it returns:
590
- Tuple[Tensor, Tensor, List[float]]
591
- "gen_loss, feat_loss, all_g_losses"
775
+ def forward(
776
+ self, input_wave: torch.Tensor, target_wave: torch.Tensor
777
+ ) -> torch.Tensor:
778
+ assert (
779
+ self.use_istft_norm
780
+ or self.auto_interpolate
781
+ or input_wave.shape[-1] == target_wave.shape[-1]
782
+ )
783
+ if self.auto_interpolate:
784
+ input_wave = self._process_tensor(input_wave, target_wave)
785
+
786
+ losses = 0.0
787
+ for M in self.mel_spectrograms:
788
+ # Apply normalization if requested
789
+ if self.use_istft_norm:
790
+ input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
791
+ target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
792
+ else:
793
+ input_proc, target_proc = input_wave, target_wave
592
794
 
593
- For 'discriminator' it returns:
594
- Tuple[Tensor, List[float], List[float]]
595
- "disc_loss, disc_real_losses, disc_gen_losses"
596
- """
597
- if step_type == "generator":
598
- all_g_losses: List[float] = []
599
- feat_loss: Tensor = 0
600
- gen_loss: Tensor = 0
601
- else:
602
- disc_loss: Tensor = 0
603
- disc_real_losses: List[float] = []
604
- disc_gen_losses: List[float] = []
795
+ x_mels = M(input_proc)
796
+ y_mels = M(target_proc)
605
797
 
606
- for disc in self.disc:
607
- if step_type == "generator":
608
- # feature loss, generator loss, list of generator losses (float)]
609
- f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
610
- gen_loss += g_loss
611
- feat_loss += f_loss
612
- all_g_losses.extend(g_losses)
613
- else:
614
- # [discriminator loss, (disc losses real, disc losses generated)]
615
- d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
616
- disc_loss += d_loss
617
- disc_real_losses.extend(d_real_losses)
618
- disc_gen_losses.extend(d_gen_losses)
798
+ loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
799
+ losses += loss * self.lambda_mel
619
800
 
620
- if step_type == "generator":
621
- return gen_loss, feat_loss, all_g_losses
622
- return disc_loss, disc_real_losses, disc_gen_losses
801
+ # pitch/f0 loss
802
+ if self.use_pitch_loss:
803
+ x_pitch = normalize_unit_norm(M.compute_pitch(input_proc))
804
+ y_pitch = normalize_unit_norm(M.compute_pitch(target_proc))
805
+ f0_loss = self.loss_fn(x_pitch, y_pitch)
806
+ losses += f0_loss * self.lambda_pitch
807
+
808
+ # energy/rms loss
809
+ if self.use_rms_loss:
810
+ x_rms = normalize_unit_norm(M.compute_rms(input_proc, x_mels))
811
+ y_rms = normalize_unit_norm(M.compute_rms(target_proc, y_mels))
812
+ rms_loss = self.loss_fn(x_rms, y_rms)
813
+ losses += rms_loss * self.lambda_rms
814
+
815
+ return losses * self.weight
@@ -23,7 +23,7 @@ class AudioProcessorConfig(ModelConfig):
23
23
  win_length: int = 1024
24
24
  hop_length: int = 256
25
25
  f_min: float = 0
26
- f_max: float = 8000.0
26
+ f_max: Optional[float] = None
27
27
  center: bool = True
28
28
  mel_scale: Literal["htk" "slaney"] = "htk"
29
29
  std: int = 4
@@ -41,8 +41,8 @@ class AudioProcessorConfig(ModelConfig):
41
41
  n_fft: int = 1024,
42
42
  win_length: Optional[int] = None,
43
43
  hop_length: Optional[int] = None,
44
- f_min: float = 1,
45
- f_max: float = 12000.0,
44
+ f_min: float = 0,
45
+ f_max: Optional[float] = None,
46
46
  center: bool = True,
47
47
  mel_scale: Literal["htk", "slaney"] = "htk",
48
48
  std: int = 4,
@@ -71,9 +71,12 @@ class AudioProcessorConfig(ModelConfig):
71
71
  self.post_process()
72
72
 
73
73
  def post_process(self):
74
- self.f_min = max(self.f_min, 1)
75
- self.f_max = max(min(self.f_max, self.n_fft // 2), self.f_min + 1)
76
74
  self.n_stft = self.n_fft // 2 + 1
75
+ # some functions needs this to be a non-zero or not None value.
76
+ self.f_min = max(self.f_min, (self.sample_rate / (self.n_fft - 1)) * 2)
77
+ self.default_f_max = min(
78
+ default(self.f_max, self.sample_rate // 2), self.sample_rate // 2
79
+ )
77
80
  self.hop_length = default(self.hop_length, self.n_fft // 4)
78
81
  self.win_length = default(self.win_length, self.n_fft)
79
82
 
@@ -105,7 +108,7 @@ class AudioProcessor(Model):
105
108
  onesided=self.cfg.onesided,
106
109
  normalized=self.cfg.normalized,
107
110
  )
108
- self.mel_rscale = torchaudio.transforms.InverseMelScale(
111
+ self._mel_rscale = torchaudio.transforms.InverseMelScale(
109
112
  n_stft=self.cfg.n_stft,
110
113
  n_mels=self.cfg.n_mels,
111
114
  sample_rate=self.cfg.sample_rate,
@@ -119,32 +122,39 @@ class AudioProcessor(Model):
119
122
  (torch.hann_window(self.cfg.win_length) if window is None else window),
120
123
  )
121
124
 
122
- def from_numpy(
123
- self,
124
- array: np.ndarray,
125
- device: Optional[torch.device] = None,
126
- dtype: Optional[torch.dtype] = None,
127
- ):
128
- converted = torch.from_numpy(array)
129
- if device is None:
130
- device = self.device
131
- return converted.to(device=device, dtype=dtype)
132
125
 
133
- def from_numpy_batch(
126
+
127
+ def compute_mel(
134
128
  self,
135
- arrays: List[np.ndarray],
136
- device: Optional[torch.device] = None,
137
- dtype: Optional[torch.dtype] = None,
138
- ):
139
- stacked = torch.stack([torch.from_numpy(x) for x in arrays])
140
- if device is None:
141
- device = self.device
142
- return stacked.to(device=device, dtype=dtype)
129
+ wave: Tensor,
130
+ raw_mel_only: bool = False,
131
+ eps: float = 1e-5,
132
+ *,
133
+ _recall: bool = False,
134
+ ) -> Tensor:
135
+ """Returns: [B, M, T]"""
136
+ try:
137
+ mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
138
+ if not raw_mel_only:
139
+ mel_tensor = (
140
+ torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
141
+ ) / self.cfg.std
142
+ return mel_tensor.squeeze()
143
143
 
144
- def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
145
- if isinstance(tensor, np.ndarray):
146
- return tensor
147
- return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
144
+ except RuntimeError as e:
145
+ if not _recall:
146
+ self._mel_spec.to(self.device)
147
+ return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
148
+ raise e
149
+
150
+ def compute_inverse_mel(self, melspec: Tensor, *, _recall=False):
151
+ try:
152
+ return self._mel_rscale.forward(melspec.to(self.device)).squeeze()
153
+ except RuntimeError as e:
154
+ if not _recall:
155
+ self._mel_rscale.to(self.device)
156
+ return self.compute_inverse_mel(melspec, _recall=True)
157
+ raise e
148
158
 
149
159
  def compute_rms(
150
160
  self,
@@ -192,12 +202,44 @@ class AudioProcessor(Model):
192
202
  else:
193
203
  rms_ = []
194
204
  for i in range(B):
195
- _r = librosa.feature.rms(_comp_rms_helper(i, audio, mel), **rms_kwargs)[
205
+ _t = _comp_rms_helper(i, audio, mel)
206
+ _r = librosa.feature.rms(**_t, **rms_kwargs)[
196
207
  0
197
208
  ]
198
209
  rms_.append(_r)
199
210
  return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
200
211
 
212
+ def pitch_shift(self, audio: torch.Tensor, sample_rate: Optional[int] = None, n_steps: float = 2.0):
213
+ """
214
+ Shifts the pitch of an audio tensor by `n_steps` semitones.
215
+
216
+ Args:
217
+ audio (torch.Tensor): Tensor of shape (B, T) or (T,)
218
+ sample_rate (int, optional): Sample rate of the audio. Will use the class sample rate if unset.
219
+ n_steps (float): Number of semitones to shift. Can be negative.
220
+
221
+ Returns:
222
+ torch.Tensor: Pitch-shifted audio.
223
+ """
224
+ src_device = audio.device
225
+ src_dtype = audio.dtype
226
+ audio = audio.squeeze()
227
+ sample_rate = default(sample_rate, self.cfg.sample_rate)
228
+ def _shift_one(wav):
229
+ wav_np = self.to_numpy_safe(wav)
230
+ shifted_np = librosa.effects.pitch_shift(wav_np, sr=sample_rate, n_steps=n_steps)
231
+ return torch.from_numpy(shifted_np)
232
+
233
+ if audio.ndim == 1:
234
+ return _shift_one(audio).to(device=src_device, dtype=src_dtype)
235
+ return torch.stack([_shift_one(a) for a in audio]).to(device=src_device, dtype=src_dtype)
236
+
237
+
238
+ @staticmethod
239
+ def calc_pitch_fmin(sr:int, frame_length:float):
240
+ """For pitch f_min"""
241
+ return (sr / (frame_length - 1)) * 2
242
+
201
243
  def compute_pitch(
202
244
  self,
203
245
  audio: Tensor,
@@ -218,9 +260,9 @@ class AudioProcessor(Model):
218
260
  else:
219
261
  B = 1
220
262
  sr = default(sr, self.cfg.sample_rate)
221
- fmin = max(default(fmin, self.cfg.f_min), 65)
222
- fmax = min(default(fmax, self.cfg.f_max), sr // 2)
223
263
  frame_length = default(frame_length, self.cfg.n_fft)
264
+ fmin = max(default(fmin, self.cfg.f_min), self.calc_pitch_fmin(sr, frame_length))
265
+ fmax = min(max(default(fmax, self.cfg.default_f_max), fmin+1), sr // 2)
224
266
  hop_length = default(hop_length, self.cfg.hop_length)
225
267
  center = default(center, self.cfg.center)
226
268
  yn_kwargs = dict(
@@ -257,10 +299,10 @@ class AudioProcessor(Model):
257
299
  frame_length: Optional[Number] = None,
258
300
  ):
259
301
  sr = default(sr, self.sample_rate)
260
- fmin = max(default(fmin, self.f_min), 1)
261
- fmax = min(default(fmax, self.f_max), sr // 2)
262
- win_length = default(win_length, self.win_length)
263
- frame_length = default(frame_length, self.n_fft)
302
+ win_length = default(win_length, self.cfg.win_length)
303
+ frame_length = default(frame_length, self.cfg.n_fft)
304
+ fmin = default(fmin, self.calc_pitch_fmin(sr, frame_length))
305
+ fmax = default(fmax, self.cfg.default_f_max)
264
306
  return detect_pitch_frequency(
265
307
  audio,
266
308
  sample_rate=sr,
@@ -270,6 +312,33 @@ class AudioProcessor(Model):
270
312
  freq_high=fmax,
271
313
  ).squeeze()
272
314
 
315
+ def from_numpy(
316
+ self,
317
+ array: np.ndarray,
318
+ device: Optional[torch.device] = None,
319
+ dtype: Optional[torch.dtype] = None,
320
+ ):
321
+ converted = torch.from_numpy(array)
322
+ if device is None:
323
+ device = self.device
324
+ return converted.to(device=device, dtype=dtype)
325
+
326
+ def from_numpy_batch(
327
+ self,
328
+ arrays: List[np.ndarray],
329
+ device: Optional[torch.device] = None,
330
+ dtype: Optional[torch.dtype] = None,
331
+ ):
332
+ stacked = torch.stack([torch.from_numpy(x) for x in arrays])
333
+ if device is None:
334
+ device = self.device
335
+ return stacked.to(device=device, dtype=dtype)
336
+
337
+ def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
338
+ if isinstance(tensor, np.ndarray):
339
+ return tensor
340
+ return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
341
+
273
342
  def interpolate(
274
343
  self,
275
344
  tensor: Tensor,
@@ -391,29 +460,6 @@ class AudioProcessor(Model):
391
460
  return self.istft_norm(wave, length, _recall=True)
392
461
  raise e
393
462
 
394
- def compute_mel(
395
- self,
396
- wave: Tensor,
397
- raw_mel_only: bool = False,
398
- eps: float = 1e-5,
399
- *,
400
- _recall: bool = False,
401
- ) -> Tensor:
402
- """Returns: [B, M, T]"""
403
- try:
404
- mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
405
- if not raw_mel_only:
406
- mel_tensor = (
407
- torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
408
- ) / self.cfg.std
409
- return mel_tensor.squeeze()
410
-
411
- except RuntimeError as e:
412
- if not _recall:
413
- self._mel_spec.to(self.device)
414
- return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
415
- raise e
416
-
417
463
  def load_audio(
418
464
  self,
419
465
  path: PathLike,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a32
3
+ Version: 0.0.1a33
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils>=0.0.3
20
+ Requires-Dist: lt-utils>=0.0.4
21
21
  Requires-Dist: librosa==0.11.*
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
@@ -4,7 +4,7 @@ numpy>=1.26.4
4
4
  tokenizers
5
5
  pyyaml>=6.0.0
6
6
  numba>0.60.0
7
- lt-utils>=0.0.3
7
+ lt-utils>=0.0.4
8
8
  librosa==0.11.*
9
9
  einops
10
10
  plotly
@@ -4,7 +4,7 @@ with open("README.md", "r", encoding="utf-8") as f:
4
4
  long_description = f.read()
5
5
 
6
6
  setup(
7
- version="0.0.1a32",
7
+ version="0.0.1a33",
8
8
  name="lt-tensor",
9
9
  description="General utilities for PyTorch and others. Built for general use.",
10
10
  long_description=long_description,
@@ -17,7 +17,7 @@ setup(
17
17
  "tokenizers",
18
18
  "pyyaml>=6.0.0",
19
19
  "numba>0.60.0",
20
- "lt-utils>=0.0.3",
20
+ "lt-utils>=0.0.4",
21
21
  "librosa==0.11.*",
22
22
  "einops",
23
23
  "plotly",
File without changes
File without changes