lt-tensor 0.0.1a33__tar.gz → 0.0.1a34__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/PKG-INFO +1 -1
  2. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/__init__.py +1 -1
  3. lt_tensor-0.0.1a34/lt_tensor/losses.py +277 -0
  4. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/losses/discriminators.py +0 -193
  5. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/PKG-INFO +1 -1
  6. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/setup.py +1 -1
  7. lt_tensor-0.0.1a33/lt_tensor/losses.py +0 -159
  8. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/LICENSE +0 -0
  9. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/README.md +0 -0
  10. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/config_templates.py +0 -0
  11. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/lr_schedulers.py +0 -0
  12. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/math_ops.py +0 -0
  13. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/misc_utils.py +0 -0
  14. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_base.py +0 -0
  15. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/__init__.py +0 -0
  16. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -0
  17. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/act.py +0 -0
  18. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +0 -0
  19. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +0 -0
  20. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
  21. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
  22. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
  23. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +0 -0
  24. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/istft/__init__.py +0 -0
  25. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/basic.py +0 -0
  26. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/convs.py +0 -0
  27. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/features.py +0 -0
  28. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/fusion.py +0 -0
  29. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/losses/__init__.py +0 -0
  30. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  31. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/residual.py +0 -0
  32. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/transformer.py +0 -0
  33. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/monotonic_align.py +0 -0
  34. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/noise_tools.py +0 -0
  35. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/processors/__init__.py +0 -0
  36. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/processors/audio.py +0 -0
  37. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/torch_commons.py +0 -0
  38. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor/transform.py +0 -0
  39. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/SOURCES.txt +0 -0
  40. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/dependency_links.txt +0 -0
  41. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/requires.txt +0 -0
  42. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/top_level.txt +0 -0
  43. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a34}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a33
3
+ Version: 0.0.1a34
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a33"
1
+ __version__ = "0.0.1a34"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
@@ -0,0 +1,277 @@
1
+ __all__ = [
2
+ "masked_cross_entropy",
3
+ "adaptive_l1_loss",
4
+ "contrastive_loss",
5
+ "smooth_l1_loss",
6
+ "hybrid_loss",
7
+ "diff_loss",
8
+ "cosine_loss",
9
+ "ft_n_loss",
10
+ "MultiMelScaleLoss",
11
+ ]
12
+ import math
13
+ import random
14
+ from lt_tensor.torch_commons import *
15
+ from lt_utils.common import *
16
+ import torch.nn.functional as F
17
+ from lt_tensor.model_base import Model
18
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
19
+ from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
20
+
21
+
22
+ def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
23
+ if weight is not None:
24
+ return torch.mean((torch.abs(output - target) + weight) ** 0.5)
25
+ return torch.mean(torch.abs(output - target) ** 0.5)
26
+
27
+
28
+ def adaptive_l1_loss(
29
+ inp: Tensor,
30
+ tgt: Tensor,
31
+ weight: Optional[Tensor] = None,
32
+ scale: float = 1.0,
33
+ inverted: bool = False,
34
+ ):
35
+
36
+ if weight is not None:
37
+ loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
38
+ else:
39
+ loss = torch.mean(torch.abs(inp - tgt))
40
+ loss *= scale
41
+ if inverted:
42
+ return -loss
43
+ return loss
44
+
45
+
46
+ def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
47
+ diff = torch.abs(inp - tgt)
48
+ loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
49
+ if weight is not None:
50
+ loss *= weight
51
+ return loss.mean()
52
+
53
+
54
+ def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
55
+ # label == 1: similar, label == 0: dissimilar
56
+ dist = torch.nn.functional.pairwise_distance(x1, x2)
57
+ loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
58
+ return loss.mean()
59
+
60
+
61
+ def cosine_loss(inp, tgt):
62
+ cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
63
+ return 1 - cos.mean() # Lower is better
64
+
65
+
66
+ def masked_cross_entropy(
67
+ logits: torch.Tensor, # [B, T, V]
68
+ targets: torch.Tensor, # [B, T]
69
+ lengths: torch.Tensor, # [B]
70
+ reduction: str = "mean",
71
+ ) -> torch.Tensor:
72
+ """
73
+ CrossEntropyLoss with masking for variable-length sequences.
74
+ - logits: unnormalized scores [B, T, V]
75
+ - targets: ground truth indices [B, T]
76
+ - lengths: actual sequence lengths [B]
77
+ """
78
+ B, T, V = logits.size()
79
+ logits = logits.view(-1, V)
80
+ targets = targets.view(-1)
81
+
82
+ # Create mask
83
+ mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
84
+ mask = mask.reshape(-1)
85
+
86
+ # Apply CE only where mask == True
87
+ loss = F.cross_entropy(
88
+ logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
89
+ )
90
+ if reduction == "none":
91
+ return loss
92
+ return loss
93
+
94
+
95
+ def diff_loss(pred_noise, true_noise, mask=None):
96
+ """Standard diffusion noise-prediction loss (e.g., DDPM)"""
97
+ if mask is not None:
98
+ return F.mse_loss(pred_noise * mask, true_noise * mask)
99
+ return F.mse_loss(pred_noise, true_noise)
100
+
101
+
102
+ def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
103
+ """Combines L1 and L2"""
104
+ l1 = F.l1_loss(pred_noise, true_noise)
105
+ l2 = F.mse_loss(pred_noise, true_noise)
106
+ return alpha * l1 + (1 - alpha) * l2
107
+
108
+
109
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
110
+ loss = 0
111
+ for real, fake in zip(real_preds, fake_preds):
112
+ if use_lsgan:
113
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
114
+ fake, torch.zeros_like(fake)
115
+ )
116
+ else:
117
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
118
+ torch.log(1 - fake + 1e-7)
119
+ )
120
+ return loss
121
+
122
+
123
+ class MultiMelScaleLoss(Model):
124
+ def __init__(
125
+ self,
126
+ sample_rate: int,
127
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
128
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
129
+ n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
130
+ hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
131
+ f_min: float = [0, 0, 0, 0, 0, 0, 0],
132
+ f_max: Optional[float] = [None, None, None, None, None, None, None],
133
+ loss_fn: Callable = nn.L1Loss(),
134
+ center: bool = True,
135
+ power: float = 1.0,
136
+ normalized: bool = False,
137
+ pad_mode: str = "reflect",
138
+ onesided: Optional[bool] = None,
139
+ std: int = 4,
140
+ mean: int = -4,
141
+ use_istft_norm: bool = True,
142
+ use_pitch_loss: bool = True,
143
+ use_rms_loss: bool = True,
144
+ norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
145
+ norm_rms_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
146
+ lambda_mel: float = 1.0,
147
+ lambda_rms: float = 1.0,
148
+ lambda_pitch: float = 1.0,
149
+ weight: float = 1.0,
150
+ ):
151
+ super().__init__()
152
+ assert (
153
+ len(n_mels)
154
+ == len(window_lengths)
155
+ == len(n_ffts)
156
+ == len(hops)
157
+ == len(f_min)
158
+ == len(f_max)
159
+ )
160
+ self.loss_fn = loss_fn
161
+ self.lambda_mel = lambda_mel
162
+ self.weight = weight
163
+ self.use_istft_norm = use_istft_norm
164
+ self.use_pitch_loss = use_pitch_loss
165
+ self.use_rms_loss = use_rms_loss
166
+ self.lambda_pitch = lambda_pitch
167
+ self.lambda_rms = lambda_rms
168
+
169
+ self.norm_pitch_fn = norm_pitch_fn
170
+ self.norm_rms = norm_rms_fn
171
+
172
+ self._setup_mels(
173
+ sample_rate,
174
+ n_mels,
175
+ window_lengths,
176
+ n_ffts,
177
+ hops,
178
+ f_min,
179
+ f_max,
180
+ center,
181
+ power,
182
+ normalized,
183
+ pad_mode,
184
+ onesided,
185
+ std,
186
+ mean,
187
+ )
188
+
189
+ def _setup_mels(
190
+ self,
191
+ sample_rate: int,
192
+ n_mels: List[int],
193
+ window_lengths: List[int],
194
+ n_ffts: List[int],
195
+ hops: List[int],
196
+ f_min: List[float],
197
+ f_max: List[Optional[float]],
198
+ center: bool,
199
+ power: float,
200
+ normalized: bool,
201
+ pad_mode: str,
202
+ onesided: Optional[bool],
203
+ std: int,
204
+ mean: int,
205
+ ):
206
+ assert (
207
+ len(n_mels)
208
+ == len(window_lengths)
209
+ == len(n_ffts)
210
+ == len(hops)
211
+ == len(f_min)
212
+ == len(f_max)
213
+ )
214
+ _mel_kwargs = dict(
215
+ sample_rate=sample_rate,
216
+ center=center,
217
+ onesided=onesided,
218
+ normalized=normalized,
219
+ power=power,
220
+ pad_mode=pad_mode,
221
+ std=std,
222
+ mean=mean,
223
+ )
224
+ self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
225
+ [
226
+ AudioProcessor(
227
+ AudioProcessorConfig(
228
+ **_mel_kwargs,
229
+ n_mels=mel,
230
+ n_fft=n_fft,
231
+ win_length=win,
232
+ hop_length=hop,
233
+ f_min=fmin,
234
+ f_max=fmax,
235
+ )
236
+ )
237
+ for mel, win, n_fft, hop, fmin, fmax in zip(
238
+ n_mels, window_lengths, n_ffts, hops, f_min, f_max
239
+ )
240
+ ]
241
+ )
242
+
243
+ def forward(
244
+ self, input_wave: torch.Tensor, target_wave: torch.Tensor
245
+ ) -> torch.Tensor:
246
+ assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
247
+ target_wave = target_wave.to(input_wave.device)
248
+ losses = 0.0
249
+ for M in self.mel_spectrograms:
250
+ # Apply normalization if requested
251
+ if self.use_istft_norm:
252
+ input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
253
+ target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
254
+ else:
255
+ input_proc, target_proc = input_wave, target_wave
256
+
257
+ x_mels = M(input_proc)
258
+ y_mels = M(target_proc)
259
+
260
+ loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
261
+ losses += loss * self.lambda_mel
262
+
263
+ # pitch/f0 loss
264
+ if self.use_pitch_loss:
265
+ x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
266
+ y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
267
+ f0_loss = self.loss_fn(x_pitch, y_pitch)
268
+ losses += f0_loss * self.lambda_pitch
269
+
270
+ # energy/rms loss
271
+ if self.use_rms_loss:
272
+ x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
273
+ y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
274
+ rms_loss = self.loss_fn(x_rms, y_rms)
275
+ losses += rms_loss * self.lambda_rms
276
+
277
+ return losses * self.weight
@@ -82,23 +82,6 @@ class MultiDiscriminatorWrapper(Model):
82
82
  return disc_loss, disc_real_losses, disc_gen_losses
83
83
 
84
84
 
85
- def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-5):
86
- norm = torch.norm(x, dim=-1, keepdim=True)
87
- return x / (norm + eps)
88
-
89
-
90
- def normalize_minmax(x: torch.Tensor, eps: float = 1e-5):
91
- min_val = x.amin(dim=-1, keepdim=True)
92
- max_val = x.amax(dim=-1, keepdim=True)
93
- return (x - min_val) / (max_val - min_val + eps)
94
-
95
-
96
- def normalize_zscore(x: torch.Tensor, eps: float = 1e-5):
97
- mean = x.mean(dim=-1, keepdim=True)
98
- std = x.std(dim=-1, keepdim=True)
99
- return (x - mean) / (std + eps)
100
-
101
-
102
85
  def get_padding(kernel_size, dilation=1):
103
86
  return int((kernel_size * dilation - dilation) / 2)
104
87
 
@@ -637,179 +620,3 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
637
620
  y_d_gs.append(y_d_g)
638
621
  fmap_gs.append(fmap_g)
639
622
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
640
-
641
-
642
- class MultiMelScaleLoss(Model):
643
- # TODO: Make the normalization an argument to be chosen by the dev
644
- def __init__(
645
- self,
646
- sample_rate: int,
647
- n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
648
- window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
649
- n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
650
- hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
651
- weight: float = 1.0,
652
- lambda_mel: float = 1.0,
653
- f_min: float = [0, 0, 0, 0, 0, 0, 0],
654
- f_max: Optional[float] = [None, None, None, None, None, None, None],
655
- loss_fn: Callable = nn.L1Loss(),
656
- center: bool = True,
657
- power: float = 1.0,
658
- normalized: bool = False,
659
- pad_mode: str = "reflect",
660
- onesided: Optional[bool] = None,
661
- std: int = 4,
662
- mean: int = -4,
663
- auto_interpolate: bool = True,
664
- use_istft_norm: bool = True,
665
- use_pitch_loss: bool = False,
666
- use_rms_loss: bool = False,
667
- lambda_pitch: float = 0.5,
668
- lambda_rms: float = 0.5,
669
- ):
670
- super().__init__()
671
- assert (
672
- len(n_mels)
673
- == len(window_lengths)
674
- == len(n_ffts)
675
- == len(hops)
676
- == len(f_min)
677
- == len(f_max)
678
- )
679
- self.loss_fn = loss_fn
680
- self.lambda_mel = lambda_mel
681
- self.weight = weight
682
- self.use_istft_norm = use_istft_norm
683
- self.auto_interpolate = auto_interpolate if not self.use_istft_norm else False
684
- self.use_pitch_loss = use_pitch_loss
685
- self.use_rms_loss = use_rms_loss
686
- self.lambda_pitch = lambda_pitch
687
- self.lambda_rms = lambda_rms
688
-
689
- self._setup_mels(
690
- sample_rate,
691
- n_mels,
692
- window_lengths,
693
- n_ffts,
694
- hops,
695
- f_min,
696
- f_max,
697
- center,
698
- power,
699
- normalized,
700
- pad_mode,
701
- onesided,
702
- std,
703
- mean,
704
- )
705
-
706
- def _setup_mels(
707
- self,
708
- sample_rate: int,
709
- n_mels: List[int],
710
- window_lengths: List[int],
711
- n_ffts: List[int],
712
- hops: List[int],
713
- f_min: List[float],
714
- f_max: List[Optional[float]],
715
- center: bool,
716
- power: float,
717
- normalized: bool,
718
- pad_mode: str = "reflect",
719
- onesided: Optional[bool] = None,
720
- std: int = 4,
721
- mean: int = -4,
722
- ):
723
- assert (
724
- len(n_mels)
725
- == len(window_lengths)
726
- == len(n_ffts)
727
- == len(hops)
728
- == len(f_min)
729
- == len(f_max)
730
- )
731
- _mel_kwargs = dict(
732
- sample_rate=sample_rate,
733
- center=center,
734
- onesided=onesided,
735
- normalized=normalized,
736
- power=power,
737
- pad_mode=pad_mode,
738
- std=std,
739
- mean=mean,
740
- )
741
- self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
742
- [
743
- AudioProcessor(
744
- AudioProcessorConfig(
745
- **_mel_kwargs,
746
- n_mels=mel,
747
- n_fft=n_fft,
748
- win_length=win,
749
- hop_length=hop,
750
- f_min=fmin,
751
- f_max=fmax,
752
- )
753
- )
754
- for mel, win, n_fft, hop, fmin, fmax in zip(
755
- n_mels, window_lengths, n_ffts, hops, f_min, f_max
756
- )
757
- ]
758
- )
759
-
760
- def _process_tensor(
761
- self,
762
- input_wave: torch.Tensor,
763
- target_wave: torch.Tensor,
764
- ):
765
- if input_wave.shape[-1] != target_wave.shape[-1]:
766
- if input_wave.ndim < 3:
767
- # To be compatible with interpolatin
768
- if input_wave.ndim == 2:
769
- input_wave = input_wave.unsqueeze(1)
770
- else:
771
- input_wave = input_wave.unsqueeze(0).unsqueeze(0)
772
- input_wave = F.interpolate(input_wave, target_wave.shape[-1], mode="linear")
773
- return input_wave
774
-
775
- def forward(
776
- self, input_wave: torch.Tensor, target_wave: torch.Tensor
777
- ) -> torch.Tensor:
778
- assert (
779
- self.use_istft_norm
780
- or self.auto_interpolate
781
- or input_wave.shape[-1] == target_wave.shape[-1]
782
- )
783
- if self.auto_interpolate:
784
- input_wave = self._process_tensor(input_wave, target_wave)
785
-
786
- losses = 0.0
787
- for M in self.mel_spectrograms:
788
- # Apply normalization if requested
789
- if self.use_istft_norm:
790
- input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
791
- target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
792
- else:
793
- input_proc, target_proc = input_wave, target_wave
794
-
795
- x_mels = M(input_proc)
796
- y_mels = M(target_proc)
797
-
798
- loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
799
- losses += loss * self.lambda_mel
800
-
801
- # pitch/f0 loss
802
- if self.use_pitch_loss:
803
- x_pitch = normalize_unit_norm(M.compute_pitch(input_proc))
804
- y_pitch = normalize_unit_norm(M.compute_pitch(target_proc))
805
- f0_loss = self.loss_fn(x_pitch, y_pitch)
806
- losses += f0_loss * self.lambda_pitch
807
-
808
- # energy/rms loss
809
- if self.use_rms_loss:
810
- x_rms = normalize_unit_norm(M.compute_rms(input_proc, x_mels))
811
- y_rms = normalize_unit_norm(M.compute_rms(target_proc, y_mels))
812
- rms_loss = self.loss_fn(x_rms, y_rms)
813
- losses += rms_loss * self.lambda_rms
814
-
815
- return losses * self.weight
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a33
3
+ Version: 0.0.1a34
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -4,7 +4,7 @@ with open("README.md", "r", encoding="utf-8") as f:
4
4
  long_description = f.read()
5
5
 
6
6
  setup(
7
- version="0.0.1a33",
7
+ version="0.0.1a34",
8
8
  name="lt-tensor",
9
9
  description="General utilities for PyTorch and others. Built for general use.",
10
10
  long_description=long_description,
@@ -1,159 +0,0 @@
1
- __all__ = [
2
- "masked_cross_entropy",
3
- "adaptive_l1_loss",
4
- "contrastive_loss",
5
- "smooth_l1_loss",
6
- "hybrid_loss",
7
- "diff_loss",
8
- "cosine_loss",
9
- "gan_loss",
10
- "ft_n_loss",
11
- ]
12
- import math
13
- import random
14
- from lt_tensor.torch_commons import *
15
- from lt_utils.common import *
16
- import torch.nn.functional as F
17
-
18
- def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
19
- if weight is not None:
20
- return torch.mean((torch.abs(output - target) + weight) **0.5)
21
- return torch.mean(torch.abs(output - target)**0.5)
22
-
23
- def adaptive_l1_loss(
24
- inp: Tensor,
25
- tgt: Tensor,
26
- weight: Optional[Tensor] = None,
27
- scale: float = 1.0,
28
- inverted: bool = False,
29
- ):
30
-
31
- if weight is not None:
32
- loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
33
- else:
34
- loss = torch.mean(torch.abs(inp - tgt))
35
- loss *= scale
36
- if inverted:
37
- return -loss
38
- return loss
39
-
40
-
41
- def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
42
- diff = torch.abs(inp - tgt)
43
- loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
44
- if weight is not None:
45
- loss *= weight
46
- return loss.mean()
47
-
48
-
49
- def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
50
- # label == 1: similar, label == 0: dissimilar
51
- dist = torch.nn.functional.pairwise_distance(x1, x2)
52
- loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
53
- return loss.mean()
54
-
55
-
56
- def cosine_loss(inp, tgt):
57
- cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
58
- return 1 - cos.mean() # Lower is better
59
-
60
-
61
- class GanLosses:
62
- @staticmethod
63
- def get_loss(
64
- pred: Tensor,
65
- target_is_real: bool,
66
- loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
67
- ) -> Tensor:
68
- if loss_type == "bce": # Standard GAN
69
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
70
- return F.binary_cross_entropy_with_logits(pred, target)
71
-
72
- elif loss_type == "mse": # LSGAN
73
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
74
- return F.mse_loss(torch.sigmoid(pred), target)
75
-
76
- elif loss_type == "hinge":
77
- if target_is_real:
78
- return torch.mean(F.relu(1.0 - pred))
79
- else:
80
- return torch.mean(F.relu(1.0 + pred))
81
-
82
- elif loss_type == "wasserstein":
83
- return -pred.mean() if target_is_real else pred.mean()
84
-
85
- else:
86
- raise ValueError(f"Unknown loss_type: {loss_type}")
87
-
88
- @staticmethod
89
- def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
90
- return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
91
-
92
- @staticmethod
93
- def discriminator_loss(
94
- real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
95
- ) -> Tensor:
96
- real_loss = GanLosses.get_loss(
97
- real_pred, target_is_real=True, loss_type=loss_type
98
- )
99
- fake_loss = GanLosses.get_loss(
100
- fake_pred.detach(), target_is_real=False, loss_type=loss_type
101
- )
102
- return (real_loss + fake_loss) * 0.5
103
-
104
-
105
- def masked_cross_entropy(
106
- logits: torch.Tensor, # [B, T, V]
107
- targets: torch.Tensor, # [B, T]
108
- lengths: torch.Tensor, # [B]
109
- reduction: str = "mean",
110
- ) -> torch.Tensor:
111
- """
112
- CrossEntropyLoss with masking for variable-length sequences.
113
- - logits: unnormalized scores [B, T, V]
114
- - targets: ground truth indices [B, T]
115
- - lengths: actual sequence lengths [B]
116
- """
117
- B, T, V = logits.size()
118
- logits = logits.view(-1, V)
119
- targets = targets.view(-1)
120
-
121
- # Create mask
122
- mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
123
- mask = mask.reshape(-1)
124
-
125
- # Apply CE only where mask == True
126
- loss = F.cross_entropy(
127
- logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
128
- )
129
- if reduction == "none":
130
- return loss
131
- return loss
132
-
133
-
134
- def diff_loss(pred_noise, true_noise, mask=None):
135
- """Standard diffusion noise-prediction loss (e.g., DDPM)"""
136
- if mask is not None:
137
- return F.mse_loss(pred_noise * mask, true_noise * mask)
138
- return F.mse_loss(pred_noise, true_noise)
139
-
140
-
141
- def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
142
- """Combines L1 and L2"""
143
- l1 = F.l1_loss(pred_noise, true_noise)
144
- l2 = F.mse_loss(pred_noise, true_noise)
145
- return alpha * l1 + (1 - alpha) * l2
146
-
147
-
148
- def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
149
- loss = 0
150
- for real, fake in zip(real_preds, fake_preds):
151
- if use_lsgan:
152
- loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
153
- fake, torch.zeros_like(fake)
154
- )
155
- else:
156
- loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
157
- torch.log(1 - fake + 1e-7)
158
- )
159
- return loss
File without changes
File without changes
File without changes