lt-tensor 0.0.1a4__py3-none-any.whl → 0.0.1a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +5 -1
- lt_tensor/datasets/audio.py +45 -1
- lt_tensor/losses.py +145 -0
- lt_tensor/math_ops.py +7 -0
- lt_tensor/misc_utils.py +10 -96
- lt_tensor/model_base.py +102 -3
- lt_tensor/model_zoo/disc.py +14 -14
- lt_tensor/model_zoo/istft.py +41 -0
- lt_tensor/noise_tools.py +362 -0
- lt_tensor/transform.py +13 -22
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a6.dist-info}/METADATA +1 -1
- lt_tensor-0.0.1a6.dist-info/RECORD +26 -0
- lt_tensor-0.0.1a4.dist-info/RECORD +0 -24
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a6.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a6.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a6.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
__version__ = "0.0.1a"
|
2
2
|
|
3
3
|
from . import (
|
4
|
+
lr_schedulers,
|
4
5
|
model_zoo,
|
5
6
|
model_base,
|
6
7
|
math_ops,
|
7
8
|
misc_utils,
|
8
9
|
monotonic_align,
|
9
10
|
transform,
|
10
|
-
|
11
|
+
noise_tools,
|
12
|
+
losses,
|
11
13
|
)
|
12
14
|
|
13
15
|
__all__ = [
|
@@ -18,4 +20,6 @@ __all__ = [
|
|
18
20
|
"monotonic_align",
|
19
21
|
"transform",
|
20
22
|
"lr_schedulers",
|
23
|
+
"noise_tools",
|
24
|
+
"losses",
|
21
25
|
]
|
lt_tensor/datasets/audio.py
CHANGED
@@ -3,9 +3,10 @@ from ..torch_commons import *
|
|
3
3
|
import torchaudio
|
4
4
|
from lt_utils.common import *
|
5
5
|
import librosa
|
6
|
-
from lt_utils.type_utils import is_file
|
6
|
+
from lt_utils.type_utils import is_file, is_array
|
7
7
|
from torchaudio.functional import resample
|
8
8
|
from ..transform import inverse_transform
|
9
|
+
from lt_utils.file_ops import FileScan, load_text, get_file_name
|
9
10
|
|
10
11
|
|
11
12
|
class AudioProcessor:
|
@@ -108,3 +109,46 @@ class AudioProcessor:
|
|
108
109
|
.float()
|
109
110
|
.unsqueeze(0)
|
110
111
|
)
|
112
|
+
|
113
|
+
def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
|
114
|
+
extensions = [
|
115
|
+
"*.wav",
|
116
|
+
"*.aac",
|
117
|
+
"*.m4a",
|
118
|
+
"*.mp3",
|
119
|
+
"*.ogg",
|
120
|
+
"*.opus",
|
121
|
+
"*.flac",
|
122
|
+
]
|
123
|
+
extensions.extend(
|
124
|
+
[x for x in additional_extensions if isinstance(x, str) and "*" in x]
|
125
|
+
)
|
126
|
+
return FileScan.files(
|
127
|
+
path,
|
128
|
+
extensions,
|
129
|
+
)
|
130
|
+
|
131
|
+
def find_audio_text_pairs(
|
132
|
+
self,
|
133
|
+
path,
|
134
|
+
additional_extensions: List[str] = [],
|
135
|
+
text_file_patterns: List[str] = [".normalized.txt", ".original.txt"],
|
136
|
+
):
|
137
|
+
is_array(text_file_patterns, True, validate=True) # Rases if empty or not valid
|
138
|
+
additional_extensions = [
|
139
|
+
x
|
140
|
+
for x in additional_extensions
|
141
|
+
if isinstance(x, str)
|
142
|
+
and "*" in x
|
143
|
+
and not any(list(map(lambda y: y in x), text_file_patterns))
|
144
|
+
]
|
145
|
+
audio_files = self.find_audios(path, additional_extensions)
|
146
|
+
text_files = []
|
147
|
+
for audio in audio_files:
|
148
|
+
base_audio_dir = Path(audio).parent
|
149
|
+
audio_name = get_file_name(audio, False)
|
150
|
+
for pattern in text_file_patterns:
|
151
|
+
possible_txt_file = Path(base_audio_dir, audio_name + pattern)
|
152
|
+
if is_file(possible_txt_file):
|
153
|
+
text_files.append(audio)
|
154
|
+
return audio_files, text_files
|
lt_tensor/losses.py
ADDED
@@ -0,0 +1,145 @@
|
|
1
|
+
__all__ = ["masked_cross_entropy"]
|
2
|
+
import math
|
3
|
+
import random
|
4
|
+
from .torch_commons import *
|
5
|
+
from lt_utils.common import *
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
def masked_cross_entropy(
|
10
|
+
logits: torch.Tensor, # [B, T, V]
|
11
|
+
targets: torch.Tensor, # [B, T]
|
12
|
+
lengths: torch.Tensor, # [B]
|
13
|
+
reduction: str = "mean",
|
14
|
+
) -> torch.Tensor:
|
15
|
+
"""
|
16
|
+
CrossEntropyLoss with masking for variable-length sequences.
|
17
|
+
- logits: unnormalized scores [B, T, V]
|
18
|
+
- targets: ground truth indices [B, T]
|
19
|
+
- lengths: actual sequence lengths [B]
|
20
|
+
"""
|
21
|
+
B, T, V = logits.size()
|
22
|
+
logits = logits.view(-1, V)
|
23
|
+
targets = targets.view(-1)
|
24
|
+
|
25
|
+
# Create mask
|
26
|
+
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
27
|
+
mask = mask.reshape(-1)
|
28
|
+
|
29
|
+
# Apply CE only where mask == True
|
30
|
+
loss = F.cross_entropy(
|
31
|
+
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
32
|
+
)
|
33
|
+
if reduction == "none":
|
34
|
+
return loss
|
35
|
+
return loss
|
36
|
+
|
37
|
+
|
38
|
+
def diff_loss(pred_noise, true_noise, mask=None):
|
39
|
+
"""Standard diffusion noise-prediction loss (e.g., DDPM)"""
|
40
|
+
if mask is not None:
|
41
|
+
return F.mse_loss(pred_noise * mask, true_noise * mask)
|
42
|
+
return F.mse_loss(pred_noise, true_noise)
|
43
|
+
|
44
|
+
|
45
|
+
def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
|
46
|
+
"""Combines L1 and L2"""
|
47
|
+
l1 = F.l1_loss(pred_noise, true_noise)
|
48
|
+
l2 = F.mse_loss(pred_noise, true_noise)
|
49
|
+
return alpha * l1 + (1 - alpha) * l2
|
50
|
+
|
51
|
+
|
52
|
+
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
53
|
+
loss = 0
|
54
|
+
for real, fake in zip(real_preds, fake_preds):
|
55
|
+
if use_lsgan:
|
56
|
+
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
57
|
+
fake, torch.zeros_like(fake)
|
58
|
+
)
|
59
|
+
else:
|
60
|
+
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
61
|
+
torch.log(1 - fake + 1e-7)
|
62
|
+
)
|
63
|
+
return loss
|
64
|
+
|
65
|
+
|
66
|
+
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
67
|
+
loss = 0
|
68
|
+
for real, fake in zip(real_preds, fake_preds):
|
69
|
+
if use_lsgan:
|
70
|
+
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
71
|
+
fake, torch.zeros_like(fake)
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
75
|
+
torch.log(1 - fake + 1e-7)
|
76
|
+
)
|
77
|
+
return loss
|
78
|
+
|
79
|
+
|
80
|
+
def gan_g_loss(fake_preds, use_lsgan=True):
|
81
|
+
loss = 0
|
82
|
+
for fake in fake_preds:
|
83
|
+
if use_lsgan:
|
84
|
+
loss += F.mse_loss(fake, torch.ones_like(fake))
|
85
|
+
else:
|
86
|
+
loss += -torch.mean(torch.log(fake + 1e-7))
|
87
|
+
return loss
|
88
|
+
|
89
|
+
|
90
|
+
def feature_matching_loss(real_feats, fake_feats):
|
91
|
+
"""real_feats and fake_feats are lists of intermediate features"""
|
92
|
+
loss = 0
|
93
|
+
for real_layers, fake_layers in zip(real_feats, fake_feats):
|
94
|
+
for r, f in zip(real_layers, fake_layers):
|
95
|
+
loss += F.l1_loss(f, r.detach())
|
96
|
+
return loss
|
97
|
+
|
98
|
+
|
99
|
+
def feature_loss(real_fmaps, fake_fmaps, weight=2.0):
|
100
|
+
loss = 0.0
|
101
|
+
for dr, dg in zip(real_fmaps, fake_fmaps): # Each (layer list from a discriminator)
|
102
|
+
for r_feat, g_feat in zip(dr, dg):
|
103
|
+
loss += F.l1_loss(r_feat, g_feat)
|
104
|
+
return loss * weight
|
105
|
+
|
106
|
+
|
107
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
108
|
+
loss = 0.0
|
109
|
+
r_losses = []
|
110
|
+
g_losses = []
|
111
|
+
|
112
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
113
|
+
r_loss = F.mse_loss(dr, torch.ones_like(dr))
|
114
|
+
g_loss = F.mse_loss(dg, torch.zeros_like(dg))
|
115
|
+
loss += r_loss + g_loss
|
116
|
+
r_losses.append(r_loss)
|
117
|
+
g_losses.append(g_loss)
|
118
|
+
|
119
|
+
return loss, r_losses, g_losses
|
120
|
+
|
121
|
+
|
122
|
+
def generator_loss(fake_outputs):
|
123
|
+
total = 0.0
|
124
|
+
g_losses = []
|
125
|
+
for out in fake_outputs:
|
126
|
+
loss = F.mse_loss(out, torch.ones_like(out))
|
127
|
+
g_losses.append(loss)
|
128
|
+
total += loss
|
129
|
+
return total, g_losses
|
130
|
+
|
131
|
+
|
132
|
+
def multi_resolution_stft_loss(y, y_hat, fft_sizes=[512, 1024, 2048]):
|
133
|
+
loss = 0
|
134
|
+
for fft_size in fft_sizes:
|
135
|
+
hop = fft_size // 4
|
136
|
+
win = fft_size
|
137
|
+
y_stft = torch.stft(
|
138
|
+
y, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
|
139
|
+
)
|
140
|
+
y_hat_stft = torch.stft(
|
141
|
+
y_hat, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
|
142
|
+
)
|
143
|
+
|
144
|
+
loss += F.l1_loss(torch.abs(y_stft), torch.abs(y_hat_stft))
|
145
|
+
return loss
|
lt_tensor/math_ops.py
CHANGED
@@ -8,6 +8,7 @@ __all__ = [
|
|
8
8
|
"dot_product",
|
9
9
|
"normalize_tensor",
|
10
10
|
"log_magnitude",
|
11
|
+
"shift_time",
|
11
12
|
"phase",
|
12
13
|
]
|
13
14
|
|
@@ -50,6 +51,11 @@ def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
|
|
50
51
|
return torch.roll(x, shifts=1, dims=dim)
|
51
52
|
|
52
53
|
|
54
|
+
def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
|
55
|
+
"""Shifts tensor along time axis (last dim)."""
|
56
|
+
return torch.roll(x, shifts=shift, dims=-1)
|
57
|
+
|
58
|
+
|
53
59
|
def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
|
54
60
|
"""Computes dot product along the specified dimension."""
|
55
61
|
return torch.sum(x * y, dim=dim)
|
@@ -69,3 +75,4 @@ def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
|
|
69
75
|
def phase(stft_complex: Tensor) -> Tensor:
|
70
76
|
"""Returns phase from complex STFT."""
|
71
77
|
return torch.angle(stft_complex)
|
78
|
+
|
lt_tensor/misc_utils.py
CHANGED
@@ -8,7 +8,6 @@ __all__ = [
|
|
8
8
|
"unfreeze_selected_weights",
|
9
9
|
"clip_gradients",
|
10
10
|
"detach_hidden",
|
11
|
-
"tensor_summary",
|
12
11
|
"one_hot",
|
13
12
|
"safe_divide",
|
14
13
|
"batch_pad",
|
@@ -18,22 +17,20 @@ __all__ = [
|
|
18
17
|
"default_device",
|
19
18
|
"Packing",
|
20
19
|
"Padding",
|
21
|
-
"
|
22
|
-
"masked_cross_entropy",
|
23
|
-
"NoiseScheduler",
|
20
|
+
"Masking",
|
24
21
|
]
|
25
22
|
|
26
23
|
import gc
|
24
|
+
import sys
|
27
25
|
import random
|
28
26
|
import numpy as np
|
29
27
|
from lt_utils.type_utils import is_str
|
30
28
|
from .torch_commons import *
|
31
|
-
from lt_utils.misc_utils import
|
32
|
-
from lt_utils.file_ops import load_json, load_yaml, save_json, save_yaml
|
33
|
-
import math
|
29
|
+
from lt_utils.misc_utils import cache_wrapper
|
34
30
|
from lt_utils.common import *
|
35
31
|
import torch.nn.functional as F
|
36
32
|
|
33
|
+
|
37
34
|
def log_tensor(
|
38
35
|
item: Union[Tensor, np.ndarray],
|
39
36
|
title: Optional[str] = None,
|
@@ -64,10 +61,13 @@ def log_tensor(
|
|
64
61
|
print(f"mean: {item.mean(dim=dim):.4f}")
|
65
62
|
except:
|
66
63
|
pass
|
67
|
-
|
68
|
-
|
64
|
+
if print_tensor:
|
65
|
+
print(item)
|
69
66
|
if has_title:
|
70
67
|
print("".join(["-"] * _b), "\n")
|
68
|
+
else:
|
69
|
+
print("\n")
|
70
|
+
sys.stdout.flush()
|
71
71
|
|
72
72
|
|
73
73
|
def set_seed(seed: int):
|
@@ -136,11 +136,6 @@ def detach_hidden(hidden):
|
|
136
136
|
return tuple(detach_hidden(h) for h in hidden)
|
137
137
|
|
138
138
|
|
139
|
-
def tensor_summary(tensor: torch.Tensor) -> str:
|
140
|
-
"""Prints min/max/mean/std of a tensor for debugging."""
|
141
|
-
return f"Shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}, min: {tensor.min():.4f}, max: {tensor.max():.4f}, mean: {tensor.mean():.4f}, std: {tensor.std():.4f}"
|
142
|
-
|
143
|
-
|
144
139
|
def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
|
145
140
|
"""One-hot encodes a tensor of labels."""
|
146
141
|
return F.one_hot(labels, num_classes).float()
|
@@ -463,7 +458,7 @@ class Padding:
|
|
463
458
|
return torch.stack(padded), lengths
|
464
459
|
|
465
460
|
|
466
|
-
class
|
461
|
+
class Masking:
|
467
462
|
|
468
463
|
@staticmethod
|
469
464
|
def apply_mask(x: Tensor, mask: Tensor, fill_value: Number = 0) -> Tensor:
|
@@ -546,84 +541,3 @@ class MaskUtils:
|
|
546
541
|
return (
|
547
542
|
causal_mask & pad_mask.unsqueeze(1).expand(-1, pad_mask.size(1), -1).bool()
|
548
543
|
)
|
549
|
-
|
550
|
-
|
551
|
-
def masked_cross_entropy(
|
552
|
-
logits: torch.Tensor, # [B, T, V]
|
553
|
-
targets: torch.Tensor, # [B, T]
|
554
|
-
lengths: torch.Tensor, # [B]
|
555
|
-
reduction: str = "mean",
|
556
|
-
) -> torch.Tensor:
|
557
|
-
"""
|
558
|
-
CrossEntropyLoss with masking for variable-length sequences.
|
559
|
-
- logits: unnormalized scores [B, T, V]
|
560
|
-
- targets: ground truth indices [B, T]
|
561
|
-
- lengths: actual sequence lengths [B]
|
562
|
-
"""
|
563
|
-
B, T, V = logits.size()
|
564
|
-
logits = logits.view(-1, V)
|
565
|
-
targets = targets.view(-1)
|
566
|
-
|
567
|
-
# Create mask
|
568
|
-
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
569
|
-
mask = mask.reshape(-1)
|
570
|
-
|
571
|
-
# Apply CE only where mask == True
|
572
|
-
loss = F.cross_entropy(
|
573
|
-
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
574
|
-
)
|
575
|
-
if reduction == "none":
|
576
|
-
return loss
|
577
|
-
return loss
|
578
|
-
|
579
|
-
|
580
|
-
class NoiseScheduler(nn.Module):
|
581
|
-
def __init__(self, timesteps: int = 512):
|
582
|
-
super().__init__()
|
583
|
-
|
584
|
-
betas = torch.linspace(1e-4, 0.02, timesteps)
|
585
|
-
alphas = 1.0 - betas
|
586
|
-
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
587
|
-
|
588
|
-
self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
|
589
|
-
self.register_buffer(
|
590
|
-
"sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
|
591
|
-
)
|
592
|
-
|
593
|
-
self.timesteps = timesteps
|
594
|
-
self.default_noise = math.sqrt(1.25)
|
595
|
-
|
596
|
-
def get_random_noise(
|
597
|
-
self, min_max: Tuple[float, float] = (-3, 3), seed: int = 0
|
598
|
-
) -> float:
|
599
|
-
if seed > 0:
|
600
|
-
random.seed(seed)
|
601
|
-
return random.uniform(*min_max)
|
602
|
-
|
603
|
-
def set_noise(
|
604
|
-
self,
|
605
|
-
seed: int = 0,
|
606
|
-
min_max: Tuple[float, float] = (-3, 3),
|
607
|
-
default: bool = False,
|
608
|
-
):
|
609
|
-
self.default_noise = (
|
610
|
-
math.sqrt(1.25) if default else self.get_random_noise(min_max, seed)
|
611
|
-
)
|
612
|
-
|
613
|
-
def forward(
|
614
|
-
self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
|
615
|
-
) -> Tensor:
|
616
|
-
if t < 0 or t >= self.timesteps:
|
617
|
-
raise ValueError(
|
618
|
-
f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
|
619
|
-
)
|
620
|
-
|
621
|
-
if noise is None:
|
622
|
-
noise = self.default_noise
|
623
|
-
|
624
|
-
if isinstance(noise, (float, int)):
|
625
|
-
noise = torch.randn_like(x_0) * noise
|
626
|
-
|
627
|
-
alpha_term = self.sqrt_alpha_cumprod[t] * x_0
|
628
|
-
noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
|
629
|
-
return alpha_term + noise_term
|
lt_tensor/model_base.py
CHANGED
@@ -4,6 +4,7 @@ __all__ = ["Model"]
|
|
4
4
|
import warnings
|
5
5
|
from .torch_commons import *
|
6
6
|
from lt_utils.common import *
|
7
|
+
from lt_utils.misc_utils import log_traceback
|
7
8
|
|
8
9
|
T = TypeVar("T")
|
9
10
|
|
@@ -44,7 +45,7 @@ class Model(nn.Module, ABC):
|
|
44
45
|
|
45
46
|
def tp_apply_device_to(self):
|
46
47
|
"""Add here components that are needed to have device applied to them,
|
47
|
-
that
|
48
|
+
that usually the '.to()' function fails to apply
|
48
49
|
|
49
50
|
example:
|
50
51
|
```
|
@@ -54,6 +55,99 @@ class Model(nn.Module, ABC):
|
|
54
55
|
"""
|
55
56
|
pass
|
56
57
|
|
58
|
+
def freeze_weight(self, weight: Union[str, nn.Module], freeze: bool):
|
59
|
+
assert isinstance(weight, (str, nn.Module))
|
60
|
+
if isinstance(weight, str):
|
61
|
+
if hasattr(self, weight):
|
62
|
+
w = getattr(self, weight)
|
63
|
+
if isinstance(w, nn.Module):
|
64
|
+
w.requires_grad_(not freeze)
|
65
|
+
else:
|
66
|
+
weight.requires_grad_(not freeze)
|
67
|
+
|
68
|
+
def _freeze_unfreeze(
|
69
|
+
self,
|
70
|
+
weight: Union[str, nn.Module],
|
71
|
+
task: Literal["freeze", "unfreeze"] = "freeze",
|
72
|
+
_skip_except: bool = False,
|
73
|
+
):
|
74
|
+
try:
|
75
|
+
assert isinstance(weight, (str, nn.Module))
|
76
|
+
if isinstance(weight, str):
|
77
|
+
w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a valid attribute of {self._get_name()}"
|
78
|
+
if hasattr(self, weight):
|
79
|
+
w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a Module type."
|
80
|
+
w = getattr(self, weight)
|
81
|
+
if isinstance(w, nn.Module):
|
82
|
+
w_txt = f"Successfully {task} the module '{weight}'."
|
83
|
+
w.requires_grad_(task == "unfreeze")
|
84
|
+
|
85
|
+
else:
|
86
|
+
w.requires_grad_(task == "unfreeze")
|
87
|
+
w_txt = f"Successfully '{task}' the module '{weight}'."
|
88
|
+
return w_txt
|
89
|
+
except Exception as e:
|
90
|
+
if not _skip_except:
|
91
|
+
raise e
|
92
|
+
return str(e)
|
93
|
+
|
94
|
+
def freeze_weight(
|
95
|
+
self,
|
96
|
+
weight: Union[str, nn.Module],
|
97
|
+
_skip_except: bool = False,
|
98
|
+
):
|
99
|
+
return self._freeze_unfreeze(weight, "freeze", _skip_except)
|
100
|
+
|
101
|
+
def unfreeze_weight(
|
102
|
+
self,
|
103
|
+
weight: Union[str, nn.Module],
|
104
|
+
_skip_except: bool = False,
|
105
|
+
):
|
106
|
+
return self._freeze_unfreeze(weight, "freeze", _skip_except)
|
107
|
+
|
108
|
+
def freeze_all(self, exclude: Optional[List[str]] = None):
|
109
|
+
no_exclusions = not exclude
|
110
|
+
frozen = []
|
111
|
+
not_frozen = []
|
112
|
+
for name, param in self.named_parameters():
|
113
|
+
if no_exclusions:
|
114
|
+
try:
|
115
|
+
param.requires_grad_(False)
|
116
|
+
frozen.append(name)
|
117
|
+
except Exception as e:
|
118
|
+
not_frozen.append((name, str(e)))
|
119
|
+
elif any(layer in name for layer in exclude):
|
120
|
+
try:
|
121
|
+
param.requires_grad_(False)
|
122
|
+
frozen.append(name)
|
123
|
+
except Exception as e:
|
124
|
+
not_frozen.append((name, str(e)))
|
125
|
+
else:
|
126
|
+
not_frozen.append((name, "Excluded"))
|
127
|
+
return dict(frozen=frozen, not_frozen=not_frozen)
|
128
|
+
|
129
|
+
def unfreeze_all_except(self, exclude: Optional[list[str]] = None):
|
130
|
+
"""Unfreezes all model parameters except specified layers."""
|
131
|
+
no_exclusions = not exclude
|
132
|
+
unfrozen = []
|
133
|
+
not_unfrozen = []
|
134
|
+
for name, param in self.named_parameters():
|
135
|
+
if no_exclusions:
|
136
|
+
try:
|
137
|
+
param.requires_grad_(True)
|
138
|
+
unfrozen.append(name)
|
139
|
+
except Exception as e:
|
140
|
+
not_unfrozen.append((name, str(e)))
|
141
|
+
elif any(layer in name for layer in exclude):
|
142
|
+
try:
|
143
|
+
param.requires_grad_(True)
|
144
|
+
unfrozen.append(name)
|
145
|
+
except Exception as e:
|
146
|
+
not_unfrozen.append((name, str(e)))
|
147
|
+
else:
|
148
|
+
not_unfrozen.append((name, "Excluded"))
|
149
|
+
return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
|
150
|
+
|
57
151
|
def to(self, *args, **kwargs):
|
58
152
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
59
153
|
*args, **kwargs
|
@@ -186,11 +280,16 @@ class Model(nn.Module, ABC):
|
|
186
280
|
)
|
187
281
|
|
188
282
|
def get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
|
189
|
-
"""Returns the weights of the model
|
283
|
+
"""Returns the weights of the model entry model or from a specified module"""
|
190
284
|
if module_name is not None:
|
191
285
|
assert hasattr(self, module_name), f"Module {module_name} does not exits"
|
192
286
|
module = getattr(self, module_name)
|
193
|
-
|
287
|
+
params = []
|
288
|
+
if isinstance(module, nn.Module):
|
289
|
+
return [x.data.detach() for x in module.parameters()]
|
290
|
+
elif isinstance(module, (Tensor, nn.Parameter)):
|
291
|
+
return [module.data.detach()]
|
292
|
+
raise (f"{module_name} is has no weights")
|
194
293
|
return [x.data.detach() for x in self.parameters()]
|
195
294
|
|
196
295
|
def print_trainable_parameters(
|
lt_tensor/model_zoo/disc.py
CHANGED
@@ -11,37 +11,36 @@ class PeriodDiscriminator(Model):
|
|
11
11
|
use_spectral_norm=False,
|
12
12
|
kernel_size: int = 5,
|
13
13
|
stride: int = 3,
|
14
|
-
initial_s: int = 32,
|
15
14
|
):
|
16
15
|
super().__init__()
|
17
16
|
self.period = period
|
17
|
+
self.stride = stride
|
18
|
+
self.kernel_size = kernel_size
|
18
19
|
self.norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
20
|
+
|
21
|
+
self.channels = [32, 128, 512, 1024, 1024]
|
19
22
|
self.first_pass = nn.Sequential(
|
20
23
|
self.norm_f(
|
21
24
|
nn.Conv2d(
|
22
|
-
1,
|
25
|
+
1, self.channels[0], (kernel_size, 1), (stride, 1), padding=(2, 0)
|
23
26
|
)
|
24
27
|
),
|
25
28
|
nn.LeakyReLU(0.1),
|
26
29
|
)
|
27
|
-
self._last_sz = initial_s * 4
|
28
30
|
|
29
|
-
|
31
|
+
|
32
|
+
self.convs = nn.ModuleList([self._get_next(self.channels[i+1], self.channels[i], i == 3) for i in range(4)])
|
30
33
|
|
31
34
|
self.post_conv = nn.Conv2d(1024, 1, (stride, 1), 1, padding=(1, 0))
|
32
|
-
self.kernel_size = kernel_size
|
33
|
-
self.stride = stride
|
34
35
|
|
35
|
-
def _get_next(self, is_last: bool = False):
|
36
|
-
in_dim = self._last_sz
|
37
|
-
self._last_sz *= 4
|
38
|
-
print(self._last_sz, "-----------------------")
|
36
|
+
def _get_next(self, out_dim:int, last_in:int, is_last: bool = False):
|
39
37
|
stride = (self.stride, 1) if not is_last else 1
|
38
|
+
|
40
39
|
return nn.Sequential(
|
41
40
|
self.norm_f(
|
42
41
|
nn.Conv2d(
|
43
|
-
|
44
|
-
|
42
|
+
last_in,
|
43
|
+
out_dim,
|
45
44
|
(self.kernel_size, 1),
|
46
45
|
stride,
|
47
46
|
padding=(2, 0),
|
@@ -91,6 +90,7 @@ class ScaleDiscriminator(nn.Module):
|
|
91
90
|
def __init__(self, use_spectral_norm=False):
|
92
91
|
super().__init__()
|
93
92
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
93
|
+
self.activation = nn.LeakyReLU(0.1)
|
94
94
|
self.convs = nn.ModuleList(
|
95
95
|
[
|
96
96
|
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
@@ -103,7 +103,6 @@ class ScaleDiscriminator(nn.Module):
|
|
103
103
|
]
|
104
104
|
)
|
105
105
|
self.post_conv = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
106
|
-
self.activation = nn.LeakyReLU(0.1)
|
107
106
|
|
108
107
|
def forward(self, x: torch.Tensor):
|
109
108
|
"""
|
@@ -147,9 +146,10 @@ class GeneralLossDescriminator(Model):
|
|
147
146
|
super().__init__()
|
148
147
|
self.mpd = MultiPeriodDiscriminator()
|
149
148
|
self.msd = MultiScaleDiscriminator()
|
149
|
+
self.print_trainable_parameters()
|
150
150
|
|
151
151
|
def _get_group_(self):
|
152
152
|
pass
|
153
153
|
|
154
154
|
def forward(self, x: Tensor, y_hat: Tensor):
|
155
|
-
return
|
155
|
+
return
|
lt_tensor/model_zoo/istft.py
CHANGED
@@ -106,3 +106,44 @@ class Generator(Model):
|
|
106
106
|
classname = m.__class__.__name__
|
107
107
|
if "Conv" in classname:
|
108
108
|
m.weight.data.normal_(mean, std)
|
109
|
+
|
110
|
+
|
111
|
+
# Below are items found in the Rishikesh's repo that might work for this generator.
|
112
|
+
# https://github.com/rishikksh20/iSTFTNet-pytorch/blob/781480e9563d4dff5a8cc9ef1af6c6e0cab025c8/models.py
|
113
|
+
|
114
|
+
|
115
|
+
def feature_loss(fmap_r, fmap_g, weight=2.0):
|
116
|
+
"""Feature matching loss between real and generated feature maps."""
|
117
|
+
loss = 0.0
|
118
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
119
|
+
for rl, gl in zip(dr, dg):
|
120
|
+
loss += torch.mean(torch.abs(rl - gl))
|
121
|
+
return loss * weight
|
122
|
+
|
123
|
+
|
124
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
125
|
+
"""LSGAN-style loss for real and fake predictions."""
|
126
|
+
loss = 0.0
|
127
|
+
r_losses, g_losses = [], []
|
128
|
+
|
129
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
130
|
+
r_loss = torch.mean((1.0 - dr) ** 2)
|
131
|
+
g_loss = torch.mean(dg**2)
|
132
|
+
loss += r_loss + g_loss
|
133
|
+
r_losses.append(r_loss.item())
|
134
|
+
g_losses.append(g_loss.item())
|
135
|
+
|
136
|
+
return loss, r_losses, g_losses
|
137
|
+
|
138
|
+
|
139
|
+
def generator_loss(disc_generated_outputs):
|
140
|
+
"""LSGAN generator loss encouraging fake to look like real (close to 1)."""
|
141
|
+
loss = 0.0
|
142
|
+
gen_losses = []
|
143
|
+
|
144
|
+
for dg in disc_generated_outputs:
|
145
|
+
l = torch.mean((1.0 - dg) ** 2)
|
146
|
+
gen_losses.append(l.item())
|
147
|
+
loss += l
|
148
|
+
|
149
|
+
return loss, gen_losses
|
lt_tensor/noise_tools.py
ADDED
@@ -0,0 +1,362 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"NoiseSchedulerA",
|
3
|
+
"NoiseSchedulerB",
|
4
|
+
"NoiseSchedulerC",
|
5
|
+
"add_gaussian_noise",
|
6
|
+
"add_uniform_noise",
|
7
|
+
"add_linear_noise",
|
8
|
+
"add_impulse_noise",
|
9
|
+
"add_pink_noise",
|
10
|
+
"add_clipped_gaussian_noise",
|
11
|
+
"add_multiplicative_noise",
|
12
|
+
"apply_noise",
|
13
|
+
]
|
14
|
+
|
15
|
+
from lt_utils.common import *
|
16
|
+
import torch.nn.functional as F
|
17
|
+
from .torch_commons import *
|
18
|
+
import math
|
19
|
+
import random
|
20
|
+
from .misc_utils import set_seed
|
21
|
+
|
22
|
+
|
23
|
+
def add_gaussian_noise(x: Tensor, noise_level=0.025):
|
24
|
+
noise = torch.randn_like(x) * noise_level
|
25
|
+
return x + noise
|
26
|
+
|
27
|
+
|
28
|
+
def add_uniform_noise(x: Tensor, noise_level=0.025):
|
29
|
+
noise = (torch.rand_like(x) - 0.5) * 2 * noise_level
|
30
|
+
return x + noise
|
31
|
+
|
32
|
+
|
33
|
+
def add_linear_noise(x, noise_level=0.05):
|
34
|
+
T = x.shape[-1]
|
35
|
+
ramp = torch.linspace(0, noise_level, T, device=x.device)
|
36
|
+
for _ in range(x.dim() - 1):
|
37
|
+
ramp = ramp.unsqueeze(0)
|
38
|
+
return x + ramp.expand_as(x)
|
39
|
+
|
40
|
+
|
41
|
+
def add_impulse_noise(x: Tensor, noise_level=0.025):
|
42
|
+
# For image inputs
|
43
|
+
probs = torch.rand_like(x)
|
44
|
+
x_clone = x.detach().clone()
|
45
|
+
x_clone[probs < (noise_level / 2)] = 0.0 # salt
|
46
|
+
x_clone[probs > (1 - noise_level / 2)] = 1.0 # pepper
|
47
|
+
return x_clone
|
48
|
+
|
49
|
+
|
50
|
+
def add_pink_noise(x: Tensor, noise_level=0.05):
|
51
|
+
# pink noise: divide freq spectrum by sqrt(f)
|
52
|
+
if x.ndim == 3:
|
53
|
+
x = x.view(-1, x.shape[-1]) # flatten to 2D [B*M, T]
|
54
|
+
pink_noised = []
|
55
|
+
|
56
|
+
for row in x:
|
57
|
+
white = torch.randn_like(row)
|
58
|
+
f = torch.fft.rfft(white)
|
59
|
+
freqs = torch.fft.rfftfreq(row.numel(), d=1.0)
|
60
|
+
freqs[0] = 1.0 # prevent div by 0
|
61
|
+
f /= freqs.sqrt()
|
62
|
+
pink = torch.fft.irfft(f, n=row.numel())
|
63
|
+
pink_noised.append(pink)
|
64
|
+
|
65
|
+
pink_noised = torch.stack(pink_noised, dim=0).view_as(x)
|
66
|
+
return x + pink_noised * noise_level
|
67
|
+
|
68
|
+
|
69
|
+
def add_clipped_gaussian_noise(x, noise_level=0.025):
|
70
|
+
noise = torch.randn_like(x) * noise_level
|
71
|
+
return torch.clamp(x + noise, 0.0, 1.0)
|
72
|
+
|
73
|
+
|
74
|
+
def add_multiplicative_noise(x, noise_level=0.025):
|
75
|
+
noise = 1 + torch.randn_like(x) * noise_level
|
76
|
+
return x * noise
|
77
|
+
|
78
|
+
|
79
|
+
_VALID_NOISES = [
|
80
|
+
"gaussian",
|
81
|
+
"uniform",
|
82
|
+
"linear",
|
83
|
+
"impulse",
|
84
|
+
"pink",
|
85
|
+
"clipped_gaussian",
|
86
|
+
"multiplicative",
|
87
|
+
]
|
88
|
+
|
89
|
+
_NOISE_MAP = {
|
90
|
+
"gaussian": add_gaussian_noise,
|
91
|
+
"uniform": add_uniform_noise,
|
92
|
+
"linear": add_linear_noise,
|
93
|
+
"impulse": add_impulse_noise,
|
94
|
+
"pink": add_pink_noise,
|
95
|
+
"clipped_gaussian": add_clipped_gaussian_noise,
|
96
|
+
"multiplicative": add_multiplicative_noise,
|
97
|
+
}
|
98
|
+
|
99
|
+
_NOISE_DIM_SUPPORT = {
|
100
|
+
"gaussian": (1, 2),
|
101
|
+
"uniform": (1, 2),
|
102
|
+
"multiplicative": (1, 2, 3),
|
103
|
+
"clipped_gaussian": (1, 2, 3),
|
104
|
+
"linear": (2, 3),
|
105
|
+
"impulse": (2, 3),
|
106
|
+
"pink": (2, 3),
|
107
|
+
}
|
108
|
+
|
109
|
+
|
110
|
+
def apply_noise(
|
111
|
+
x: Tensor,
|
112
|
+
noise_type: str = "gaussian",
|
113
|
+
noise_level: float = 0.01,
|
114
|
+
seed: Optional[int] = None,
|
115
|
+
on_error: Literal["raise", "try_others", "return_unchanged"] = "raise",
|
116
|
+
_last_tries: list[str] = [],
|
117
|
+
):
|
118
|
+
noise_type = noise_type.lower().strip()
|
119
|
+
last_tries = _last_tries
|
120
|
+
|
121
|
+
if noise_type not in _NOISE_MAP:
|
122
|
+
raise ValueError(f"Noise type '{noise_type}' not supported.")
|
123
|
+
|
124
|
+
# Check dimension compatibility
|
125
|
+
allowed_dims = _NOISE_DIM_SUPPORT.get(noise_type, (1, 2))
|
126
|
+
if x.ndim not in allowed_dims:
|
127
|
+
assert (
|
128
|
+
on_error != "raise"
|
129
|
+
), f"Noise '{noise_type}' is not supported for {x.ndim}D input."
|
130
|
+
if on_error == "return_unchanged":
|
131
|
+
return x
|
132
|
+
elif on_error == "try_others":
|
133
|
+
remaining = [
|
134
|
+
n
|
135
|
+
for n in _VALID_NOISES
|
136
|
+
if n not in last_tries and x.ndim in _NOISE_DIM_SUPPORT[n]
|
137
|
+
]
|
138
|
+
if not remaining:
|
139
|
+
return x
|
140
|
+
new_type = random.choice(remaining)
|
141
|
+
last_tries.append(new_type)
|
142
|
+
return apply_noise(
|
143
|
+
x, new_type, noise_level, seed, on_error, last_tries.copy()
|
144
|
+
)
|
145
|
+
try:
|
146
|
+
if isinstance(seed, int):
|
147
|
+
set_seed(seed)
|
148
|
+
return _NOISE_MAP[noise_type](x, noise_level)
|
149
|
+
except Exception as e:
|
150
|
+
if on_error == "raise":
|
151
|
+
raise e
|
152
|
+
elif on_error == "return_unchanged":
|
153
|
+
return x
|
154
|
+
if len(last_tries) == len(_VALID_NOISES):
|
155
|
+
return x
|
156
|
+
remaining = [n for n in _VALID_NOISES if n not in last_tries]
|
157
|
+
new_type = random.choice(remaining)
|
158
|
+
last_tries.append(new_type)
|
159
|
+
return apply_noise(x, new_type, noise_level, seed, on_error, last_tries.copy())
|
160
|
+
|
161
|
+
|
162
|
+
class NoiseSchedulerA(nn.Module):
|
163
|
+
def __init__(self, samples: int = 64):
|
164
|
+
super().__init__()
|
165
|
+
self.base_steps = samples
|
166
|
+
|
167
|
+
def plot_noise_progression(noise_seq: list[Tensor], titles: list[str] = None):
|
168
|
+
import matplotlib.pyplot as plt
|
169
|
+
|
170
|
+
steps = len(noise_seq)
|
171
|
+
plt.figure(figsize=(15, 3))
|
172
|
+
for i, tensor in enumerate(noise_seq):
|
173
|
+
plt.subplot(1, steps, i + 1)
|
174
|
+
plt.imshow(tensor.squeeze().cpu().numpy(), aspect="auto", origin="lower")
|
175
|
+
if titles:
|
176
|
+
plt.title(titles[i])
|
177
|
+
plt.axis("off")
|
178
|
+
plt.tight_layout()
|
179
|
+
plt.show()
|
180
|
+
|
181
|
+
def forward(
|
182
|
+
self,
|
183
|
+
source_item: torch.Tensor,
|
184
|
+
steps: Optional[int] = None,
|
185
|
+
noise_type: Literal[
|
186
|
+
"gaussian",
|
187
|
+
"uniform",
|
188
|
+
"linear",
|
189
|
+
"impulse",
|
190
|
+
"pink",
|
191
|
+
"clipped_gaussian",
|
192
|
+
"multiplicative",
|
193
|
+
] = "gaussian",
|
194
|
+
seed: Optional[int] = None,
|
195
|
+
noise_level: float = 0.01,
|
196
|
+
shuffle_noise_types: bool = False,
|
197
|
+
return_dict: bool = True,
|
198
|
+
):
|
199
|
+
if steps is None:
|
200
|
+
steps = self.base_steps
|
201
|
+
collected = [source_item.detach().clone()]
|
202
|
+
noise_history = []
|
203
|
+
for i in range(steps):
|
204
|
+
if i > 0 and shuffle_noise_types:
|
205
|
+
noise_type = random.choice(_VALID_NOISES)
|
206
|
+
current, noise_name = apply_noise(
|
207
|
+
collected[-1],
|
208
|
+
noise_type,
|
209
|
+
noise_level,
|
210
|
+
seed=seed,
|
211
|
+
on_error="try_others",
|
212
|
+
)
|
213
|
+
noise_history.append(noise_name)
|
214
|
+
collected.append(current)
|
215
|
+
|
216
|
+
if return_dict:
|
217
|
+
return {
|
218
|
+
"steps": collected,
|
219
|
+
"history": noise_history,
|
220
|
+
"final": collected[-1],
|
221
|
+
"init": collected[0],
|
222
|
+
}
|
223
|
+
return collected, noise_history
|
224
|
+
|
225
|
+
|
226
|
+
class NoiseSchedulerB(nn.Module):
|
227
|
+
def __init__(self, timesteps: int = 512):
|
228
|
+
super().__init__()
|
229
|
+
|
230
|
+
betas = torch.linspace(1e-4, 0.02, timesteps)
|
231
|
+
alphas = 1.0 - betas
|
232
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
233
|
+
|
234
|
+
self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
|
235
|
+
self.register_buffer(
|
236
|
+
"sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
|
237
|
+
)
|
238
|
+
|
239
|
+
self.timesteps = timesteps
|
240
|
+
self.default_noise = math.sqrt(1.25)
|
241
|
+
|
242
|
+
def _get_random_noise(
|
243
|
+
self,
|
244
|
+
min_max: Tuple[float, float] = (-3, 3),
|
245
|
+
seed: Optional[int] = None,
|
246
|
+
) -> float:
|
247
|
+
if isinstance(seed, int):
|
248
|
+
random.seed(seed)
|
249
|
+
return random.uniform(*min_max)
|
250
|
+
|
251
|
+
def set_noise(
|
252
|
+
self,
|
253
|
+
noise: Optional[Union[Tensor, Number]] = None,
|
254
|
+
seed: Optional[int] = None,
|
255
|
+
min_max: Tuple[float, float] = (-3, 3),
|
256
|
+
default: bool = False,
|
257
|
+
):
|
258
|
+
if noise is not None:
|
259
|
+
self.default_noise = noise
|
260
|
+
else:
|
261
|
+
self.default_noise = (
|
262
|
+
math.sqrt(1.25) if default else self._get_random_noise(min_max, seed)
|
263
|
+
)
|
264
|
+
|
265
|
+
def forward(
|
266
|
+
self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
|
267
|
+
) -> Tensor:
|
268
|
+
apply_noise()
|
269
|
+
assert (
|
270
|
+
0 >= t < self.timesteps
|
271
|
+
), f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
|
272
|
+
|
273
|
+
if noise is None:
|
274
|
+
noise = torch.randn_like(x_0) * self.default_noise
|
275
|
+
|
276
|
+
elif isinstance(noise, (float, int)):
|
277
|
+
noise = torch.randn_like(x_0) * noise
|
278
|
+
|
279
|
+
alpha_term = self.sqrt_alpha_cumprod[t] * x_0
|
280
|
+
noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
|
281
|
+
return alpha_term + noise_term
|
282
|
+
|
283
|
+
|
284
|
+
class NoiseSchedulerC(nn.Module):
|
285
|
+
def __init__(self, timesteps: int = 512):
|
286
|
+
super().__init__()
|
287
|
+
|
288
|
+
betas = torch.linspace(1e-4, 0.02, timesteps)
|
289
|
+
alphas = 1.0 - betas
|
290
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
291
|
+
|
292
|
+
self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
|
293
|
+
self.register_buffer(
|
294
|
+
"sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
|
295
|
+
)
|
296
|
+
|
297
|
+
self.timesteps = timesteps
|
298
|
+
self.default_noise_strength = math.sqrt(1.25)
|
299
|
+
self.default_noise_type = "gaussian"
|
300
|
+
self.noise_seed = None
|
301
|
+
|
302
|
+
def _get_random_uniform(self, shape, min_val=-1.0, max_val=1.0):
|
303
|
+
return torch.empty(shape).uniform_(min_val, max_val)
|
304
|
+
|
305
|
+
def _get_noise(self, x: Tensor, noise_type: str, noise_level: float) -> Tensor:
|
306
|
+
# Basic noise types
|
307
|
+
if noise_type == "gaussian":
|
308
|
+
return torch.randn_like(x) * noise_level
|
309
|
+
elif noise_type == "uniform":
|
310
|
+
return self._get_random_uniform(x.shape) * noise_level
|
311
|
+
elif noise_type == "multiplicative":
|
312
|
+
return x * (1 + (torch.randn_like(x) * noise_level))
|
313
|
+
elif noise_type == "clipped_gaussian":
|
314
|
+
noise = torch.randn_like(x) * noise_level
|
315
|
+
return noise.clamp(-1.0, 1.0)
|
316
|
+
elif noise_type == "impulse":
|
317
|
+
mask = torch.rand_like(x) < noise_level
|
318
|
+
impulses = torch.randn_like(x) * noise_level
|
319
|
+
return x + impulses * mask
|
320
|
+
else:
|
321
|
+
raise ValueError(f"Unsupported noise type: '{noise_type}'")
|
322
|
+
|
323
|
+
def set_noise(
|
324
|
+
self,
|
325
|
+
noise_strength: Optional[Union[Tensor, float]] = None,
|
326
|
+
noise_type: Optional[str] = None,
|
327
|
+
seed: Optional[int] = None,
|
328
|
+
default: bool = False,
|
329
|
+
):
|
330
|
+
if noise_strength is not None:
|
331
|
+
self.default_noise_strength = noise_strength
|
332
|
+
elif default:
|
333
|
+
self.default_noise_strength = math.sqrt(1.25)
|
334
|
+
|
335
|
+
if noise_type is not None:
|
336
|
+
self.default_noise_type = noise_type.lower().strip()
|
337
|
+
|
338
|
+
if isinstance(seed, int):
|
339
|
+
self.noise_seed = seed
|
340
|
+
torch.manual_seed(seed)
|
341
|
+
random.seed(seed)
|
342
|
+
|
343
|
+
def forward(
|
344
|
+
self,
|
345
|
+
x_0: Tensor,
|
346
|
+
t: int,
|
347
|
+
noise: Optional[Union[Tensor, float]] = None,
|
348
|
+
noise_type: Optional[str] = None,
|
349
|
+
) -> Tensor:
|
350
|
+
assert 0 <= t < self.timesteps, f"t={t} is out of bounds [0, {self.timesteps})"
|
351
|
+
|
352
|
+
noise_type = noise_type or self.default_noise_type
|
353
|
+
noise_level = self.default_noise_strength
|
354
|
+
|
355
|
+
if noise is None:
|
356
|
+
noise = self._get_noise(x_0, noise_type, noise_level)
|
357
|
+
elif isinstance(noise, (float, int)):
|
358
|
+
noise = self._get_noise(x_0, noise_type, noise)
|
359
|
+
|
360
|
+
alpha_term = self.sqrt_alpha_cumprod[t] * x_0
|
361
|
+
noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
|
362
|
+
return alpha_term + noise_term
|
lt_tensor/transform.py
CHANGED
@@ -8,8 +8,6 @@ __all__ = [
|
|
8
8
|
"normalize",
|
9
9
|
"min_max_scale",
|
10
10
|
"mel_to_linear",
|
11
|
-
"add_noise",
|
12
|
-
"shift_time",
|
13
11
|
"stretch_tensor",
|
14
12
|
"pad_tensor",
|
15
13
|
"get_sinusoidal_embedding",
|
@@ -39,16 +37,19 @@ def to_mel_spectrogram(
|
|
39
37
|
f_max: Optional[float] = None,
|
40
38
|
) -> torch.Tensor:
|
41
39
|
"""Converts waveform to mel spectrogram."""
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
40
|
+
return (
|
41
|
+
torchaudio.transforms.MelSpectrogram(
|
42
|
+
sample_rate=sample_rate,
|
43
|
+
n_fft=n_fft,
|
44
|
+
hop_length=hop_length,
|
45
|
+
win_length=win_length,
|
46
|
+
n_mels=n_mels,
|
47
|
+
f_min=f_min,
|
48
|
+
f_max=f_max,
|
49
|
+
)
|
50
|
+
.to(device=waveform.device)
|
51
|
+
.forward(waveform)
|
50
52
|
)
|
51
|
-
return mel_spectrogram(waveform)
|
52
53
|
|
53
54
|
|
54
55
|
def stft(
|
@@ -151,16 +152,6 @@ def mel_to_linear(
|
|
151
152
|
return torch.matmul(mel_fb_inv, mel_spec + eps)
|
152
153
|
|
153
154
|
|
154
|
-
def add_noise(x: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
|
155
|
-
"""Adds Gaussian noise to tensor."""
|
156
|
-
return x + noise_level * torch.randn_like(x)
|
157
|
-
|
158
|
-
|
159
|
-
def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
|
160
|
-
"""Shifts tensor along time axis (last dim)."""
|
161
|
-
return torch.roll(x, shifts=shift, dims=-1)
|
162
|
-
|
163
|
-
|
164
155
|
def stretch_tensor(x: torch.Tensor, rate: float, mode: str = "linear") -> torch.Tensor:
|
165
156
|
"""Time-stretch tensor using interpolation."""
|
166
157
|
B, C, T = x.shape if x.ndim == 3 else (1, 1, x.shape[0])
|
@@ -274,7 +265,7 @@ def window_sumsquare(
|
|
274
265
|
n_fft: int = 2048,
|
275
266
|
dtype: torch.dtype = torch.float32,
|
276
267
|
norm: Optional[Union[int, float]] = None,
|
277
|
-
device: Optional[torch.device] =
|
268
|
+
device: Optional[torch.device] = None,
|
278
269
|
):
|
279
270
|
if win_length is None:
|
280
271
|
win_length = n_fft
|
@@ -0,0 +1,26 @@
|
|
1
|
+
lt_tensor/__init__.py,sha256=D-oEjsuKWhtk1qyiADERgNO78aRCXUJJz0hs65h8LOg,365
|
2
|
+
lt_tensor/losses.py,sha256=TinZJP2ypZ7Tdg6d9nnFWFkPyormfgQ0Z9P2ER3sqzE,4341
|
3
|
+
lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
|
4
|
+
lt_tensor/math_ops.py,sha256=ewIYkvxIy_Lab_9ExjFUgLs-oYLOu8IRRDo7f1pn3i8,2248
|
5
|
+
lt_tensor/misc_utils.py,sha256=sjWUkUaHFhaCdN4rZ6X-cQDbPieimfKchKq9VtjiwEA,17029
|
6
|
+
lt_tensor/model_base.py,sha256=2W4m6hlvMyfRx1efWJ0NIIwctzLjL4rip208vL9_n0Y,13419
|
7
|
+
lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
|
8
|
+
lt_tensor/noise_tools.py,sha256=O4oq5oi0jLJuQNIuxOBZa-rB0S065QXtb1gjQUXVaLs,11212
|
9
|
+
lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
|
10
|
+
lt_tensor/transform.py,sha256=hqsP6nXRn4nqMGkN2hBi4y-kHxEQdlIUS0y89Y1mjVI,8589
|
11
|
+
lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
lt_tensor/datasets/audio.py,sha256=frftmRYNk0eXqHEiFggC46RMuCoGyuwBAlnPxfFsS7Y,4858
|
13
|
+
lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
|
14
|
+
lt_tensor/model_zoo/bsc.py,sha256=muxIR7dU-Pvf-HFE-iy3zmRb1sTJlcs1vqdlnbU1Hss,6307
|
15
|
+
lt_tensor/model_zoo/disc.py,sha256=jZPhoSV1hlrba3ohXGutYAAcSl4pWkqGYFpOlOoN3eo,4740
|
16
|
+
lt_tensor/model_zoo/fsn.py,sha256=5ySsg2OHjvTV_coPAdZQ0f7bz4ugJB8mDYsItmd61qA,2102
|
17
|
+
lt_tensor/model_zoo/gns.py,sha256=Tirr_grONp_FFQ_L7K-zV2lvkaC39h8mMl4QDpx9vLQ,6028
|
18
|
+
lt_tensor/model_zoo/istft.py,sha256=0Xms2QNPAgz_ib8XTfaWl1SCHgS53oKC6-EkDkl_qe4,4863
|
19
|
+
lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
|
20
|
+
lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
|
21
|
+
lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
|
22
|
+
lt_tensor-0.0.1a6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
23
|
+
lt_tensor-0.0.1a6.dist-info/METADATA,sha256=-89IqEHsZD3W8moDuKWR8UodkdR2pwefqrG9C7P7y_Y,968
|
24
|
+
lt_tensor-0.0.1a6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
25
|
+
lt_tensor-0.0.1a6.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
26
|
+
lt_tensor-0.0.1a6.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
lt_tensor/__init__.py,sha256=bvCjaIsYjbGFbR5MNezgLyRgN4_CsyrjmVEvuClsgOU,303
|
2
|
-
lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
|
3
|
-
lt_tensor/math_ops.py,sha256=ZtnJ9WB-pbFQLsXuNfQl2dAaeob5BWfxmhkwpxITUZ4,2066
|
4
|
-
lt_tensor/misc_utils.py,sha256=e44FCQbjNHP-4WOHIbtqqH0x590DzUE6CrD_4Vl_d38,19880
|
5
|
-
lt_tensor/model_base.py,sha256=tmRu5pTcELKMFcybOiZ1thJPuJWRSPkbUUtp9Y1NJWw,9555
|
6
|
-
lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
|
7
|
-
lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
|
8
|
-
lt_tensor/transform.py,sha256=IVAaQlq12OvMVhX3lX4lgsTCJYJce5n5MtMy7IK_AU4,8892
|
9
|
-
lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
lt_tensor/datasets/audio.py,sha256=BZTceP9MlmyrVioHpWLkd_ZcyawYYZUAlVWKfKwyWAg,3318
|
11
|
-
lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
|
12
|
-
lt_tensor/model_zoo/bsc.py,sha256=muxIR7dU-Pvf-HFE-iy3zmRb1sTJlcs1vqdlnbU1Hss,6307
|
13
|
-
lt_tensor/model_zoo/disc.py,sha256=ND6JR_x6b2Y1VqxZejalv8Cz5_TO3H_Z-0x6UnACbBM,4740
|
14
|
-
lt_tensor/model_zoo/fsn.py,sha256=5ySsg2OHjvTV_coPAdZQ0f7bz4ugJB8mDYsItmd61qA,2102
|
15
|
-
lt_tensor/model_zoo/gns.py,sha256=Tirr_grONp_FFQ_L7K-zV2lvkaC39h8mMl4QDpx9vLQ,6028
|
16
|
-
lt_tensor/model_zoo/istft.py,sha256=RV7KVY7q4CYzzsWXH4NGJQwSqrYWwHh-16Q62lKoA2k,3594
|
17
|
-
lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
|
18
|
-
lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
|
19
|
-
lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
|
20
|
-
lt_tensor-0.0.1a4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
21
|
-
lt_tensor-0.0.1a4.dist-info/METADATA,sha256=sbT9xduzE-huVvSjnak9iCo1Eyp45bsMUarc16oTD3o,968
|
22
|
-
lt_tensor-0.0.1a4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
-
lt_tensor-0.0.1a4.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
24
|
-
lt_tensor-0.0.1a4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|