lt-tensor 0.0.1a33__py3-none-any.whl → 0.0.1a34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a33"
1
+ __version__ = "0.0.1a34"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
lt_tensor/losses.py CHANGED
@@ -6,19 +6,24 @@ __all__ = [
6
6
  "hybrid_loss",
7
7
  "diff_loss",
8
8
  "cosine_loss",
9
- "gan_loss",
10
9
  "ft_n_loss",
10
+ "MultiMelScaleLoss",
11
11
  ]
12
12
  import math
13
13
  import random
14
14
  from lt_tensor.torch_commons import *
15
15
  from lt_utils.common import *
16
16
  import torch.nn.functional as F
17
+ from lt_tensor.model_base import Model
18
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
19
+ from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
20
+
17
21
 
18
22
  def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
19
23
  if weight is not None:
20
- return torch.mean((torch.abs(output - target) + weight) **0.5)
21
- return torch.mean(torch.abs(output - target)**0.5)
24
+ return torch.mean((torch.abs(output - target) + weight) ** 0.5)
25
+ return torch.mean(torch.abs(output - target) ** 0.5)
26
+
22
27
 
23
28
  def adaptive_l1_loss(
24
29
  inp: Tensor,
@@ -58,50 +63,6 @@ def cosine_loss(inp, tgt):
58
63
  return 1 - cos.mean() # Lower is better
59
64
 
60
65
 
61
- class GanLosses:
62
- @staticmethod
63
- def get_loss(
64
- pred: Tensor,
65
- target_is_real: bool,
66
- loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
67
- ) -> Tensor:
68
- if loss_type == "bce": # Standard GAN
69
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
70
- return F.binary_cross_entropy_with_logits(pred, target)
71
-
72
- elif loss_type == "mse": # LSGAN
73
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
74
- return F.mse_loss(torch.sigmoid(pred), target)
75
-
76
- elif loss_type == "hinge":
77
- if target_is_real:
78
- return torch.mean(F.relu(1.0 - pred))
79
- else:
80
- return torch.mean(F.relu(1.0 + pred))
81
-
82
- elif loss_type == "wasserstein":
83
- return -pred.mean() if target_is_real else pred.mean()
84
-
85
- else:
86
- raise ValueError(f"Unknown loss_type: {loss_type}")
87
-
88
- @staticmethod
89
- def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
90
- return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
91
-
92
- @staticmethod
93
- def discriminator_loss(
94
- real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
95
- ) -> Tensor:
96
- real_loss = GanLosses.get_loss(
97
- real_pred, target_is_real=True, loss_type=loss_type
98
- )
99
- fake_loss = GanLosses.get_loss(
100
- fake_pred.detach(), target_is_real=False, loss_type=loss_type
101
- )
102
- return (real_loss + fake_loss) * 0.5
103
-
104
-
105
66
  def masked_cross_entropy(
106
67
  logits: torch.Tensor, # [B, T, V]
107
68
  targets: torch.Tensor, # [B, T]
@@ -157,3 +118,160 @@ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
157
118
  torch.log(1 - fake + 1e-7)
158
119
  )
159
120
  return loss
121
+
122
+
123
+ class MultiMelScaleLoss(Model):
124
+ def __init__(
125
+ self,
126
+ sample_rate: int,
127
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
128
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
129
+ n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
130
+ hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
131
+ f_min: float = [0, 0, 0, 0, 0, 0, 0],
132
+ f_max: Optional[float] = [None, None, None, None, None, None, None],
133
+ loss_fn: Callable = nn.L1Loss(),
134
+ center: bool = True,
135
+ power: float = 1.0,
136
+ normalized: bool = False,
137
+ pad_mode: str = "reflect",
138
+ onesided: Optional[bool] = None,
139
+ std: int = 4,
140
+ mean: int = -4,
141
+ use_istft_norm: bool = True,
142
+ use_pitch_loss: bool = True,
143
+ use_rms_loss: bool = True,
144
+ norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
145
+ norm_rms_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
146
+ lambda_mel: float = 1.0,
147
+ lambda_rms: float = 1.0,
148
+ lambda_pitch: float = 1.0,
149
+ weight: float = 1.0,
150
+ ):
151
+ super().__init__()
152
+ assert (
153
+ len(n_mels)
154
+ == len(window_lengths)
155
+ == len(n_ffts)
156
+ == len(hops)
157
+ == len(f_min)
158
+ == len(f_max)
159
+ )
160
+ self.loss_fn = loss_fn
161
+ self.lambda_mel = lambda_mel
162
+ self.weight = weight
163
+ self.use_istft_norm = use_istft_norm
164
+ self.use_pitch_loss = use_pitch_loss
165
+ self.use_rms_loss = use_rms_loss
166
+ self.lambda_pitch = lambda_pitch
167
+ self.lambda_rms = lambda_rms
168
+
169
+ self.norm_pitch_fn = norm_pitch_fn
170
+ self.norm_rms = norm_rms_fn
171
+
172
+ self._setup_mels(
173
+ sample_rate,
174
+ n_mels,
175
+ window_lengths,
176
+ n_ffts,
177
+ hops,
178
+ f_min,
179
+ f_max,
180
+ center,
181
+ power,
182
+ normalized,
183
+ pad_mode,
184
+ onesided,
185
+ std,
186
+ mean,
187
+ )
188
+
189
+ def _setup_mels(
190
+ self,
191
+ sample_rate: int,
192
+ n_mels: List[int],
193
+ window_lengths: List[int],
194
+ n_ffts: List[int],
195
+ hops: List[int],
196
+ f_min: List[float],
197
+ f_max: List[Optional[float]],
198
+ center: bool,
199
+ power: float,
200
+ normalized: bool,
201
+ pad_mode: str,
202
+ onesided: Optional[bool],
203
+ std: int,
204
+ mean: int,
205
+ ):
206
+ assert (
207
+ len(n_mels)
208
+ == len(window_lengths)
209
+ == len(n_ffts)
210
+ == len(hops)
211
+ == len(f_min)
212
+ == len(f_max)
213
+ )
214
+ _mel_kwargs = dict(
215
+ sample_rate=sample_rate,
216
+ center=center,
217
+ onesided=onesided,
218
+ normalized=normalized,
219
+ power=power,
220
+ pad_mode=pad_mode,
221
+ std=std,
222
+ mean=mean,
223
+ )
224
+ self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
225
+ [
226
+ AudioProcessor(
227
+ AudioProcessorConfig(
228
+ **_mel_kwargs,
229
+ n_mels=mel,
230
+ n_fft=n_fft,
231
+ win_length=win,
232
+ hop_length=hop,
233
+ f_min=fmin,
234
+ f_max=fmax,
235
+ )
236
+ )
237
+ for mel, win, n_fft, hop, fmin, fmax in zip(
238
+ n_mels, window_lengths, n_ffts, hops, f_min, f_max
239
+ )
240
+ ]
241
+ )
242
+
243
+ def forward(
244
+ self, input_wave: torch.Tensor, target_wave: torch.Tensor
245
+ ) -> torch.Tensor:
246
+ assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
247
+ target_wave = target_wave.to(input_wave.device)
248
+ losses = 0.0
249
+ for M in self.mel_spectrograms:
250
+ # Apply normalization if requested
251
+ if self.use_istft_norm:
252
+ input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
253
+ target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
254
+ else:
255
+ input_proc, target_proc = input_wave, target_wave
256
+
257
+ x_mels = M(input_proc)
258
+ y_mels = M(target_proc)
259
+
260
+ loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
261
+ losses += loss * self.lambda_mel
262
+
263
+ # pitch/f0 loss
264
+ if self.use_pitch_loss:
265
+ x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
266
+ y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
267
+ f0_loss = self.loss_fn(x_pitch, y_pitch)
268
+ losses += f0_loss * self.lambda_pitch
269
+
270
+ # energy/rms loss
271
+ if self.use_rms_loss:
272
+ x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
273
+ y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
274
+ rms_loss = self.loss_fn(x_rms, y_rms)
275
+ losses += rms_loss * self.lambda_rms
276
+
277
+ return losses * self.weight
@@ -82,23 +82,6 @@ class MultiDiscriminatorWrapper(Model):
82
82
  return disc_loss, disc_real_losses, disc_gen_losses
83
83
 
84
84
 
85
- def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-5):
86
- norm = torch.norm(x, dim=-1, keepdim=True)
87
- return x / (norm + eps)
88
-
89
-
90
- def normalize_minmax(x: torch.Tensor, eps: float = 1e-5):
91
- min_val = x.amin(dim=-1, keepdim=True)
92
- max_val = x.amax(dim=-1, keepdim=True)
93
- return (x - min_val) / (max_val - min_val + eps)
94
-
95
-
96
- def normalize_zscore(x: torch.Tensor, eps: float = 1e-5):
97
- mean = x.mean(dim=-1, keepdim=True)
98
- std = x.std(dim=-1, keepdim=True)
99
- return (x - mean) / (std + eps)
100
-
101
-
102
85
  def get_padding(kernel_size, dilation=1):
103
86
  return int((kernel_size * dilation - dilation) / 2)
104
87
 
@@ -637,179 +620,3 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
637
620
  y_d_gs.append(y_d_g)
638
621
  fmap_gs.append(fmap_g)
639
622
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
640
-
641
-
642
- class MultiMelScaleLoss(Model):
643
- # TODO: Make the normalization an argument to be chosen by the dev
644
- def __init__(
645
- self,
646
- sample_rate: int,
647
- n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
648
- window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
649
- n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
650
- hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
651
- weight: float = 1.0,
652
- lambda_mel: float = 1.0,
653
- f_min: float = [0, 0, 0, 0, 0, 0, 0],
654
- f_max: Optional[float] = [None, None, None, None, None, None, None],
655
- loss_fn: Callable = nn.L1Loss(),
656
- center: bool = True,
657
- power: float = 1.0,
658
- normalized: bool = False,
659
- pad_mode: str = "reflect",
660
- onesided: Optional[bool] = None,
661
- std: int = 4,
662
- mean: int = -4,
663
- auto_interpolate: bool = True,
664
- use_istft_norm: bool = True,
665
- use_pitch_loss: bool = False,
666
- use_rms_loss: bool = False,
667
- lambda_pitch: float = 0.5,
668
- lambda_rms: float = 0.5,
669
- ):
670
- super().__init__()
671
- assert (
672
- len(n_mels)
673
- == len(window_lengths)
674
- == len(n_ffts)
675
- == len(hops)
676
- == len(f_min)
677
- == len(f_max)
678
- )
679
- self.loss_fn = loss_fn
680
- self.lambda_mel = lambda_mel
681
- self.weight = weight
682
- self.use_istft_norm = use_istft_norm
683
- self.auto_interpolate = auto_interpolate if not self.use_istft_norm else False
684
- self.use_pitch_loss = use_pitch_loss
685
- self.use_rms_loss = use_rms_loss
686
- self.lambda_pitch = lambda_pitch
687
- self.lambda_rms = lambda_rms
688
-
689
- self._setup_mels(
690
- sample_rate,
691
- n_mels,
692
- window_lengths,
693
- n_ffts,
694
- hops,
695
- f_min,
696
- f_max,
697
- center,
698
- power,
699
- normalized,
700
- pad_mode,
701
- onesided,
702
- std,
703
- mean,
704
- )
705
-
706
- def _setup_mels(
707
- self,
708
- sample_rate: int,
709
- n_mels: List[int],
710
- window_lengths: List[int],
711
- n_ffts: List[int],
712
- hops: List[int],
713
- f_min: List[float],
714
- f_max: List[Optional[float]],
715
- center: bool,
716
- power: float,
717
- normalized: bool,
718
- pad_mode: str = "reflect",
719
- onesided: Optional[bool] = None,
720
- std: int = 4,
721
- mean: int = -4,
722
- ):
723
- assert (
724
- len(n_mels)
725
- == len(window_lengths)
726
- == len(n_ffts)
727
- == len(hops)
728
- == len(f_min)
729
- == len(f_max)
730
- )
731
- _mel_kwargs = dict(
732
- sample_rate=sample_rate,
733
- center=center,
734
- onesided=onesided,
735
- normalized=normalized,
736
- power=power,
737
- pad_mode=pad_mode,
738
- std=std,
739
- mean=mean,
740
- )
741
- self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
742
- [
743
- AudioProcessor(
744
- AudioProcessorConfig(
745
- **_mel_kwargs,
746
- n_mels=mel,
747
- n_fft=n_fft,
748
- win_length=win,
749
- hop_length=hop,
750
- f_min=fmin,
751
- f_max=fmax,
752
- )
753
- )
754
- for mel, win, n_fft, hop, fmin, fmax in zip(
755
- n_mels, window_lengths, n_ffts, hops, f_min, f_max
756
- )
757
- ]
758
- )
759
-
760
- def _process_tensor(
761
- self,
762
- input_wave: torch.Tensor,
763
- target_wave: torch.Tensor,
764
- ):
765
- if input_wave.shape[-1] != target_wave.shape[-1]:
766
- if input_wave.ndim < 3:
767
- # To be compatible with interpolatin
768
- if input_wave.ndim == 2:
769
- input_wave = input_wave.unsqueeze(1)
770
- else:
771
- input_wave = input_wave.unsqueeze(0).unsqueeze(0)
772
- input_wave = F.interpolate(input_wave, target_wave.shape[-1], mode="linear")
773
- return input_wave
774
-
775
- def forward(
776
- self, input_wave: torch.Tensor, target_wave: torch.Tensor
777
- ) -> torch.Tensor:
778
- assert (
779
- self.use_istft_norm
780
- or self.auto_interpolate
781
- or input_wave.shape[-1] == target_wave.shape[-1]
782
- )
783
- if self.auto_interpolate:
784
- input_wave = self._process_tensor(input_wave, target_wave)
785
-
786
- losses = 0.0
787
- for M in self.mel_spectrograms:
788
- # Apply normalization if requested
789
- if self.use_istft_norm:
790
- input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
791
- target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
792
- else:
793
- input_proc, target_proc = input_wave, target_wave
794
-
795
- x_mels = M(input_proc)
796
- y_mels = M(target_proc)
797
-
798
- loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
799
- losses += loss * self.lambda_mel
800
-
801
- # pitch/f0 loss
802
- if self.use_pitch_loss:
803
- x_pitch = normalize_unit_norm(M.compute_pitch(input_proc))
804
- y_pitch = normalize_unit_norm(M.compute_pitch(target_proc))
805
- f0_loss = self.loss_fn(x_pitch, y_pitch)
806
- losses += f0_loss * self.lambda_pitch
807
-
808
- # energy/rms loss
809
- if self.use_rms_loss:
810
- x_rms = normalize_unit_norm(M.compute_rms(input_proc, x_mels))
811
- y_rms = normalize_unit_norm(M.compute_rms(target_proc, y_mels))
812
- rms_loss = self.loss_fn(x_rms, y_rms)
813
- losses += rms_loss * self.lambda_rms
814
-
815
- return losses * self.weight
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a33
3
+ Version: 0.0.1a34
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,6 +1,6 @@
1
- lt_tensor/__init__.py,sha256=f3wraCpbx0fV2tQgsZfKw1ifTPp87hSCOZmE0d09LYk,441
1
+ lt_tensor/__init__.py,sha256=WAGPuMPq5c4DGAJ57x1Ykgzg3vMlLq9BiWk5EdJcUsU,441
2
2
  lt_tensor/config_templates.py,sha256=F9UvL8paAjkSvio890kp8WznpYeI50pYnm9iqQroBxk,2797
3
- lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
3
+ lt_tensor/losses.py,sha256=fHVMqOFo3ekjORYy89R_aRjmtT6lo27Z1egzOYjQ1W8,8646
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
5
  lt_tensor/math_ops.py,sha256=ahX6Z1Mt3X-FhmwSZYZea5mB1B0S8GDuvKPfAm5e_FQ,2646
6
6
  lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
@@ -27,11 +27,11 @@ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
29
  lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
30
- lt_tensor/model_zoo/losses/discriminators.py,sha256=0b4ikOFy8Ubozq0Igs7X1ELQD5JrPA3jwR4dzuEa6hM,27047
30
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=ZpyByFgc7L7uV_XRBsV9vkdVItbJO3z--Y6LlvTvtwY,20765
31
31
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
32
32
  lt_tensor/processors/audio.py,sha256=HNr1GS-6M2q0Rda4cErf5y2Jlc9f4jD58FvpX2ua9d4,18369
33
- lt_tensor-0.0.1a33.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
34
- lt_tensor-0.0.1a33.dist-info/METADATA,sha256=6xlFyxd0mYYqTi8oSS3M99mnqZUQrmtp3_AJt-rlewg,1062
35
- lt_tensor-0.0.1a33.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- lt_tensor-0.0.1a33.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
- lt_tensor-0.0.1a33.dist-info/RECORD,,
33
+ lt_tensor-0.0.1a34.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
34
+ lt_tensor-0.0.1a34.dist-info/METADATA,sha256=WkTafcY5nYZbrZ7WzUc3JXnmg9NtUAXrchx42dCok9I,1062
35
+ lt_tensor-0.0.1a34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ lt_tensor-0.0.1a34.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
+ lt_tensor-0.0.1a34.dist-info/RECORD,,