lt-tensor 0.0.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,158 @@
1
+ __all__ = [
2
+ "spectral_norm_select",
3
+ "ResBlock1D",
4
+ "ResBlock2D",
5
+ "ResBlock1D_S",
6
+ ]
7
+
8
+ from .._torch_commons import *
9
+ from .._basics import Model
10
+ import math
11
+ from ..misc_utils import log_tensor
12
+
13
+
14
+ def spectral_norm_select(module: Module, enabled: bool):
15
+ if enabled:
16
+ return spectral_norm(module)
17
+ return module
18
+
19
+
20
+ class ResBlock1D(Model):
21
+ def __init__(
22
+ self,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ kernel_size: int = 3,
26
+ dilation: Union[Sequence[int], int] = (1, 3, 5),
27
+ activation: nn.Module = nn.LeakyReLU(0.1),
28
+ num_groups: int = 1,
29
+ batched: bool = True,
30
+ ):
31
+ super().__init__()
32
+ self.conv = nn.ModuleList()
33
+ if isinstance(dilation, int):
34
+ dilation = [dilation]
35
+
36
+ if batched:
37
+ layernorm_fn = lambda x: nn.GroupNorm(num_groups=num_groups, num_channels=x)
38
+ else:
39
+ layernorm_fn = lambda x: nn.LayerNorm(normalized_shape=x)
40
+ for i, dil in enumerate(dilation):
41
+
42
+ self.conv.append(
43
+ nn.ModuleDict(
44
+ dict(
45
+ net=nn.Sequential(
46
+ self._get_conv_layer(
47
+ in_channels, in_channels, kernel_size, dil
48
+ ),
49
+ activation,
50
+ self._get_conv_layer(
51
+ in_channels, in_channels, kernel_size, 1, True
52
+ ),
53
+ activation,
54
+ ),
55
+ l_norm=layernorm_fn(in_channels),
56
+ )
57
+ )
58
+ )
59
+ self.final = nn.Sequential(
60
+ self._get_conv_layer(in_channels, out_channels, kernel_size, 1, True),
61
+ activation,
62
+ )
63
+ self.conv.apply(self.init_weights)
64
+
65
+ def _get_conv_layer(
66
+ self,
67
+ channels_in: int,
68
+ channels_out: int,
69
+ kernel_size: int,
70
+ dilation: int,
71
+ pad_gate: bool = False,
72
+ ):
73
+ return weight_norm(
74
+ nn.Conv1d(
75
+ in_channels=channels_in,
76
+ out_channels=channels_out,
77
+ kernel_size=kernel_size,
78
+ stride=1,
79
+ dilation=dilation,
80
+ padding=(
81
+ int((kernel_size * dilation - dilation) / 2)
82
+ if not pad_gate
83
+ else int((kernel_size * 1 - 1) / 2)
84
+ ),
85
+ )
86
+ )
87
+
88
+ def forward(self, x: Tensor):
89
+ for i, layer in enumerate(self.conv):
90
+ xt = layer["net"](x)
91
+ x = xt + x
92
+ x = layer["l_norm"](x)
93
+ return self.final(x)
94
+
95
+ def remove_weight_norm(self):
96
+ for module in self.modules():
97
+ try:
98
+ remove_weight_norm(module)
99
+ except ValueError:
100
+ pass # Not normed, skip
101
+
102
+ @staticmethod
103
+ def init_weights(m, mean=0.0, std=0.01):
104
+ classname = m.__class__.__name__
105
+ if "Conv" in classname:
106
+ m.weight.data.normal_(mean, std)
107
+
108
+
109
+ class ResBlock2D(Model):
110
+ def __init__(
111
+ self,
112
+ in_channels,
113
+ out_channels,
114
+ downsample=False,
115
+ spec_norm: bool = False,
116
+ ):
117
+ super().__init__()
118
+ stride = 2 if downsample else 1
119
+
120
+ self.block = nn.Sequential(
121
+ spectral_norm_select(
122
+ nn.Conv2d(in_channels, out_channels, 3, stride, 1), spec_norm
123
+ ),
124
+ nn.LeakyReLU(0.2),
125
+ spectral_norm_select(
126
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), spec_norm
127
+ ),
128
+ )
129
+
130
+ self.skip = nn.Identity()
131
+ if downsample or in_channels != out_channels:
132
+ self.skip = spectral_norm_select(
133
+ nn.Conv2d(in_channels, out_channels, 1, stride), spec_norm
134
+ )
135
+ # on less to be handled every cicle
136
+ self.sqrt_2 = math.sqrt(2)
137
+
138
+ def forward(self, x):
139
+ return (self.block(x) + self.skip(x)) / self.sqrt_2
140
+
141
+
142
+ class ResBlock1D_S(Model):
143
+ """Simplified version"""
144
+
145
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: int = 1):
146
+ super().__init__()
147
+ padding = (kernel_size - 1) // 2 * dilation
148
+ self.net = nn.Sequential(
149
+ nn.Conv1d(
150
+ channels, channels, kernel_size, padding=padding, dilation=dilation
151
+ ),
152
+ nn.LeakyReLU(0.1),
153
+ nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=1),
154
+ )
155
+ self.activation = nn.LeakyReLU(0.1)
156
+
157
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
158
+ return self.activation(x + self.net(x))
@@ -0,0 +1,140 @@
1
+ __all__ = [
2
+ "TransformerEncoderLayer",
3
+ "TransformerDecoderLayer",
4
+ "TransformerEncoder",
5
+ "TransformerDecoder",
6
+ "init_weights",
7
+ ]
8
+
9
+ import math
10
+ from .._torch_commons import *
11
+ from .._basics import Model
12
+ from lt_utils.misc_utils import default
13
+
14
+ from .pos import *
15
+ from .bsc import FeedForward
16
+
17
+
18
+ def init_weights(module):
19
+ if isinstance(module, nn.Linear):
20
+ nn.init.xavier_uniform_(module.weight)
21
+ if module.bias is not None:
22
+ nn.init.constant_(module.bias, 0)
23
+ elif isinstance(module, nn.Embedding):
24
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
25
+ elif isinstance(module, nn.LayerNorm):
26
+ nn.init.constant_(module.bias, 0)
27
+ nn.init.constant_(module.weight, 1.0)
28
+
29
+
30
+ class TransformerEncoderLayer(Model):
31
+ def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
32
+ super().__init__()
33
+ self.self_attn = nn.MultiheadAttention(
34
+ d_model, n_heads, dropout=dropout, batch_first=True
35
+ )
36
+ self.norm1 = nn.LayerNorm(d_model)
37
+ self.ff = FeedForward(d_model, ff_size, dropout)
38
+ self.norm2 = nn.LayerNorm(d_model)
39
+ self.dropout = nn.Dropout(dropout)
40
+
41
+ def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
42
+ attn_output, _ = self.self_attn(x, x, x, attn_mask=src_mask)
43
+ x = self.norm1(x + self.dropout(attn_output))
44
+ ff_output = self.ff(x)
45
+ x = self.norm2(x + self.dropout(ff_output))
46
+ return x
47
+
48
+
49
+ class TransformerDecoderLayer(Model):
50
+ def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
51
+ super().__init__()
52
+ self.self_attn = nn.MultiheadAttention(
53
+ d_model, n_heads, dropout=dropout, batch_first=True
54
+ )
55
+ self.norm1 = nn.LayerNorm(d_model)
56
+
57
+ self.cross_attn = nn.MultiheadAttention(
58
+ d_model, n_heads, dropout=dropout, batch_first=True
59
+ )
60
+ self.norm2 = nn.LayerNorm(d_model)
61
+
62
+ self.ff = FeedForward(d_model, ff_size, dropout)
63
+ self.norm3 = nn.LayerNorm(d_model)
64
+ self.dropout = nn.Dropout(dropout)
65
+
66
+ def forward(
67
+ self,
68
+ x: Tensor, # Decoder input [B, T, d_model]
69
+ encoder_out: Tensor, # Encoder output [B, S, d_model]
70
+ tgt_mask: Optional[Tensor] = None,
71
+ memory_mask: Optional[Tensor] = None,
72
+ ) -> Tensor:
73
+ # 1. Masked Self-Attention
74
+ attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
75
+ x = self.norm1(x + self.dropout(attn_output))
76
+
77
+ # 2. Cross-Attention
78
+ cross_output, _ = self.cross_attn(
79
+ x, encoder_out, encoder_out, attn_mask=memory_mask
80
+ )
81
+ x = self.norm2(x + self.dropout(cross_output))
82
+
83
+ # 3. FeedForward
84
+ ff_output = self.ff(x)
85
+ x = self.norm3(x + self.dropout(ff_output))
86
+ return x
87
+
88
+
89
+ class TransformerEncoder(Model):
90
+ def __init__(
91
+ self,
92
+ d_model: int = 64,
93
+ n_heads: int = 4,
94
+ ff_size: int = 256,
95
+ num_layers: int = 2,
96
+ dropout: float = 0.1,
97
+ ):
98
+ super().__init__()
99
+ self.layers = nn.ModuleList(
100
+ [
101
+ TransformerEncoderLayer(d_model, n_heads, ff_size, dropout)
102
+ for _ in range(num_layers)
103
+ ]
104
+ )
105
+
106
+ def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
107
+
108
+ for layer in self.layers:
109
+ x = layer(x, src_mask)
110
+ return x
111
+
112
+
113
+ class TransformerDecoder(Model):
114
+ def __init__(
115
+ self,
116
+ d_model: int = 64,
117
+ n_heads: int = 2,
118
+ ff_size: int = 256,
119
+ num_layers: int = 2,
120
+ dropout: float = 0.1,
121
+ ):
122
+ super().__init__()
123
+
124
+ self.layers = nn.ModuleList(
125
+ [
126
+ TransformerDecoderLayer(d_model, n_heads, ff_size, dropout)
127
+ for _ in range(num_layers)
128
+ ]
129
+ )
130
+
131
+ def forward(
132
+ self,
133
+ x: Tensor,
134
+ encoder_out: Tensor,
135
+ tgt_mask: Optional[Tensor] = None,
136
+ memory_mask: Optional[Tensor] = None,
137
+ ) -> Tensor:
138
+ for layer in self.layers:
139
+ x = layer(x, encoder_out, tgt_mask, memory_mask)
140
+ return x
@@ -0,0 +1,70 @@
1
+ from numba import njit, prange
2
+
3
+
4
+ @njit()
5
+ def maximum_path_each(path, value, t_x, t_y, max_neg_val):
6
+ index = t_x - 1
7
+ # Forward pass: Calculate max path sums
8
+ for y in range(t_y):
9
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
10
+ v_cur = max_neg_val if x == y else value[x, y - 1]
11
+ v_prev = (
12
+ 0.0
13
+ if (x == 0 and y == 0)
14
+ else (max_neg_val if x == 0 else value[x - 1, y - 1])
15
+ )
16
+ value[x, y] = max(v_cur, v_prev) + value[x, y]
17
+
18
+ # Backtrack to store the path
19
+ for y in range(t_y - 1, -1, -1):
20
+ path[index, y] = 1
21
+ if index != 0 and (index == y or value[index, y - 1] < value[index - 1, y - 1]):
22
+ index -= 1
23
+
24
+
25
+ @njit() # Took almost 10x the time while testing using "parallel=True".
26
+ def maximum_path(paths, values, t_xs, t_ys, max_neg_val=-1e9):
27
+ """
28
+ Example:
29
+ ```python
30
+ paths = tc.randn((2, 3, 3)).numpy()
31
+ values = tc.randn((2, 3, 3)).numpy()
32
+ t_xs = tc.tensor([3, 3, 3]).numpy()
33
+ t_ys = tc.tensor([3, 3]).numpy()
34
+
35
+ # to display values (before) and paths:
36
+ print("=====================")
37
+ print("Paths:")
38
+ print(paths)
39
+ print("Original Values:")
40
+ print(values)
41
+
42
+ maximum_path(paths, values, t_xs, t_ys)
43
+
44
+ print("Updated Values:")
45
+ print(values)
46
+ print("=====================")
47
+
48
+ ```
49
+ Outputs:
50
+ ```md
51
+ =====================
52
+ Paths:
53
+ [[[ 2.310408 -1.9375949 -0.57884663]
54
+ [ 1.0308106 1.0793993 0.4461908 ]
55
+ [ 0.26789713 0.48924422 0.3409592 ]]]
56
+ Original Values:
57
+ [[[-0.48256454 0.51348686 -1.8236492 ]
58
+ [ 0.9949021 -0.6066166 0.18991096]
59
+ [ 1.2555764 -0.24222293 -0.78757876]]]
60
+ Updated Values:
61
+ [[[-0.48256454 0.51348686 -1.8236492 ]
62
+ [ 0.9949021 -1.0891812 0.18991096]
63
+ [ 1.2555764 -0.24222293 -1.87676 ]]]
64
+ =====================
65
+ ```
66
+ This may not be the standard, but may work for your project.
67
+ """
68
+ batch_size = values.shape[0]
69
+ for i in prange(batch_size):
70
+ maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
lt_tensor/transform.py ADDED
@@ -0,0 +1,349 @@
1
+ __all__ = [
2
+ "to_mel_spectrogram",
3
+ "stft",
4
+ "istft",
5
+ "fft",
6
+ "ifft",
7
+ "to_log_mel_spectrogram",
8
+ "normalize",
9
+ "min_max_scale",
10
+ "mel_to_linear",
11
+ "add_noise",
12
+ "shift_time",
13
+ "stretch_tensor",
14
+ "pad_tensor",
15
+ "get_sinusoidal_embedding",
16
+ "pad_center",
17
+ "normalize",
18
+ "window_sumsquare",
19
+ "inverse_transform",
20
+ "stft_istft_rebuild",
21
+ ]
22
+
23
+ from ._torch_commons import *
24
+ import torchaudio
25
+ import math
26
+ from .misc_utils import log_tensor
27
+
28
+
29
+ def to_mel_spectrogram(
30
+ waveform: torch.Tensor,
31
+ sample_rate: int = 22050,
32
+ n_fft: int = 1024,
33
+ hop_length: Optional[int] = None,
34
+ win_length: Optional[int] = None,
35
+ n_mels: int = 80,
36
+ f_min: float = 0.0,
37
+ f_max: Optional[float] = None,
38
+ ) -> torch.Tensor:
39
+ """Converts waveform to mel spectrogram."""
40
+ mel_spectrogram = torchaudio.transforms.MelSpectrogram(
41
+ sample_rate=sample_rate,
42
+ n_fft=n_fft,
43
+ hop_length=hop_length,
44
+ win_length=win_length,
45
+ n_mels=n_mels,
46
+ f_min=f_min,
47
+ f_max=f_max,
48
+ )
49
+ return mel_spectrogram(waveform)
50
+
51
+
52
+ def stft(
53
+ waveform: Tensor,
54
+ n_fft: int = 512,
55
+ hop_length: Optional[int] = None,
56
+ win_length: Optional[int] = None,
57
+ window_fn: str = "hann",
58
+ center: bool = True,
59
+ return_complex: bool = True,
60
+ ) -> Tensor:
61
+ """Performs short-time Fourier transform using PyTorch."""
62
+ window = (
63
+ torch.hann_window(win_length or n_fft).to(waveform.device)
64
+ if window_fn == "hann"
65
+ else None
66
+ )
67
+ return torch.stft(
68
+ input=waveform,
69
+ n_fft=n_fft,
70
+ hop_length=hop_length,
71
+ win_length=win_length,
72
+ window=window,
73
+ center=center,
74
+ return_complex=return_complex,
75
+ )
76
+
77
+
78
+ def istft(
79
+ stft_matrix: Tensor,
80
+ n_fft: int = 512,
81
+ hop_length: Optional[int] = None,
82
+ win_length: Optional[int] = None,
83
+ window_fn: str = "hann",
84
+ center: bool = True,
85
+ length: Optional[int] = None,
86
+ ) -> Tensor:
87
+ """Performs inverse short-time Fourier transform using PyTorch."""
88
+ window = (
89
+ torch.hann_window(win_length or n_fft).to(stft_matrix.device)
90
+ if window_fn == "hann"
91
+ else None
92
+ )
93
+ return torch.istft(
94
+ input=stft_matrix,
95
+ n_fft=n_fft,
96
+ hop_length=hop_length,
97
+ win_length=win_length,
98
+ window=window,
99
+ center=center,
100
+ length=length,
101
+ )
102
+
103
+
104
+ def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
105
+ """Returns the FFT of a real tensor."""
106
+ return torch.fft.fft(x, norm=norm)
107
+
108
+
109
+ def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
110
+ """Returns the inverse FFT of a complex tensor."""
111
+ return torch.fft.ifft(x, norm=norm)
112
+
113
+
114
+ def to_log_mel_spectrogram(
115
+ waveform: torch.Tensor, sample_rate: int = 22050, eps: float = 1e-9, **kwargs
116
+ ) -> torch.Tensor:
117
+ """Converts waveform to log-mel spectrogram."""
118
+ mel = to_mel_spectrogram(waveform, sample_rate, **kwargs)
119
+ return torch.log(mel + eps)
120
+
121
+
122
+ def normalize(
123
+ x: torch.Tensor,
124
+ mean: Optional[float] = None,
125
+ std: Optional[float] = None,
126
+ eps: float = 1e-9,
127
+ ) -> torch.Tensor:
128
+ """Normalizes tensor by mean and std."""
129
+ if mean is None:
130
+ mean = x.mean()
131
+ if std is None:
132
+ std = x.std()
133
+ return (x - mean) / (std + eps)
134
+
135
+
136
+ def min_max_scale(
137
+ x: torch.Tensor, min_val: float = 0.0, max_val: float = 1.0
138
+ ) -> torch.Tensor:
139
+ """Scales tensor to [min_val, max_val] range."""
140
+ x_min, x_max = x.min(), x.max()
141
+ return (x - x_min) / (x_max - x_min + 1e-8) * (max_val - min_val) + min_val
142
+
143
+
144
+ def mel_to_linear(
145
+ mel_spec: torch.Tensor, mel_fb: torch.Tensor, eps: float = 1e-8
146
+ ) -> torch.Tensor:
147
+ """Approximate inversion of mel to linear spectrogram using pseudo-inverse."""
148
+ mel_fb_inv = torch.pinverse(mel_fb)
149
+ return torch.matmul(mel_fb_inv, mel_spec + eps)
150
+
151
+
152
+ def add_noise(x: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
153
+ """Adds Gaussian noise to tensor."""
154
+ return x + noise_level * torch.randn_like(x)
155
+
156
+
157
+ def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
158
+ """Shifts tensor along time axis (last dim)."""
159
+ return torch.roll(x, shifts=shift, dims=-1)
160
+
161
+
162
+ def stretch_tensor(x: torch.Tensor, rate: float, mode: str = "linear") -> torch.Tensor:
163
+ """Time-stretch tensor using interpolation."""
164
+ B, C, T = x.shape if x.ndim == 3 else (1, 1, x.shape[0])
165
+ new_T = int(T * rate)
166
+ x_reshaped = x.view(B * C, T).unsqueeze(1)
167
+ stretched = torch.nn.functional.interpolate(x_reshaped, size=new_T, mode=mode)
168
+ return (
169
+ stretched.squeeze(1).view(B, C, new_T) if x.ndim == 3 else stretched.squeeze()
170
+ )
171
+
172
+
173
+ def pad_tensor(
174
+ x: torch.Tensor, target_len: int, pad_value: float = 0.0
175
+ ) -> torch.Tensor:
176
+ """Pads tensor to target length along last dimension."""
177
+ current_len = x.shape[-1]
178
+ if current_len >= target_len:
179
+ return x[..., :target_len]
180
+ padding = [0] * (2 * (x.ndim - 1)) + [0, target_len - current_len]
181
+ return F.pad(x, padding, value=pad_value)
182
+
183
+
184
+ def get_sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
185
+ # Expect shape [B] or [B, 1]
186
+ if timesteps.dim() > 1:
187
+ timesteps = timesteps.view(-1) # flatten to [B]
188
+
189
+ device = timesteps.device
190
+ half_dim = dim // 2
191
+ emb = torch.exp(
192
+ torch.arange(half_dim, device=device) * -(math.log(10000.0) / half_dim)
193
+ )
194
+ emb = timesteps[:, None].float() * emb[None, :] # [B, half_dim]
195
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # [B, dim]
196
+ return emb
197
+
198
+
199
+ def _generate_window(
200
+ M: int, alpha: float = 0.5, device: Optional[DeviceType] = None
201
+ ) -> Tensor:
202
+ if M < 1:
203
+ raise ValueError("Window length M must be >= 1.")
204
+ if M == 1:
205
+ return torch.ones(1, device=device)
206
+
207
+ n = torch.arange(M, dtype=torch.float32, device=device)
208
+ window = alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
209
+ return window
210
+
211
+
212
+ def pad_center(tensor: torch.Tensor, size: int, axis: int = -1) -> torch.Tensor:
213
+ n = tensor.shape[axis]
214
+ if size < n:
215
+ raise ValueError(f"Target size ({size}) must be at least input size ({n})")
216
+
217
+ lpad = (size - n) // 2
218
+ rpad = size - n - lpad
219
+
220
+ pad = [0] * (2 * tensor.ndim)
221
+ pad[2 * axis + 1] = rpad
222
+ pad[2 * axis] = lpad
223
+
224
+ return F.pad(tensor, pad, mode="constant", value=0)
225
+
226
+
227
+ def normalize(
228
+ S: torch.Tensor,
229
+ norm: float = float("inf"),
230
+ axis: int = 0,
231
+ threshold: float = 1e-10,
232
+ fill: bool = False,
233
+ ) -> torch.Tensor:
234
+ mag = S.abs().float()
235
+
236
+ if norm is None:
237
+ return S
238
+
239
+ elif norm == float("inf"):
240
+ length = mag.max(dim=axis, keepdim=True).values
241
+
242
+ elif norm == float("-inf"):
243
+ length = mag.min(dim=axis, keepdim=True).values
244
+
245
+ elif norm == 0:
246
+ length = (mag > 0).sum(dim=axis, keepdim=True).float()
247
+
248
+ elif norm > 0:
249
+ length = (mag**norm).sum(dim=axis, keepdim=True) ** (1.0 / norm)
250
+
251
+ else:
252
+ raise ValueError(f"Unsupported norm: {norm}")
253
+
254
+ small_idx = length < threshold
255
+ length = length.clone()
256
+ if fill:
257
+ length[small_idx] = float("nan")
258
+ Snorm = S / length
259
+ Snorm[Snorm != Snorm] = 1.0 # replace nan with fill_norm (default 1.0)
260
+ else:
261
+ length[small_idx] = float("inf")
262
+ Snorm = S / length
263
+
264
+ return Snorm
265
+
266
+
267
+ def window_sumsquare(
268
+ window_spec: Union[str, int, float, Callable, List[Any], Tuple[Any, ...]],
269
+ n_frames: int,
270
+ hop_length: int = 300,
271
+ win_length: int = 1200,
272
+ n_fft: int = 2048,
273
+ dtype: torch.dtype = torch.float32,
274
+ norm: Optional[Union[int, float]] = None,
275
+ device: Optional[torch.device] = "cpu",
276
+ ):
277
+ if win_length is None:
278
+ win_length = n_fft
279
+
280
+ total_length = n_fft + hop_length * (n_frames - 1)
281
+ x = torch.zeros(total_length, dtype=dtype, device=device)
282
+
283
+ # Get the window (from scipy for now)
284
+ win = _generate_window(window_spec, win_length, fftbins=True)
285
+ win = torch.tensor(win, dtype=dtype, device=device)
286
+
287
+ # Normalize and square
288
+ win_sq = normalize(win, norm=norm, axis=0) ** 2
289
+ win_sq = pad_center(win_sq, size=n_fft, axis=0)
290
+
291
+ # Accumulate squared windows
292
+ for i in range(n_frames):
293
+ sample = i * hop_length
294
+ end = min(total_length, sample + n_fft)
295
+ length = end - sample
296
+ x[sample:end] += win_sq[:length]
297
+
298
+ return x
299
+
300
+
301
+ def inverse_transform(
302
+ spec: Tensor,
303
+ phase: Tensor,
304
+ window: Optional[Tensor] = None,
305
+ n_fft: int = 2048,
306
+ hop_length: int = 300,
307
+ win_length: int = 1200,
308
+ length: Optional[Any] = None,
309
+ ):
310
+ if window is None:
311
+ window = _generate_window(win_length)
312
+ return torch.istft(
313
+ spec * torch.exp(phase * 1j),
314
+ n_fft,
315
+ hop_length,
316
+ win_length,
317
+ window=window,
318
+ length=length,
319
+ )
320
+
321
+
322
+ def stft_istft_rebuild(
323
+ input_data: Tensor,
324
+ window: Optional[Tensor] = None,
325
+ n_fft: int = 2048,
326
+ hop_length: int = 300,
327
+ win_length: int = 1200,
328
+ ):
329
+ """
330
+ Perform STFT followed by ISTFT reconstruction using magnitude and phase.
331
+ """
332
+ if window is None:
333
+ window = _generate_window(win_length)
334
+ st = torch.stft(
335
+ input_data,
336
+ n_fft,
337
+ hop_length,
338
+ win_length,
339
+ window=window,
340
+ return_complex=True,
341
+ )
342
+ return torch.istft(
343
+ torch.abs(st) * torch.exp(1j * torch.angle(st)),
344
+ n_fft,
345
+ hop_length,
346
+ win_length,
347
+ window=window,
348
+ length=input_data.shape[-1],
349
+ ).squeeze(0)