lt-tensor 0.0.1a32__tar.gz → 0.0.1a34__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/LICENSE +1 -1
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/PKG-INFO +2 -2
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/__init__.py +1 -1
- lt_tensor-0.0.1a34/lt_tensor/losses.py +277 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/math_ops.py +19 -6
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/losses/discriminators.py +73 -73
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/processors/audio.py +105 -59
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/PKG-INFO +2 -2
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/requires.txt +1 -1
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/setup.py +2 -2
- lt_tensor-0.0.1a32/lt_tensor/losses.py +0 -159
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/README.md +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/config_templates.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/lr_schedulers.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/misc_utils.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_base.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/act.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/istft/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/basic.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/convs.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/features.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/fusion.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/losses/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/pos_encoder.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/residual.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/transformer.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/monotonic_align.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/noise_tools.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/processors/__init__.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/torch_commons.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/transform.py +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/SOURCES.txt +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/dependency_links.txt +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor.egg-info/top_level.txt +0 -0
- {lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/setup.cfg +0 -0
@@ -186,7 +186,7 @@
|
|
186
186
|
same "printed page" as the copyright notice for easier
|
187
187
|
identification within third-party archives.
|
188
188
|
|
189
|
-
Copyright 2025 gr1336
|
189
|
+
Copyright 2025 gr1336 (Gabriel Ribeiro)
|
190
190
|
|
191
191
|
Licensed under the Apache License, Version 2.0 (the "License");
|
192
192
|
you may not use this file except in compliance with the License.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: lt-tensor
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a34
|
4
4
|
Summary: General utilities for PyTorch and others. Built for general use.
|
5
5
|
Home-page: https://github.com/gr1336/lt-tensor/
|
6
6
|
Author: gr1336
|
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
|
|
17
17
|
Requires-Dist: tokenizers
|
18
18
|
Requires-Dist: pyyaml>=6.0.0
|
19
19
|
Requires-Dist: numba>0.60.0
|
20
|
-
Requires-Dist: lt-utils>=0.0.
|
20
|
+
Requires-Dist: lt-utils>=0.0.4
|
21
21
|
Requires-Dist: librosa==0.11.*
|
22
22
|
Requires-Dist: einops
|
23
23
|
Requires-Dist: plotly
|
@@ -0,0 +1,277 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"masked_cross_entropy",
|
3
|
+
"adaptive_l1_loss",
|
4
|
+
"contrastive_loss",
|
5
|
+
"smooth_l1_loss",
|
6
|
+
"hybrid_loss",
|
7
|
+
"diff_loss",
|
8
|
+
"cosine_loss",
|
9
|
+
"ft_n_loss",
|
10
|
+
"MultiMelScaleLoss",
|
11
|
+
]
|
12
|
+
import math
|
13
|
+
import random
|
14
|
+
from lt_tensor.torch_commons import *
|
15
|
+
from lt_utils.common import *
|
16
|
+
import torch.nn.functional as F
|
17
|
+
from lt_tensor.model_base import Model
|
18
|
+
from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
|
19
|
+
from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
|
20
|
+
|
21
|
+
|
22
|
+
def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
|
23
|
+
if weight is not None:
|
24
|
+
return torch.mean((torch.abs(output - target) + weight) ** 0.5)
|
25
|
+
return torch.mean(torch.abs(output - target) ** 0.5)
|
26
|
+
|
27
|
+
|
28
|
+
def adaptive_l1_loss(
|
29
|
+
inp: Tensor,
|
30
|
+
tgt: Tensor,
|
31
|
+
weight: Optional[Tensor] = None,
|
32
|
+
scale: float = 1.0,
|
33
|
+
inverted: bool = False,
|
34
|
+
):
|
35
|
+
|
36
|
+
if weight is not None:
|
37
|
+
loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
|
38
|
+
else:
|
39
|
+
loss = torch.mean(torch.abs(inp - tgt))
|
40
|
+
loss *= scale
|
41
|
+
if inverted:
|
42
|
+
return -loss
|
43
|
+
return loss
|
44
|
+
|
45
|
+
|
46
|
+
def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
|
47
|
+
diff = torch.abs(inp - tgt)
|
48
|
+
loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
|
49
|
+
if weight is not None:
|
50
|
+
loss *= weight
|
51
|
+
return loss.mean()
|
52
|
+
|
53
|
+
|
54
|
+
def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
|
55
|
+
# label == 1: similar, label == 0: dissimilar
|
56
|
+
dist = torch.nn.functional.pairwise_distance(x1, x2)
|
57
|
+
loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
|
58
|
+
return loss.mean()
|
59
|
+
|
60
|
+
|
61
|
+
def cosine_loss(inp, tgt):
|
62
|
+
cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
|
63
|
+
return 1 - cos.mean() # Lower is better
|
64
|
+
|
65
|
+
|
66
|
+
def masked_cross_entropy(
|
67
|
+
logits: torch.Tensor, # [B, T, V]
|
68
|
+
targets: torch.Tensor, # [B, T]
|
69
|
+
lengths: torch.Tensor, # [B]
|
70
|
+
reduction: str = "mean",
|
71
|
+
) -> torch.Tensor:
|
72
|
+
"""
|
73
|
+
CrossEntropyLoss with masking for variable-length sequences.
|
74
|
+
- logits: unnormalized scores [B, T, V]
|
75
|
+
- targets: ground truth indices [B, T]
|
76
|
+
- lengths: actual sequence lengths [B]
|
77
|
+
"""
|
78
|
+
B, T, V = logits.size()
|
79
|
+
logits = logits.view(-1, V)
|
80
|
+
targets = targets.view(-1)
|
81
|
+
|
82
|
+
# Create mask
|
83
|
+
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
84
|
+
mask = mask.reshape(-1)
|
85
|
+
|
86
|
+
# Apply CE only where mask == True
|
87
|
+
loss = F.cross_entropy(
|
88
|
+
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
89
|
+
)
|
90
|
+
if reduction == "none":
|
91
|
+
return loss
|
92
|
+
return loss
|
93
|
+
|
94
|
+
|
95
|
+
def diff_loss(pred_noise, true_noise, mask=None):
|
96
|
+
"""Standard diffusion noise-prediction loss (e.g., DDPM)"""
|
97
|
+
if mask is not None:
|
98
|
+
return F.mse_loss(pred_noise * mask, true_noise * mask)
|
99
|
+
return F.mse_loss(pred_noise, true_noise)
|
100
|
+
|
101
|
+
|
102
|
+
def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
|
103
|
+
"""Combines L1 and L2"""
|
104
|
+
l1 = F.l1_loss(pred_noise, true_noise)
|
105
|
+
l2 = F.mse_loss(pred_noise, true_noise)
|
106
|
+
return alpha * l1 + (1 - alpha) * l2
|
107
|
+
|
108
|
+
|
109
|
+
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
110
|
+
loss = 0
|
111
|
+
for real, fake in zip(real_preds, fake_preds):
|
112
|
+
if use_lsgan:
|
113
|
+
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
114
|
+
fake, torch.zeros_like(fake)
|
115
|
+
)
|
116
|
+
else:
|
117
|
+
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
118
|
+
torch.log(1 - fake + 1e-7)
|
119
|
+
)
|
120
|
+
return loss
|
121
|
+
|
122
|
+
|
123
|
+
class MultiMelScaleLoss(Model):
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
sample_rate: int,
|
127
|
+
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
128
|
+
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
129
|
+
n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
130
|
+
hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
|
131
|
+
f_min: float = [0, 0, 0, 0, 0, 0, 0],
|
132
|
+
f_max: Optional[float] = [None, None, None, None, None, None, None],
|
133
|
+
loss_fn: Callable = nn.L1Loss(),
|
134
|
+
center: bool = True,
|
135
|
+
power: float = 1.0,
|
136
|
+
normalized: bool = False,
|
137
|
+
pad_mode: str = "reflect",
|
138
|
+
onesided: Optional[bool] = None,
|
139
|
+
std: int = 4,
|
140
|
+
mean: int = -4,
|
141
|
+
use_istft_norm: bool = True,
|
142
|
+
use_pitch_loss: bool = True,
|
143
|
+
use_rms_loss: bool = True,
|
144
|
+
norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
|
145
|
+
norm_rms_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
|
146
|
+
lambda_mel: float = 1.0,
|
147
|
+
lambda_rms: float = 1.0,
|
148
|
+
lambda_pitch: float = 1.0,
|
149
|
+
weight: float = 1.0,
|
150
|
+
):
|
151
|
+
super().__init__()
|
152
|
+
assert (
|
153
|
+
len(n_mels)
|
154
|
+
== len(window_lengths)
|
155
|
+
== len(n_ffts)
|
156
|
+
== len(hops)
|
157
|
+
== len(f_min)
|
158
|
+
== len(f_max)
|
159
|
+
)
|
160
|
+
self.loss_fn = loss_fn
|
161
|
+
self.lambda_mel = lambda_mel
|
162
|
+
self.weight = weight
|
163
|
+
self.use_istft_norm = use_istft_norm
|
164
|
+
self.use_pitch_loss = use_pitch_loss
|
165
|
+
self.use_rms_loss = use_rms_loss
|
166
|
+
self.lambda_pitch = lambda_pitch
|
167
|
+
self.lambda_rms = lambda_rms
|
168
|
+
|
169
|
+
self.norm_pitch_fn = norm_pitch_fn
|
170
|
+
self.norm_rms = norm_rms_fn
|
171
|
+
|
172
|
+
self._setup_mels(
|
173
|
+
sample_rate,
|
174
|
+
n_mels,
|
175
|
+
window_lengths,
|
176
|
+
n_ffts,
|
177
|
+
hops,
|
178
|
+
f_min,
|
179
|
+
f_max,
|
180
|
+
center,
|
181
|
+
power,
|
182
|
+
normalized,
|
183
|
+
pad_mode,
|
184
|
+
onesided,
|
185
|
+
std,
|
186
|
+
mean,
|
187
|
+
)
|
188
|
+
|
189
|
+
def _setup_mels(
|
190
|
+
self,
|
191
|
+
sample_rate: int,
|
192
|
+
n_mels: List[int],
|
193
|
+
window_lengths: List[int],
|
194
|
+
n_ffts: List[int],
|
195
|
+
hops: List[int],
|
196
|
+
f_min: List[float],
|
197
|
+
f_max: List[Optional[float]],
|
198
|
+
center: bool,
|
199
|
+
power: float,
|
200
|
+
normalized: bool,
|
201
|
+
pad_mode: str,
|
202
|
+
onesided: Optional[bool],
|
203
|
+
std: int,
|
204
|
+
mean: int,
|
205
|
+
):
|
206
|
+
assert (
|
207
|
+
len(n_mels)
|
208
|
+
== len(window_lengths)
|
209
|
+
== len(n_ffts)
|
210
|
+
== len(hops)
|
211
|
+
== len(f_min)
|
212
|
+
== len(f_max)
|
213
|
+
)
|
214
|
+
_mel_kwargs = dict(
|
215
|
+
sample_rate=sample_rate,
|
216
|
+
center=center,
|
217
|
+
onesided=onesided,
|
218
|
+
normalized=normalized,
|
219
|
+
power=power,
|
220
|
+
pad_mode=pad_mode,
|
221
|
+
std=std,
|
222
|
+
mean=mean,
|
223
|
+
)
|
224
|
+
self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
|
225
|
+
[
|
226
|
+
AudioProcessor(
|
227
|
+
AudioProcessorConfig(
|
228
|
+
**_mel_kwargs,
|
229
|
+
n_mels=mel,
|
230
|
+
n_fft=n_fft,
|
231
|
+
win_length=win,
|
232
|
+
hop_length=hop,
|
233
|
+
f_min=fmin,
|
234
|
+
f_max=fmax,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
for mel, win, n_fft, hop, fmin, fmax in zip(
|
238
|
+
n_mels, window_lengths, n_ffts, hops, f_min, f_max
|
239
|
+
)
|
240
|
+
]
|
241
|
+
)
|
242
|
+
|
243
|
+
def forward(
|
244
|
+
self, input_wave: torch.Tensor, target_wave: torch.Tensor
|
245
|
+
) -> torch.Tensor:
|
246
|
+
assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
|
247
|
+
target_wave = target_wave.to(input_wave.device)
|
248
|
+
losses = 0.0
|
249
|
+
for M in self.mel_spectrograms:
|
250
|
+
# Apply normalization if requested
|
251
|
+
if self.use_istft_norm:
|
252
|
+
input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
|
253
|
+
target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
|
254
|
+
else:
|
255
|
+
input_proc, target_proc = input_wave, target_wave
|
256
|
+
|
257
|
+
x_mels = M(input_proc)
|
258
|
+
y_mels = M(target_proc)
|
259
|
+
|
260
|
+
loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
|
261
|
+
losses += loss * self.lambda_mel
|
262
|
+
|
263
|
+
# pitch/f0 loss
|
264
|
+
if self.use_pitch_loss:
|
265
|
+
x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
|
266
|
+
y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
|
267
|
+
f0_loss = self.loss_fn(x_pitch, y_pitch)
|
268
|
+
losses += f0_loss * self.lambda_pitch
|
269
|
+
|
270
|
+
# energy/rms loss
|
271
|
+
if self.use_rms_loss:
|
272
|
+
x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
|
273
|
+
y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
|
274
|
+
rms_loss = self.loss_fn(x_rms, y_rms)
|
275
|
+
losses += rms_loss * self.lambda_rms
|
276
|
+
|
277
|
+
return losses * self.weight
|
@@ -6,10 +6,12 @@ __all__ = [
|
|
6
6
|
"apply_window",
|
7
7
|
"shift_ring",
|
8
8
|
"dot_product",
|
9
|
-
"normalize_tensor",
|
10
9
|
"log_magnitude",
|
11
10
|
"shift_time",
|
12
11
|
"phase",
|
12
|
+
"normalize_unit_norm",
|
13
|
+
"normalize_minmax",
|
14
|
+
"normalize_zscore",
|
13
15
|
]
|
14
16
|
|
15
17
|
from lt_tensor.torch_commons import *
|
@@ -61,11 +63,6 @@ def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
|
|
61
63
|
return torch.sum(x * y, dim=dim)
|
62
64
|
|
63
65
|
|
64
|
-
def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
|
65
|
-
"""Normalizes a tensor to unit norm (L2)."""
|
66
|
-
return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
|
67
|
-
|
68
|
-
|
69
66
|
def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
|
70
67
|
"""Returns log magnitude from complex STFT."""
|
71
68
|
magnitude = torch.abs(stft_complex)
|
@@ -76,3 +73,19 @@ def phase(stft_complex: Tensor) -> Tensor:
|
|
76
73
|
"""Returns phase from complex STFT."""
|
77
74
|
return torch.angle(stft_complex)
|
78
75
|
|
76
|
+
|
77
|
+
def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-6):
|
78
|
+
norm = torch.norm(x, dim=-1, keepdim=True)
|
79
|
+
return x / (norm + eps)
|
80
|
+
|
81
|
+
|
82
|
+
def normalize_minmax(x: torch.Tensor, eps: float = 1e-6):
|
83
|
+
min_val = x.amin(dim=-1, keepdim=True)
|
84
|
+
max_val = x.amax(dim=-1, keepdim=True)
|
85
|
+
return (x - min_val) / (max_val - min_val + eps)
|
86
|
+
|
87
|
+
|
88
|
+
def normalize_zscore(x: torch.Tensor, eps: float = 1e-6):
|
89
|
+
mean = x.mean(dim=-1, keepdim=True)
|
90
|
+
std = x.std(dim=-1, keepdim=True)
|
91
|
+
return (x - mean) / (std + eps)
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
|
2
4
|
from lt_utils.common import *
|
3
5
|
from lt_tensor.torch_commons import *
|
@@ -5,6 +7,8 @@ from lt_tensor.model_base import Model
|
|
5
7
|
from lt_tensor.model_zoo.convs import ConvNets
|
6
8
|
from torch.nn import functional as F
|
7
9
|
from torchaudio import transforms as T
|
10
|
+
from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
|
11
|
+
|
8
12
|
|
9
13
|
MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
10
14
|
List[Tensor],
|
@@ -14,11 +18,75 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
|
14
18
|
]
|
15
19
|
|
16
20
|
|
21
|
+
class MultiDiscriminatorWrapper(Model):
|
22
|
+
def __init__(self, list_discriminator: List["_MultiDiscriminatorT"]):
|
23
|
+
"""Setup example:
|
24
|
+
model_d = MultiDiscriminatorStep(
|
25
|
+
[
|
26
|
+
MultiEnvelopeDiscriminator(),
|
27
|
+
MultiBandDiscriminator(),
|
28
|
+
MultiResolutionDiscriminator(),
|
29
|
+
MultiPeriodDiscriminator(0.5),
|
30
|
+
]
|
31
|
+
)
|
32
|
+
"""
|
33
|
+
super().__init__()
|
34
|
+
self.disc: Sequence[_MultiDiscriminatorT] = nn.ModuleList(list_discriminator)
|
35
|
+
self.total = len(self.disc)
|
36
|
+
|
37
|
+
def forward(
|
38
|
+
self,
|
39
|
+
y: Tensor,
|
40
|
+
y_hat: Tensor,
|
41
|
+
step_type: Literal["discriminator", "generator"],
|
42
|
+
) -> Union[
|
43
|
+
Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
|
44
|
+
]:
|
45
|
+
"""
|
46
|
+
It returns the content based on the choice of "step_type", being it a
|
47
|
+
'discriminator' or 'generator'
|
48
|
+
|
49
|
+
For generator it returns:
|
50
|
+
Tuple[Tensor, Tensor, List[float]]
|
51
|
+
"gen_loss, feat_loss, all_g_losses"
|
52
|
+
|
53
|
+
For 'discriminator' it returns:
|
54
|
+
Tuple[Tensor, List[float], List[float]]
|
55
|
+
"disc_loss, disc_real_losses, disc_gen_losses"
|
56
|
+
"""
|
57
|
+
if step_type == "generator":
|
58
|
+
all_g_losses: List[float] = []
|
59
|
+
feat_loss: Tensor = 0
|
60
|
+
gen_loss: Tensor = 0
|
61
|
+
else:
|
62
|
+
disc_loss: Tensor = 0
|
63
|
+
disc_real_losses: List[float] = []
|
64
|
+
disc_gen_losses: List[float] = []
|
65
|
+
|
66
|
+
for disc in self.disc:
|
67
|
+
if step_type == "generator":
|
68
|
+
# feature loss, generator loss, list of generator losses (float)]
|
69
|
+
f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
|
70
|
+
gen_loss += g_loss
|
71
|
+
feat_loss += f_loss
|
72
|
+
all_g_losses.extend(g_losses)
|
73
|
+
else:
|
74
|
+
# [discriminator loss, (disc losses real, disc losses generated)]
|
75
|
+
d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
|
76
|
+
disc_loss += d_loss
|
77
|
+
disc_real_losses.extend(d_real_losses)
|
78
|
+
disc_gen_losses.extend(d_gen_losses)
|
79
|
+
|
80
|
+
if step_type == "generator":
|
81
|
+
return gen_loss, feat_loss, all_g_losses
|
82
|
+
return disc_loss, disc_real_losses, disc_gen_losses
|
83
|
+
|
84
|
+
|
17
85
|
def get_padding(kernel_size, dilation=1):
|
18
86
|
return int((kernel_size * dilation - dilation) / 2)
|
19
87
|
|
20
88
|
|
21
|
-
class
|
89
|
+
class _MultiDiscriminatorT(ConvNets):
|
22
90
|
"""Base for all multi-steps type of discriminators"""
|
23
91
|
|
24
92
|
def __init__(self, *args, **kwargs):
|
@@ -171,7 +239,7 @@ class DiscriminatorP(ConvNets):
|
|
171
239
|
return x.flatten(1, -1), fmap
|
172
240
|
|
173
241
|
|
174
|
-
class MultiPeriodDiscriminator(
|
242
|
+
class MultiPeriodDiscriminator(_MultiDiscriminatorT):
|
175
243
|
def __init__(
|
176
244
|
self,
|
177
245
|
discriminator_channel_mult: Number = 1,
|
@@ -258,7 +326,7 @@ class DiscriminatorEnvelope(ConvNets):
|
|
258
326
|
return x.flatten(1), fmap
|
259
327
|
|
260
328
|
|
261
|
-
class MultiEnvelopeDiscriminator(
|
329
|
+
class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
|
262
330
|
def __init__(self, use_spectral_norm: bool = False):
|
263
331
|
super().__init__()
|
264
332
|
self.discriminators = nn.ModuleList(
|
@@ -375,7 +443,7 @@ class DiscriminatorB(ConvNets):
|
|
375
443
|
return x, fmap
|
376
444
|
|
377
445
|
|
378
|
-
class MultiBandDiscriminator(
|
446
|
+
class MultiBandDiscriminator(_MultiDiscriminatorT):
|
379
447
|
"""
|
380
448
|
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
381
449
|
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
@@ -514,7 +582,7 @@ class DiscriminatorR(ConvNets):
|
|
514
582
|
return mag
|
515
583
|
|
516
584
|
|
517
|
-
class MultiResolutionDiscriminator(
|
585
|
+
class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
518
586
|
def __init__(
|
519
587
|
self,
|
520
588
|
use_spectral_norm: bool = False,
|
@@ -552,71 +620,3 @@ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
|
|
552
620
|
y_d_gs.append(y_d_g)
|
553
621
|
fmap_gs.append(fmap_g)
|
554
622
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
555
|
-
|
556
|
-
|
557
|
-
class MultiDiscriminatorStep(Model):
|
558
|
-
def __init__(
|
559
|
-
self, list_discriminator: List[MultiDiscriminatorWrapper]
|
560
|
-
):
|
561
|
-
"""Setup example:
|
562
|
-
model_d = MultiDiscriminatorStep(
|
563
|
-
[
|
564
|
-
MultiEnvelopeDiscriminator(),
|
565
|
-
MultiBandDiscriminator(),
|
566
|
-
MultiResolutionDiscriminator(),
|
567
|
-
MultiPeriodDiscriminator(0.5),
|
568
|
-
]
|
569
|
-
)
|
570
|
-
"""
|
571
|
-
super().__init__()
|
572
|
-
self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
|
573
|
-
list_discriminator
|
574
|
-
)
|
575
|
-
self.total = len(self.disc)
|
576
|
-
|
577
|
-
def forward(
|
578
|
-
self,
|
579
|
-
y: Tensor,
|
580
|
-
y_hat: Tensor,
|
581
|
-
step_type: Literal["discriminator", "generator"],
|
582
|
-
) -> Union[
|
583
|
-
Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
|
584
|
-
]:
|
585
|
-
"""
|
586
|
-
It returns the content based on the choice of "step_type", being it a
|
587
|
-
'discriminator' or 'generator'
|
588
|
-
|
589
|
-
For generator it returns:
|
590
|
-
Tuple[Tensor, Tensor, List[float]]
|
591
|
-
"gen_loss, feat_loss, all_g_losses"
|
592
|
-
|
593
|
-
For 'discriminator' it returns:
|
594
|
-
Tuple[Tensor, List[float], List[float]]
|
595
|
-
"disc_loss, disc_real_losses, disc_gen_losses"
|
596
|
-
"""
|
597
|
-
if step_type == "generator":
|
598
|
-
all_g_losses: List[float] = []
|
599
|
-
feat_loss: Tensor = 0
|
600
|
-
gen_loss: Tensor = 0
|
601
|
-
else:
|
602
|
-
disc_loss: Tensor = 0
|
603
|
-
disc_real_losses: List[float] = []
|
604
|
-
disc_gen_losses: List[float] = []
|
605
|
-
|
606
|
-
for disc in self.disc:
|
607
|
-
if step_type == "generator":
|
608
|
-
# feature loss, generator loss, list of generator losses (float)]
|
609
|
-
f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
|
610
|
-
gen_loss += g_loss
|
611
|
-
feat_loss += f_loss
|
612
|
-
all_g_losses.extend(g_losses)
|
613
|
-
else:
|
614
|
-
# [discriminator loss, (disc losses real, disc losses generated)]
|
615
|
-
d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
|
616
|
-
disc_loss += d_loss
|
617
|
-
disc_real_losses.extend(d_real_losses)
|
618
|
-
disc_gen_losses.extend(d_gen_losses)
|
619
|
-
|
620
|
-
if step_type == "generator":
|
621
|
-
return gen_loss, feat_loss, all_g_losses
|
622
|
-
return disc_loss, disc_real_losses, disc_gen_losses
|
@@ -23,7 +23,7 @@ class AudioProcessorConfig(ModelConfig):
|
|
23
23
|
win_length: int = 1024
|
24
24
|
hop_length: int = 256
|
25
25
|
f_min: float = 0
|
26
|
-
f_max: float =
|
26
|
+
f_max: Optional[float] = None
|
27
27
|
center: bool = True
|
28
28
|
mel_scale: Literal["htk" "slaney"] = "htk"
|
29
29
|
std: int = 4
|
@@ -41,8 +41,8 @@ class AudioProcessorConfig(ModelConfig):
|
|
41
41
|
n_fft: int = 1024,
|
42
42
|
win_length: Optional[int] = None,
|
43
43
|
hop_length: Optional[int] = None,
|
44
|
-
f_min: float =
|
45
|
-
f_max: float =
|
44
|
+
f_min: float = 0,
|
45
|
+
f_max: Optional[float] = None,
|
46
46
|
center: bool = True,
|
47
47
|
mel_scale: Literal["htk", "slaney"] = "htk",
|
48
48
|
std: int = 4,
|
@@ -71,9 +71,12 @@ class AudioProcessorConfig(ModelConfig):
|
|
71
71
|
self.post_process()
|
72
72
|
|
73
73
|
def post_process(self):
|
74
|
-
self.f_min = max(self.f_min, 1)
|
75
|
-
self.f_max = max(min(self.f_max, self.n_fft // 2), self.f_min + 1)
|
76
74
|
self.n_stft = self.n_fft // 2 + 1
|
75
|
+
# some functions needs this to be a non-zero or not None value.
|
76
|
+
self.f_min = max(self.f_min, (self.sample_rate / (self.n_fft - 1)) * 2)
|
77
|
+
self.default_f_max = min(
|
78
|
+
default(self.f_max, self.sample_rate // 2), self.sample_rate // 2
|
79
|
+
)
|
77
80
|
self.hop_length = default(self.hop_length, self.n_fft // 4)
|
78
81
|
self.win_length = default(self.win_length, self.n_fft)
|
79
82
|
|
@@ -105,7 +108,7 @@ class AudioProcessor(Model):
|
|
105
108
|
onesided=self.cfg.onesided,
|
106
109
|
normalized=self.cfg.normalized,
|
107
110
|
)
|
108
|
-
self.
|
111
|
+
self._mel_rscale = torchaudio.transforms.InverseMelScale(
|
109
112
|
n_stft=self.cfg.n_stft,
|
110
113
|
n_mels=self.cfg.n_mels,
|
111
114
|
sample_rate=self.cfg.sample_rate,
|
@@ -119,32 +122,39 @@ class AudioProcessor(Model):
|
|
119
122
|
(torch.hann_window(self.cfg.win_length) if window is None else window),
|
120
123
|
)
|
121
124
|
|
122
|
-
def from_numpy(
|
123
|
-
self,
|
124
|
-
array: np.ndarray,
|
125
|
-
device: Optional[torch.device] = None,
|
126
|
-
dtype: Optional[torch.dtype] = None,
|
127
|
-
):
|
128
|
-
converted = torch.from_numpy(array)
|
129
|
-
if device is None:
|
130
|
-
device = self.device
|
131
|
-
return converted.to(device=device, dtype=dtype)
|
132
125
|
|
133
|
-
|
126
|
+
|
127
|
+
def compute_mel(
|
134
128
|
self,
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
129
|
+
wave: Tensor,
|
130
|
+
raw_mel_only: bool = False,
|
131
|
+
eps: float = 1e-5,
|
132
|
+
*,
|
133
|
+
_recall: bool = False,
|
134
|
+
) -> Tensor:
|
135
|
+
"""Returns: [B, M, T]"""
|
136
|
+
try:
|
137
|
+
mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
|
138
|
+
if not raw_mel_only:
|
139
|
+
mel_tensor = (
|
140
|
+
torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
|
141
|
+
) / self.cfg.std
|
142
|
+
return mel_tensor.squeeze()
|
143
143
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
144
|
+
except RuntimeError as e:
|
145
|
+
if not _recall:
|
146
|
+
self._mel_spec.to(self.device)
|
147
|
+
return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
|
148
|
+
raise e
|
149
|
+
|
150
|
+
def compute_inverse_mel(self, melspec: Tensor, *, _recall=False):
|
151
|
+
try:
|
152
|
+
return self._mel_rscale.forward(melspec.to(self.device)).squeeze()
|
153
|
+
except RuntimeError as e:
|
154
|
+
if not _recall:
|
155
|
+
self._mel_rscale.to(self.device)
|
156
|
+
return self.compute_inverse_mel(melspec, _recall=True)
|
157
|
+
raise e
|
148
158
|
|
149
159
|
def compute_rms(
|
150
160
|
self,
|
@@ -192,12 +202,44 @@ class AudioProcessor(Model):
|
|
192
202
|
else:
|
193
203
|
rms_ = []
|
194
204
|
for i in range(B):
|
195
|
-
|
205
|
+
_t = _comp_rms_helper(i, audio, mel)
|
206
|
+
_r = librosa.feature.rms(**_t, **rms_kwargs)[
|
196
207
|
0
|
197
208
|
]
|
198
209
|
rms_.append(_r)
|
199
210
|
return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
|
200
211
|
|
212
|
+
def pitch_shift(self, audio: torch.Tensor, sample_rate: Optional[int] = None, n_steps: float = 2.0):
|
213
|
+
"""
|
214
|
+
Shifts the pitch of an audio tensor by `n_steps` semitones.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
audio (torch.Tensor): Tensor of shape (B, T) or (T,)
|
218
|
+
sample_rate (int, optional): Sample rate of the audio. Will use the class sample rate if unset.
|
219
|
+
n_steps (float): Number of semitones to shift. Can be negative.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
torch.Tensor: Pitch-shifted audio.
|
223
|
+
"""
|
224
|
+
src_device = audio.device
|
225
|
+
src_dtype = audio.dtype
|
226
|
+
audio = audio.squeeze()
|
227
|
+
sample_rate = default(sample_rate, self.cfg.sample_rate)
|
228
|
+
def _shift_one(wav):
|
229
|
+
wav_np = self.to_numpy_safe(wav)
|
230
|
+
shifted_np = librosa.effects.pitch_shift(wav_np, sr=sample_rate, n_steps=n_steps)
|
231
|
+
return torch.from_numpy(shifted_np)
|
232
|
+
|
233
|
+
if audio.ndim == 1:
|
234
|
+
return _shift_one(audio).to(device=src_device, dtype=src_dtype)
|
235
|
+
return torch.stack([_shift_one(a) for a in audio]).to(device=src_device, dtype=src_dtype)
|
236
|
+
|
237
|
+
|
238
|
+
@staticmethod
|
239
|
+
def calc_pitch_fmin(sr:int, frame_length:float):
|
240
|
+
"""For pitch f_min"""
|
241
|
+
return (sr / (frame_length - 1)) * 2
|
242
|
+
|
201
243
|
def compute_pitch(
|
202
244
|
self,
|
203
245
|
audio: Tensor,
|
@@ -218,9 +260,9 @@ class AudioProcessor(Model):
|
|
218
260
|
else:
|
219
261
|
B = 1
|
220
262
|
sr = default(sr, self.cfg.sample_rate)
|
221
|
-
fmin = max(default(fmin, self.cfg.f_min), 65)
|
222
|
-
fmax = min(default(fmax, self.cfg.f_max), sr // 2)
|
223
263
|
frame_length = default(frame_length, self.cfg.n_fft)
|
264
|
+
fmin = max(default(fmin, self.cfg.f_min), self.calc_pitch_fmin(sr, frame_length))
|
265
|
+
fmax = min(max(default(fmax, self.cfg.default_f_max), fmin+1), sr // 2)
|
224
266
|
hop_length = default(hop_length, self.cfg.hop_length)
|
225
267
|
center = default(center, self.cfg.center)
|
226
268
|
yn_kwargs = dict(
|
@@ -257,10 +299,10 @@ class AudioProcessor(Model):
|
|
257
299
|
frame_length: Optional[Number] = None,
|
258
300
|
):
|
259
301
|
sr = default(sr, self.sample_rate)
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
302
|
+
win_length = default(win_length, self.cfg.win_length)
|
303
|
+
frame_length = default(frame_length, self.cfg.n_fft)
|
304
|
+
fmin = default(fmin, self.calc_pitch_fmin(sr, frame_length))
|
305
|
+
fmax = default(fmax, self.cfg.default_f_max)
|
264
306
|
return detect_pitch_frequency(
|
265
307
|
audio,
|
266
308
|
sample_rate=sr,
|
@@ -270,6 +312,33 @@ class AudioProcessor(Model):
|
|
270
312
|
freq_high=fmax,
|
271
313
|
).squeeze()
|
272
314
|
|
315
|
+
def from_numpy(
|
316
|
+
self,
|
317
|
+
array: np.ndarray,
|
318
|
+
device: Optional[torch.device] = None,
|
319
|
+
dtype: Optional[torch.dtype] = None,
|
320
|
+
):
|
321
|
+
converted = torch.from_numpy(array)
|
322
|
+
if device is None:
|
323
|
+
device = self.device
|
324
|
+
return converted.to(device=device, dtype=dtype)
|
325
|
+
|
326
|
+
def from_numpy_batch(
|
327
|
+
self,
|
328
|
+
arrays: List[np.ndarray],
|
329
|
+
device: Optional[torch.device] = None,
|
330
|
+
dtype: Optional[torch.dtype] = None,
|
331
|
+
):
|
332
|
+
stacked = torch.stack([torch.from_numpy(x) for x in arrays])
|
333
|
+
if device is None:
|
334
|
+
device = self.device
|
335
|
+
return stacked.to(device=device, dtype=dtype)
|
336
|
+
|
337
|
+
def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
|
338
|
+
if isinstance(tensor, np.ndarray):
|
339
|
+
return tensor
|
340
|
+
return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
|
341
|
+
|
273
342
|
def interpolate(
|
274
343
|
self,
|
275
344
|
tensor: Tensor,
|
@@ -391,29 +460,6 @@ class AudioProcessor(Model):
|
|
391
460
|
return self.istft_norm(wave, length, _recall=True)
|
392
461
|
raise e
|
393
462
|
|
394
|
-
def compute_mel(
|
395
|
-
self,
|
396
|
-
wave: Tensor,
|
397
|
-
raw_mel_only: bool = False,
|
398
|
-
eps: float = 1e-5,
|
399
|
-
*,
|
400
|
-
_recall: bool = False,
|
401
|
-
) -> Tensor:
|
402
|
-
"""Returns: [B, M, T]"""
|
403
|
-
try:
|
404
|
-
mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
|
405
|
-
if not raw_mel_only:
|
406
|
-
mel_tensor = (
|
407
|
-
torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
|
408
|
-
) / self.cfg.std
|
409
|
-
return mel_tensor.squeeze()
|
410
|
-
|
411
|
-
except RuntimeError as e:
|
412
|
-
if not _recall:
|
413
|
-
self._mel_spec.to(self.device)
|
414
|
-
return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
|
415
|
-
raise e
|
416
|
-
|
417
463
|
def load_audio(
|
418
464
|
self,
|
419
465
|
path: PathLike,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: lt-tensor
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a34
|
4
4
|
Summary: General utilities for PyTorch and others. Built for general use.
|
5
5
|
Home-page: https://github.com/gr1336/lt-tensor/
|
6
6
|
Author: gr1336
|
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
|
|
17
17
|
Requires-Dist: tokenizers
|
18
18
|
Requires-Dist: pyyaml>=6.0.0
|
19
19
|
Requires-Dist: numba>0.60.0
|
20
|
-
Requires-Dist: lt-utils>=0.0.
|
20
|
+
Requires-Dist: lt-utils>=0.0.4
|
21
21
|
Requires-Dist: librosa==0.11.*
|
22
22
|
Requires-Dist: einops
|
23
23
|
Requires-Dist: plotly
|
@@ -4,7 +4,7 @@ with open("README.md", "r", encoding="utf-8") as f:
|
|
4
4
|
long_description = f.read()
|
5
5
|
|
6
6
|
setup(
|
7
|
-
version="0.0.
|
7
|
+
version="0.0.1a34",
|
8
8
|
name="lt-tensor",
|
9
9
|
description="General utilities for PyTorch and others. Built for general use.",
|
10
10
|
long_description=long_description,
|
@@ -17,7 +17,7 @@ setup(
|
|
17
17
|
"tokenizers",
|
18
18
|
"pyyaml>=6.0.0",
|
19
19
|
"numba>0.60.0",
|
20
|
-
"lt-utils>=0.0.
|
20
|
+
"lt-utils>=0.0.4",
|
21
21
|
"librosa==0.11.*",
|
22
22
|
"einops",
|
23
23
|
"plotly",
|
@@ -1,159 +0,0 @@
|
|
1
|
-
__all__ = [
|
2
|
-
"masked_cross_entropy",
|
3
|
-
"adaptive_l1_loss",
|
4
|
-
"contrastive_loss",
|
5
|
-
"smooth_l1_loss",
|
6
|
-
"hybrid_loss",
|
7
|
-
"diff_loss",
|
8
|
-
"cosine_loss",
|
9
|
-
"gan_loss",
|
10
|
-
"ft_n_loss",
|
11
|
-
]
|
12
|
-
import math
|
13
|
-
import random
|
14
|
-
from lt_tensor.torch_commons import *
|
15
|
-
from lt_utils.common import *
|
16
|
-
import torch.nn.functional as F
|
17
|
-
|
18
|
-
def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
|
19
|
-
if weight is not None:
|
20
|
-
return torch.mean((torch.abs(output - target) + weight) **0.5)
|
21
|
-
return torch.mean(torch.abs(output - target)**0.5)
|
22
|
-
|
23
|
-
def adaptive_l1_loss(
|
24
|
-
inp: Tensor,
|
25
|
-
tgt: Tensor,
|
26
|
-
weight: Optional[Tensor] = None,
|
27
|
-
scale: float = 1.0,
|
28
|
-
inverted: bool = False,
|
29
|
-
):
|
30
|
-
|
31
|
-
if weight is not None:
|
32
|
-
loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
|
33
|
-
else:
|
34
|
-
loss = torch.mean(torch.abs(inp - tgt))
|
35
|
-
loss *= scale
|
36
|
-
if inverted:
|
37
|
-
return -loss
|
38
|
-
return loss
|
39
|
-
|
40
|
-
|
41
|
-
def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
|
42
|
-
diff = torch.abs(inp - tgt)
|
43
|
-
loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
|
44
|
-
if weight is not None:
|
45
|
-
loss *= weight
|
46
|
-
return loss.mean()
|
47
|
-
|
48
|
-
|
49
|
-
def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
|
50
|
-
# label == 1: similar, label == 0: dissimilar
|
51
|
-
dist = torch.nn.functional.pairwise_distance(x1, x2)
|
52
|
-
loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
|
53
|
-
return loss.mean()
|
54
|
-
|
55
|
-
|
56
|
-
def cosine_loss(inp, tgt):
|
57
|
-
cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
|
58
|
-
return 1 - cos.mean() # Lower is better
|
59
|
-
|
60
|
-
|
61
|
-
class GanLosses:
|
62
|
-
@staticmethod
|
63
|
-
def get_loss(
|
64
|
-
pred: Tensor,
|
65
|
-
target_is_real: bool,
|
66
|
-
loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
|
67
|
-
) -> Tensor:
|
68
|
-
if loss_type == "bce": # Standard GAN
|
69
|
-
target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
|
70
|
-
return F.binary_cross_entropy_with_logits(pred, target)
|
71
|
-
|
72
|
-
elif loss_type == "mse": # LSGAN
|
73
|
-
target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
|
74
|
-
return F.mse_loss(torch.sigmoid(pred), target)
|
75
|
-
|
76
|
-
elif loss_type == "hinge":
|
77
|
-
if target_is_real:
|
78
|
-
return torch.mean(F.relu(1.0 - pred))
|
79
|
-
else:
|
80
|
-
return torch.mean(F.relu(1.0 + pred))
|
81
|
-
|
82
|
-
elif loss_type == "wasserstein":
|
83
|
-
return -pred.mean() if target_is_real else pred.mean()
|
84
|
-
|
85
|
-
else:
|
86
|
-
raise ValueError(f"Unknown loss_type: {loss_type}")
|
87
|
-
|
88
|
-
@staticmethod
|
89
|
-
def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
|
90
|
-
return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
|
91
|
-
|
92
|
-
@staticmethod
|
93
|
-
def discriminator_loss(
|
94
|
-
real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
|
95
|
-
) -> Tensor:
|
96
|
-
real_loss = GanLosses.get_loss(
|
97
|
-
real_pred, target_is_real=True, loss_type=loss_type
|
98
|
-
)
|
99
|
-
fake_loss = GanLosses.get_loss(
|
100
|
-
fake_pred.detach(), target_is_real=False, loss_type=loss_type
|
101
|
-
)
|
102
|
-
return (real_loss + fake_loss) * 0.5
|
103
|
-
|
104
|
-
|
105
|
-
def masked_cross_entropy(
|
106
|
-
logits: torch.Tensor, # [B, T, V]
|
107
|
-
targets: torch.Tensor, # [B, T]
|
108
|
-
lengths: torch.Tensor, # [B]
|
109
|
-
reduction: str = "mean",
|
110
|
-
) -> torch.Tensor:
|
111
|
-
"""
|
112
|
-
CrossEntropyLoss with masking for variable-length sequences.
|
113
|
-
- logits: unnormalized scores [B, T, V]
|
114
|
-
- targets: ground truth indices [B, T]
|
115
|
-
- lengths: actual sequence lengths [B]
|
116
|
-
"""
|
117
|
-
B, T, V = logits.size()
|
118
|
-
logits = logits.view(-1, V)
|
119
|
-
targets = targets.view(-1)
|
120
|
-
|
121
|
-
# Create mask
|
122
|
-
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
123
|
-
mask = mask.reshape(-1)
|
124
|
-
|
125
|
-
# Apply CE only where mask == True
|
126
|
-
loss = F.cross_entropy(
|
127
|
-
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
128
|
-
)
|
129
|
-
if reduction == "none":
|
130
|
-
return loss
|
131
|
-
return loss
|
132
|
-
|
133
|
-
|
134
|
-
def diff_loss(pred_noise, true_noise, mask=None):
|
135
|
-
"""Standard diffusion noise-prediction loss (e.g., DDPM)"""
|
136
|
-
if mask is not None:
|
137
|
-
return F.mse_loss(pred_noise * mask, true_noise * mask)
|
138
|
-
return F.mse_loss(pred_noise, true_noise)
|
139
|
-
|
140
|
-
|
141
|
-
def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
|
142
|
-
"""Combines L1 and L2"""
|
143
|
-
l1 = F.l1_loss(pred_noise, true_noise)
|
144
|
-
l2 = F.mse_loss(pred_noise, true_noise)
|
145
|
-
return alpha * l1 + (1 - alpha) * l2
|
146
|
-
|
147
|
-
|
148
|
-
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
149
|
-
loss = 0
|
150
|
-
for real, fake in zip(real_preds, fake_preds):
|
151
|
-
if use_lsgan:
|
152
|
-
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
153
|
-
fake, torch.zeros_like(fake)
|
154
|
-
)
|
155
|
-
else:
|
156
|
-
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
157
|
-
torch.log(1 - fake + 1e-7)
|
158
|
-
)
|
159
|
-
return loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/act.py
RENAMED
File without changes
|
{lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/activations/alias_free_torch/filter.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py
RENAMED
File without changes
|
{lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py
RENAMED
File without changes
|
{lt_tensor-0.0.1a32 → lt_tensor-0.0.1a34}/lt_tensor/model_zoo/audio_models/istft/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|