lt-tensor 0.0.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -0
- lt_tensor/_basics.py +270 -0
- lt_tensor/_torch_commons.py +12 -0
- lt_tensor/lr_schedulers.py +114 -0
- lt_tensor/math_ops.py +71 -0
- lt_tensor/misc_utils.py +628 -0
- lt_tensor/model_zoo/__init__.py +9 -0
- lt_tensor/model_zoo/bsc.py +210 -0
- lt_tensor/model_zoo/dfs.py +181 -0
- lt_tensor/model_zoo/fsn.py +67 -0
- lt_tensor/model_zoo/pos.py +121 -0
- lt_tensor/model_zoo/rsd.py +158 -0
- lt_tensor/model_zoo/tfr.py +140 -0
- lt_tensor/monotonic_align.py +70 -0
- lt_tensor/transform.py +349 -0
- lt_tensor-0.0.1a0.dist-info/METADATA +31 -0
- lt_tensor-0.0.1a0.dist-info/RECORD +20 -0
- lt_tensor-0.0.1a0.dist-info/WHEEL +5 -0
- lt_tensor-0.0.1a0.dist-info/licenses/LICENSE +201 -0
- lt_tensor-0.0.1a0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,158 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"spectral_norm_select",
|
3
|
+
"ResBlock1D",
|
4
|
+
"ResBlock2D",
|
5
|
+
"ResBlock1D_S",
|
6
|
+
]
|
7
|
+
|
8
|
+
from .._torch_commons import *
|
9
|
+
from .._basics import Model
|
10
|
+
import math
|
11
|
+
from ..misc_utils import log_tensor
|
12
|
+
|
13
|
+
|
14
|
+
def spectral_norm_select(module: Module, enabled: bool):
|
15
|
+
if enabled:
|
16
|
+
return spectral_norm(module)
|
17
|
+
return module
|
18
|
+
|
19
|
+
|
20
|
+
class ResBlock1D(Model):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
in_channels: int,
|
24
|
+
out_channels: int,
|
25
|
+
kernel_size: int = 3,
|
26
|
+
dilation: Union[Sequence[int], int] = (1, 3, 5),
|
27
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
28
|
+
num_groups: int = 1,
|
29
|
+
batched: bool = True,
|
30
|
+
):
|
31
|
+
super().__init__()
|
32
|
+
self.conv = nn.ModuleList()
|
33
|
+
if isinstance(dilation, int):
|
34
|
+
dilation = [dilation]
|
35
|
+
|
36
|
+
if batched:
|
37
|
+
layernorm_fn = lambda x: nn.GroupNorm(num_groups=num_groups, num_channels=x)
|
38
|
+
else:
|
39
|
+
layernorm_fn = lambda x: nn.LayerNorm(normalized_shape=x)
|
40
|
+
for i, dil in enumerate(dilation):
|
41
|
+
|
42
|
+
self.conv.append(
|
43
|
+
nn.ModuleDict(
|
44
|
+
dict(
|
45
|
+
net=nn.Sequential(
|
46
|
+
self._get_conv_layer(
|
47
|
+
in_channels, in_channels, kernel_size, dil
|
48
|
+
),
|
49
|
+
activation,
|
50
|
+
self._get_conv_layer(
|
51
|
+
in_channels, in_channels, kernel_size, 1, True
|
52
|
+
),
|
53
|
+
activation,
|
54
|
+
),
|
55
|
+
l_norm=layernorm_fn(in_channels),
|
56
|
+
)
|
57
|
+
)
|
58
|
+
)
|
59
|
+
self.final = nn.Sequential(
|
60
|
+
self._get_conv_layer(in_channels, out_channels, kernel_size, 1, True),
|
61
|
+
activation,
|
62
|
+
)
|
63
|
+
self.conv.apply(self.init_weights)
|
64
|
+
|
65
|
+
def _get_conv_layer(
|
66
|
+
self,
|
67
|
+
channels_in: int,
|
68
|
+
channels_out: int,
|
69
|
+
kernel_size: int,
|
70
|
+
dilation: int,
|
71
|
+
pad_gate: bool = False,
|
72
|
+
):
|
73
|
+
return weight_norm(
|
74
|
+
nn.Conv1d(
|
75
|
+
in_channels=channels_in,
|
76
|
+
out_channels=channels_out,
|
77
|
+
kernel_size=kernel_size,
|
78
|
+
stride=1,
|
79
|
+
dilation=dilation,
|
80
|
+
padding=(
|
81
|
+
int((kernel_size * dilation - dilation) / 2)
|
82
|
+
if not pad_gate
|
83
|
+
else int((kernel_size * 1 - 1) / 2)
|
84
|
+
),
|
85
|
+
)
|
86
|
+
)
|
87
|
+
|
88
|
+
def forward(self, x: Tensor):
|
89
|
+
for i, layer in enumerate(self.conv):
|
90
|
+
xt = layer["net"](x)
|
91
|
+
x = xt + x
|
92
|
+
x = layer["l_norm"](x)
|
93
|
+
return self.final(x)
|
94
|
+
|
95
|
+
def remove_weight_norm(self):
|
96
|
+
for module in self.modules():
|
97
|
+
try:
|
98
|
+
remove_weight_norm(module)
|
99
|
+
except ValueError:
|
100
|
+
pass # Not normed, skip
|
101
|
+
|
102
|
+
@staticmethod
|
103
|
+
def init_weights(m, mean=0.0, std=0.01):
|
104
|
+
classname = m.__class__.__name__
|
105
|
+
if "Conv" in classname:
|
106
|
+
m.weight.data.normal_(mean, std)
|
107
|
+
|
108
|
+
|
109
|
+
class ResBlock2D(Model):
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
in_channels,
|
113
|
+
out_channels,
|
114
|
+
downsample=False,
|
115
|
+
spec_norm: bool = False,
|
116
|
+
):
|
117
|
+
super().__init__()
|
118
|
+
stride = 2 if downsample else 1
|
119
|
+
|
120
|
+
self.block = nn.Sequential(
|
121
|
+
spectral_norm_select(
|
122
|
+
nn.Conv2d(in_channels, out_channels, 3, stride, 1), spec_norm
|
123
|
+
),
|
124
|
+
nn.LeakyReLU(0.2),
|
125
|
+
spectral_norm_select(
|
126
|
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), spec_norm
|
127
|
+
),
|
128
|
+
)
|
129
|
+
|
130
|
+
self.skip = nn.Identity()
|
131
|
+
if downsample or in_channels != out_channels:
|
132
|
+
self.skip = spectral_norm_select(
|
133
|
+
nn.Conv2d(in_channels, out_channels, 1, stride), spec_norm
|
134
|
+
)
|
135
|
+
# on less to be handled every cicle
|
136
|
+
self.sqrt_2 = math.sqrt(2)
|
137
|
+
|
138
|
+
def forward(self, x):
|
139
|
+
return (self.block(x) + self.skip(x)) / self.sqrt_2
|
140
|
+
|
141
|
+
|
142
|
+
class ResBlock1D_S(Model):
|
143
|
+
"""Simplified version"""
|
144
|
+
|
145
|
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: int = 1):
|
146
|
+
super().__init__()
|
147
|
+
padding = (kernel_size - 1) // 2 * dilation
|
148
|
+
self.net = nn.Sequential(
|
149
|
+
nn.Conv1d(
|
150
|
+
channels, channels, kernel_size, padding=padding, dilation=dilation
|
151
|
+
),
|
152
|
+
nn.LeakyReLU(0.1),
|
153
|
+
nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=1),
|
154
|
+
)
|
155
|
+
self.activation = nn.LeakyReLU(0.1)
|
156
|
+
|
157
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
158
|
+
return self.activation(x + self.net(x))
|
@@ -0,0 +1,140 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"TransformerEncoderLayer",
|
3
|
+
"TransformerDecoderLayer",
|
4
|
+
"TransformerEncoder",
|
5
|
+
"TransformerDecoder",
|
6
|
+
"init_weights",
|
7
|
+
]
|
8
|
+
|
9
|
+
import math
|
10
|
+
from .._torch_commons import *
|
11
|
+
from .._basics import Model
|
12
|
+
from lt_utils.misc_utils import default
|
13
|
+
|
14
|
+
from .pos import *
|
15
|
+
from .bsc import FeedForward
|
16
|
+
|
17
|
+
|
18
|
+
def init_weights(module):
|
19
|
+
if isinstance(module, nn.Linear):
|
20
|
+
nn.init.xavier_uniform_(module.weight)
|
21
|
+
if module.bias is not None:
|
22
|
+
nn.init.constant_(module.bias, 0)
|
23
|
+
elif isinstance(module, nn.Embedding):
|
24
|
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
25
|
+
elif isinstance(module, nn.LayerNorm):
|
26
|
+
nn.init.constant_(module.bias, 0)
|
27
|
+
nn.init.constant_(module.weight, 1.0)
|
28
|
+
|
29
|
+
|
30
|
+
class TransformerEncoderLayer(Model):
|
31
|
+
def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
|
32
|
+
super().__init__()
|
33
|
+
self.self_attn = nn.MultiheadAttention(
|
34
|
+
d_model, n_heads, dropout=dropout, batch_first=True
|
35
|
+
)
|
36
|
+
self.norm1 = nn.LayerNorm(d_model)
|
37
|
+
self.ff = FeedForward(d_model, ff_size, dropout)
|
38
|
+
self.norm2 = nn.LayerNorm(d_model)
|
39
|
+
self.dropout = nn.Dropout(dropout)
|
40
|
+
|
41
|
+
def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
|
42
|
+
attn_output, _ = self.self_attn(x, x, x, attn_mask=src_mask)
|
43
|
+
x = self.norm1(x + self.dropout(attn_output))
|
44
|
+
ff_output = self.ff(x)
|
45
|
+
x = self.norm2(x + self.dropout(ff_output))
|
46
|
+
return x
|
47
|
+
|
48
|
+
|
49
|
+
class TransformerDecoderLayer(Model):
|
50
|
+
def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
|
51
|
+
super().__init__()
|
52
|
+
self.self_attn = nn.MultiheadAttention(
|
53
|
+
d_model, n_heads, dropout=dropout, batch_first=True
|
54
|
+
)
|
55
|
+
self.norm1 = nn.LayerNorm(d_model)
|
56
|
+
|
57
|
+
self.cross_attn = nn.MultiheadAttention(
|
58
|
+
d_model, n_heads, dropout=dropout, batch_first=True
|
59
|
+
)
|
60
|
+
self.norm2 = nn.LayerNorm(d_model)
|
61
|
+
|
62
|
+
self.ff = FeedForward(d_model, ff_size, dropout)
|
63
|
+
self.norm3 = nn.LayerNorm(d_model)
|
64
|
+
self.dropout = nn.Dropout(dropout)
|
65
|
+
|
66
|
+
def forward(
|
67
|
+
self,
|
68
|
+
x: Tensor, # Decoder input [B, T, d_model]
|
69
|
+
encoder_out: Tensor, # Encoder output [B, S, d_model]
|
70
|
+
tgt_mask: Optional[Tensor] = None,
|
71
|
+
memory_mask: Optional[Tensor] = None,
|
72
|
+
) -> Tensor:
|
73
|
+
# 1. Masked Self-Attention
|
74
|
+
attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
|
75
|
+
x = self.norm1(x + self.dropout(attn_output))
|
76
|
+
|
77
|
+
# 2. Cross-Attention
|
78
|
+
cross_output, _ = self.cross_attn(
|
79
|
+
x, encoder_out, encoder_out, attn_mask=memory_mask
|
80
|
+
)
|
81
|
+
x = self.norm2(x + self.dropout(cross_output))
|
82
|
+
|
83
|
+
# 3. FeedForward
|
84
|
+
ff_output = self.ff(x)
|
85
|
+
x = self.norm3(x + self.dropout(ff_output))
|
86
|
+
return x
|
87
|
+
|
88
|
+
|
89
|
+
class TransformerEncoder(Model):
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
d_model: int = 64,
|
93
|
+
n_heads: int = 4,
|
94
|
+
ff_size: int = 256,
|
95
|
+
num_layers: int = 2,
|
96
|
+
dropout: float = 0.1,
|
97
|
+
):
|
98
|
+
super().__init__()
|
99
|
+
self.layers = nn.ModuleList(
|
100
|
+
[
|
101
|
+
TransformerEncoderLayer(d_model, n_heads, ff_size, dropout)
|
102
|
+
for _ in range(num_layers)
|
103
|
+
]
|
104
|
+
)
|
105
|
+
|
106
|
+
def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
|
107
|
+
|
108
|
+
for layer in self.layers:
|
109
|
+
x = layer(x, src_mask)
|
110
|
+
return x
|
111
|
+
|
112
|
+
|
113
|
+
class TransformerDecoder(Model):
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
d_model: int = 64,
|
117
|
+
n_heads: int = 2,
|
118
|
+
ff_size: int = 256,
|
119
|
+
num_layers: int = 2,
|
120
|
+
dropout: float = 0.1,
|
121
|
+
):
|
122
|
+
super().__init__()
|
123
|
+
|
124
|
+
self.layers = nn.ModuleList(
|
125
|
+
[
|
126
|
+
TransformerDecoderLayer(d_model, n_heads, ff_size, dropout)
|
127
|
+
for _ in range(num_layers)
|
128
|
+
]
|
129
|
+
)
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
self,
|
133
|
+
x: Tensor,
|
134
|
+
encoder_out: Tensor,
|
135
|
+
tgt_mask: Optional[Tensor] = None,
|
136
|
+
memory_mask: Optional[Tensor] = None,
|
137
|
+
) -> Tensor:
|
138
|
+
for layer in self.layers:
|
139
|
+
x = layer(x, encoder_out, tgt_mask, memory_mask)
|
140
|
+
return x
|
@@ -0,0 +1,70 @@
|
|
1
|
+
from numba import njit, prange
|
2
|
+
|
3
|
+
|
4
|
+
@njit()
|
5
|
+
def maximum_path_each(path, value, t_x, t_y, max_neg_val):
|
6
|
+
index = t_x - 1
|
7
|
+
# Forward pass: Calculate max path sums
|
8
|
+
for y in range(t_y):
|
9
|
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
10
|
+
v_cur = max_neg_val if x == y else value[x, y - 1]
|
11
|
+
v_prev = (
|
12
|
+
0.0
|
13
|
+
if (x == 0 and y == 0)
|
14
|
+
else (max_neg_val if x == 0 else value[x - 1, y - 1])
|
15
|
+
)
|
16
|
+
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
17
|
+
|
18
|
+
# Backtrack to store the path
|
19
|
+
for y in range(t_y - 1, -1, -1):
|
20
|
+
path[index, y] = 1
|
21
|
+
if index != 0 and (index == y or value[index, y - 1] < value[index - 1, y - 1]):
|
22
|
+
index -= 1
|
23
|
+
|
24
|
+
|
25
|
+
@njit() # Took almost 10x the time while testing using "parallel=True".
|
26
|
+
def maximum_path(paths, values, t_xs, t_ys, max_neg_val=-1e9):
|
27
|
+
"""
|
28
|
+
Example:
|
29
|
+
```python
|
30
|
+
paths = tc.randn((2, 3, 3)).numpy()
|
31
|
+
values = tc.randn((2, 3, 3)).numpy()
|
32
|
+
t_xs = tc.tensor([3, 3, 3]).numpy()
|
33
|
+
t_ys = tc.tensor([3, 3]).numpy()
|
34
|
+
|
35
|
+
# to display values (before) and paths:
|
36
|
+
print("=====================")
|
37
|
+
print("Paths:")
|
38
|
+
print(paths)
|
39
|
+
print("Original Values:")
|
40
|
+
print(values)
|
41
|
+
|
42
|
+
maximum_path(paths, values, t_xs, t_ys)
|
43
|
+
|
44
|
+
print("Updated Values:")
|
45
|
+
print(values)
|
46
|
+
print("=====================")
|
47
|
+
|
48
|
+
```
|
49
|
+
Outputs:
|
50
|
+
```md
|
51
|
+
=====================
|
52
|
+
Paths:
|
53
|
+
[[[ 2.310408 -1.9375949 -0.57884663]
|
54
|
+
[ 1.0308106 1.0793993 0.4461908 ]
|
55
|
+
[ 0.26789713 0.48924422 0.3409592 ]]]
|
56
|
+
Original Values:
|
57
|
+
[[[-0.48256454 0.51348686 -1.8236492 ]
|
58
|
+
[ 0.9949021 -0.6066166 0.18991096]
|
59
|
+
[ 1.2555764 -0.24222293 -0.78757876]]]
|
60
|
+
Updated Values:
|
61
|
+
[[[-0.48256454 0.51348686 -1.8236492 ]
|
62
|
+
[ 0.9949021 -1.0891812 0.18991096]
|
63
|
+
[ 1.2555764 -0.24222293 -1.87676 ]]]
|
64
|
+
=====================
|
65
|
+
```
|
66
|
+
This may not be the standard, but may work for your project.
|
67
|
+
"""
|
68
|
+
batch_size = values.shape[0]
|
69
|
+
for i in prange(batch_size):
|
70
|
+
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
lt_tensor/transform.py
ADDED
@@ -0,0 +1,349 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"to_mel_spectrogram",
|
3
|
+
"stft",
|
4
|
+
"istft",
|
5
|
+
"fft",
|
6
|
+
"ifft",
|
7
|
+
"to_log_mel_spectrogram",
|
8
|
+
"normalize",
|
9
|
+
"min_max_scale",
|
10
|
+
"mel_to_linear",
|
11
|
+
"add_noise",
|
12
|
+
"shift_time",
|
13
|
+
"stretch_tensor",
|
14
|
+
"pad_tensor",
|
15
|
+
"get_sinusoidal_embedding",
|
16
|
+
"pad_center",
|
17
|
+
"normalize",
|
18
|
+
"window_sumsquare",
|
19
|
+
"inverse_transform",
|
20
|
+
"stft_istft_rebuild",
|
21
|
+
]
|
22
|
+
|
23
|
+
from ._torch_commons import *
|
24
|
+
import torchaudio
|
25
|
+
import math
|
26
|
+
from .misc_utils import log_tensor
|
27
|
+
|
28
|
+
|
29
|
+
def to_mel_spectrogram(
|
30
|
+
waveform: torch.Tensor,
|
31
|
+
sample_rate: int = 22050,
|
32
|
+
n_fft: int = 1024,
|
33
|
+
hop_length: Optional[int] = None,
|
34
|
+
win_length: Optional[int] = None,
|
35
|
+
n_mels: int = 80,
|
36
|
+
f_min: float = 0.0,
|
37
|
+
f_max: Optional[float] = None,
|
38
|
+
) -> torch.Tensor:
|
39
|
+
"""Converts waveform to mel spectrogram."""
|
40
|
+
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
41
|
+
sample_rate=sample_rate,
|
42
|
+
n_fft=n_fft,
|
43
|
+
hop_length=hop_length,
|
44
|
+
win_length=win_length,
|
45
|
+
n_mels=n_mels,
|
46
|
+
f_min=f_min,
|
47
|
+
f_max=f_max,
|
48
|
+
)
|
49
|
+
return mel_spectrogram(waveform)
|
50
|
+
|
51
|
+
|
52
|
+
def stft(
|
53
|
+
waveform: Tensor,
|
54
|
+
n_fft: int = 512,
|
55
|
+
hop_length: Optional[int] = None,
|
56
|
+
win_length: Optional[int] = None,
|
57
|
+
window_fn: str = "hann",
|
58
|
+
center: bool = True,
|
59
|
+
return_complex: bool = True,
|
60
|
+
) -> Tensor:
|
61
|
+
"""Performs short-time Fourier transform using PyTorch."""
|
62
|
+
window = (
|
63
|
+
torch.hann_window(win_length or n_fft).to(waveform.device)
|
64
|
+
if window_fn == "hann"
|
65
|
+
else None
|
66
|
+
)
|
67
|
+
return torch.stft(
|
68
|
+
input=waveform,
|
69
|
+
n_fft=n_fft,
|
70
|
+
hop_length=hop_length,
|
71
|
+
win_length=win_length,
|
72
|
+
window=window,
|
73
|
+
center=center,
|
74
|
+
return_complex=return_complex,
|
75
|
+
)
|
76
|
+
|
77
|
+
|
78
|
+
def istft(
|
79
|
+
stft_matrix: Tensor,
|
80
|
+
n_fft: int = 512,
|
81
|
+
hop_length: Optional[int] = None,
|
82
|
+
win_length: Optional[int] = None,
|
83
|
+
window_fn: str = "hann",
|
84
|
+
center: bool = True,
|
85
|
+
length: Optional[int] = None,
|
86
|
+
) -> Tensor:
|
87
|
+
"""Performs inverse short-time Fourier transform using PyTorch."""
|
88
|
+
window = (
|
89
|
+
torch.hann_window(win_length or n_fft).to(stft_matrix.device)
|
90
|
+
if window_fn == "hann"
|
91
|
+
else None
|
92
|
+
)
|
93
|
+
return torch.istft(
|
94
|
+
input=stft_matrix,
|
95
|
+
n_fft=n_fft,
|
96
|
+
hop_length=hop_length,
|
97
|
+
win_length=win_length,
|
98
|
+
window=window,
|
99
|
+
center=center,
|
100
|
+
length=length,
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
|
105
|
+
"""Returns the FFT of a real tensor."""
|
106
|
+
return torch.fft.fft(x, norm=norm)
|
107
|
+
|
108
|
+
|
109
|
+
def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
|
110
|
+
"""Returns the inverse FFT of a complex tensor."""
|
111
|
+
return torch.fft.ifft(x, norm=norm)
|
112
|
+
|
113
|
+
|
114
|
+
def to_log_mel_spectrogram(
|
115
|
+
waveform: torch.Tensor, sample_rate: int = 22050, eps: float = 1e-9, **kwargs
|
116
|
+
) -> torch.Tensor:
|
117
|
+
"""Converts waveform to log-mel spectrogram."""
|
118
|
+
mel = to_mel_spectrogram(waveform, sample_rate, **kwargs)
|
119
|
+
return torch.log(mel + eps)
|
120
|
+
|
121
|
+
|
122
|
+
def normalize(
|
123
|
+
x: torch.Tensor,
|
124
|
+
mean: Optional[float] = None,
|
125
|
+
std: Optional[float] = None,
|
126
|
+
eps: float = 1e-9,
|
127
|
+
) -> torch.Tensor:
|
128
|
+
"""Normalizes tensor by mean and std."""
|
129
|
+
if mean is None:
|
130
|
+
mean = x.mean()
|
131
|
+
if std is None:
|
132
|
+
std = x.std()
|
133
|
+
return (x - mean) / (std + eps)
|
134
|
+
|
135
|
+
|
136
|
+
def min_max_scale(
|
137
|
+
x: torch.Tensor, min_val: float = 0.0, max_val: float = 1.0
|
138
|
+
) -> torch.Tensor:
|
139
|
+
"""Scales tensor to [min_val, max_val] range."""
|
140
|
+
x_min, x_max = x.min(), x.max()
|
141
|
+
return (x - x_min) / (x_max - x_min + 1e-8) * (max_val - min_val) + min_val
|
142
|
+
|
143
|
+
|
144
|
+
def mel_to_linear(
|
145
|
+
mel_spec: torch.Tensor, mel_fb: torch.Tensor, eps: float = 1e-8
|
146
|
+
) -> torch.Tensor:
|
147
|
+
"""Approximate inversion of mel to linear spectrogram using pseudo-inverse."""
|
148
|
+
mel_fb_inv = torch.pinverse(mel_fb)
|
149
|
+
return torch.matmul(mel_fb_inv, mel_spec + eps)
|
150
|
+
|
151
|
+
|
152
|
+
def add_noise(x: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
|
153
|
+
"""Adds Gaussian noise to tensor."""
|
154
|
+
return x + noise_level * torch.randn_like(x)
|
155
|
+
|
156
|
+
|
157
|
+
def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
|
158
|
+
"""Shifts tensor along time axis (last dim)."""
|
159
|
+
return torch.roll(x, shifts=shift, dims=-1)
|
160
|
+
|
161
|
+
|
162
|
+
def stretch_tensor(x: torch.Tensor, rate: float, mode: str = "linear") -> torch.Tensor:
|
163
|
+
"""Time-stretch tensor using interpolation."""
|
164
|
+
B, C, T = x.shape if x.ndim == 3 else (1, 1, x.shape[0])
|
165
|
+
new_T = int(T * rate)
|
166
|
+
x_reshaped = x.view(B * C, T).unsqueeze(1)
|
167
|
+
stretched = torch.nn.functional.interpolate(x_reshaped, size=new_T, mode=mode)
|
168
|
+
return (
|
169
|
+
stretched.squeeze(1).view(B, C, new_T) if x.ndim == 3 else stretched.squeeze()
|
170
|
+
)
|
171
|
+
|
172
|
+
|
173
|
+
def pad_tensor(
|
174
|
+
x: torch.Tensor, target_len: int, pad_value: float = 0.0
|
175
|
+
) -> torch.Tensor:
|
176
|
+
"""Pads tensor to target length along last dimension."""
|
177
|
+
current_len = x.shape[-1]
|
178
|
+
if current_len >= target_len:
|
179
|
+
return x[..., :target_len]
|
180
|
+
padding = [0] * (2 * (x.ndim - 1)) + [0, target_len - current_len]
|
181
|
+
return F.pad(x, padding, value=pad_value)
|
182
|
+
|
183
|
+
|
184
|
+
def get_sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
|
185
|
+
# Expect shape [B] or [B, 1]
|
186
|
+
if timesteps.dim() > 1:
|
187
|
+
timesteps = timesteps.view(-1) # flatten to [B]
|
188
|
+
|
189
|
+
device = timesteps.device
|
190
|
+
half_dim = dim // 2
|
191
|
+
emb = torch.exp(
|
192
|
+
torch.arange(half_dim, device=device) * -(math.log(10000.0) / half_dim)
|
193
|
+
)
|
194
|
+
emb = timesteps[:, None].float() * emb[None, :] # [B, half_dim]
|
195
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # [B, dim]
|
196
|
+
return emb
|
197
|
+
|
198
|
+
|
199
|
+
def _generate_window(
|
200
|
+
M: int, alpha: float = 0.5, device: Optional[DeviceType] = None
|
201
|
+
) -> Tensor:
|
202
|
+
if M < 1:
|
203
|
+
raise ValueError("Window length M must be >= 1.")
|
204
|
+
if M == 1:
|
205
|
+
return torch.ones(1, device=device)
|
206
|
+
|
207
|
+
n = torch.arange(M, dtype=torch.float32, device=device)
|
208
|
+
window = alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
|
209
|
+
return window
|
210
|
+
|
211
|
+
|
212
|
+
def pad_center(tensor: torch.Tensor, size: int, axis: int = -1) -> torch.Tensor:
|
213
|
+
n = tensor.shape[axis]
|
214
|
+
if size < n:
|
215
|
+
raise ValueError(f"Target size ({size}) must be at least input size ({n})")
|
216
|
+
|
217
|
+
lpad = (size - n) // 2
|
218
|
+
rpad = size - n - lpad
|
219
|
+
|
220
|
+
pad = [0] * (2 * tensor.ndim)
|
221
|
+
pad[2 * axis + 1] = rpad
|
222
|
+
pad[2 * axis] = lpad
|
223
|
+
|
224
|
+
return F.pad(tensor, pad, mode="constant", value=0)
|
225
|
+
|
226
|
+
|
227
|
+
def normalize(
|
228
|
+
S: torch.Tensor,
|
229
|
+
norm: float = float("inf"),
|
230
|
+
axis: int = 0,
|
231
|
+
threshold: float = 1e-10,
|
232
|
+
fill: bool = False,
|
233
|
+
) -> torch.Tensor:
|
234
|
+
mag = S.abs().float()
|
235
|
+
|
236
|
+
if norm is None:
|
237
|
+
return S
|
238
|
+
|
239
|
+
elif norm == float("inf"):
|
240
|
+
length = mag.max(dim=axis, keepdim=True).values
|
241
|
+
|
242
|
+
elif norm == float("-inf"):
|
243
|
+
length = mag.min(dim=axis, keepdim=True).values
|
244
|
+
|
245
|
+
elif norm == 0:
|
246
|
+
length = (mag > 0).sum(dim=axis, keepdim=True).float()
|
247
|
+
|
248
|
+
elif norm > 0:
|
249
|
+
length = (mag**norm).sum(dim=axis, keepdim=True) ** (1.0 / norm)
|
250
|
+
|
251
|
+
else:
|
252
|
+
raise ValueError(f"Unsupported norm: {norm}")
|
253
|
+
|
254
|
+
small_idx = length < threshold
|
255
|
+
length = length.clone()
|
256
|
+
if fill:
|
257
|
+
length[small_idx] = float("nan")
|
258
|
+
Snorm = S / length
|
259
|
+
Snorm[Snorm != Snorm] = 1.0 # replace nan with fill_norm (default 1.0)
|
260
|
+
else:
|
261
|
+
length[small_idx] = float("inf")
|
262
|
+
Snorm = S / length
|
263
|
+
|
264
|
+
return Snorm
|
265
|
+
|
266
|
+
|
267
|
+
def window_sumsquare(
|
268
|
+
window_spec: Union[str, int, float, Callable, List[Any], Tuple[Any, ...]],
|
269
|
+
n_frames: int,
|
270
|
+
hop_length: int = 300,
|
271
|
+
win_length: int = 1200,
|
272
|
+
n_fft: int = 2048,
|
273
|
+
dtype: torch.dtype = torch.float32,
|
274
|
+
norm: Optional[Union[int, float]] = None,
|
275
|
+
device: Optional[torch.device] = "cpu",
|
276
|
+
):
|
277
|
+
if win_length is None:
|
278
|
+
win_length = n_fft
|
279
|
+
|
280
|
+
total_length = n_fft + hop_length * (n_frames - 1)
|
281
|
+
x = torch.zeros(total_length, dtype=dtype, device=device)
|
282
|
+
|
283
|
+
# Get the window (from scipy for now)
|
284
|
+
win = _generate_window(window_spec, win_length, fftbins=True)
|
285
|
+
win = torch.tensor(win, dtype=dtype, device=device)
|
286
|
+
|
287
|
+
# Normalize and square
|
288
|
+
win_sq = normalize(win, norm=norm, axis=0) ** 2
|
289
|
+
win_sq = pad_center(win_sq, size=n_fft, axis=0)
|
290
|
+
|
291
|
+
# Accumulate squared windows
|
292
|
+
for i in range(n_frames):
|
293
|
+
sample = i * hop_length
|
294
|
+
end = min(total_length, sample + n_fft)
|
295
|
+
length = end - sample
|
296
|
+
x[sample:end] += win_sq[:length]
|
297
|
+
|
298
|
+
return x
|
299
|
+
|
300
|
+
|
301
|
+
def inverse_transform(
|
302
|
+
spec: Tensor,
|
303
|
+
phase: Tensor,
|
304
|
+
window: Optional[Tensor] = None,
|
305
|
+
n_fft: int = 2048,
|
306
|
+
hop_length: int = 300,
|
307
|
+
win_length: int = 1200,
|
308
|
+
length: Optional[Any] = None,
|
309
|
+
):
|
310
|
+
if window is None:
|
311
|
+
window = _generate_window(win_length)
|
312
|
+
return torch.istft(
|
313
|
+
spec * torch.exp(phase * 1j),
|
314
|
+
n_fft,
|
315
|
+
hop_length,
|
316
|
+
win_length,
|
317
|
+
window=window,
|
318
|
+
length=length,
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
def stft_istft_rebuild(
|
323
|
+
input_data: Tensor,
|
324
|
+
window: Optional[Tensor] = None,
|
325
|
+
n_fft: int = 2048,
|
326
|
+
hop_length: int = 300,
|
327
|
+
win_length: int = 1200,
|
328
|
+
):
|
329
|
+
"""
|
330
|
+
Perform STFT followed by ISTFT reconstruction using magnitude and phase.
|
331
|
+
"""
|
332
|
+
if window is None:
|
333
|
+
window = _generate_window(win_length)
|
334
|
+
st = torch.stft(
|
335
|
+
input_data,
|
336
|
+
n_fft,
|
337
|
+
hop_length,
|
338
|
+
win_length,
|
339
|
+
window=window,
|
340
|
+
return_complex=True,
|
341
|
+
)
|
342
|
+
return torch.istft(
|
343
|
+
torch.abs(st) * torch.exp(1j * torch.angle(st)),
|
344
|
+
n_fft,
|
345
|
+
hop_length,
|
346
|
+
win_length,
|
347
|
+
window=window,
|
348
|
+
length=input_data.shape[-1],
|
349
|
+
).squeeze(0)
|