lt-tensor 0.0.1.dev0__py3-none-any.whl → 0.0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1dev1"
lt_tensor/math_ops.py CHANGED
@@ -47,66 +47,6 @@ def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
47
47
  return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
48
48
 
49
49
 
50
- def stft(
51
- waveform: Tensor,
52
- n_fft: int = 512,
53
- hop_length: Optional[int] = None,
54
- win_length: Optional[int] = None,
55
- window_fn: str = "hann",
56
- center: bool = True,
57
- return_complex: bool = True,
58
- ) -> Tensor:
59
- """Performs short-time Fourier transform using PyTorch."""
60
- window = (
61
- torch.hann_window(win_length or n_fft).to(waveform.device)
62
- if window_fn == "hann"
63
- else None
64
- )
65
- return torch.stft(
66
- input=waveform,
67
- n_fft=n_fft,
68
- hop_length=hop_length,
69
- win_length=win_length,
70
- window=window,
71
- center=center,
72
- return_complex=return_complex,
73
- )
74
-
75
-
76
- def istft(
77
- stft_matrix: Tensor,
78
- n_fft: int = 512,
79
- hop_length: Optional[int] = None,
80
- win_length: Optional[int] = None,
81
- window_fn: str = "hann",
82
- center: bool = True,
83
- length: Optional[int] = None,
84
- ) -> Tensor:
85
- """Performs inverse short-time Fourier transform using PyTorch."""
86
- window = (
87
- torch.hann_window(win_length or n_fft).to(stft_matrix.device)
88
- if window_fn == "hann"
89
- else None
90
- )
91
- return torch.istft(
92
- input=stft_matrix,
93
- n_fft=n_fft,
94
- hop_length=hop_length,
95
- win_length=win_length,
96
- window=window,
97
- center=center,
98
- length=length,
99
- )
100
-
101
-
102
- def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
103
- """Returns the FFT of a real tensor."""
104
- return torch.fft.fft(x, norm=norm)
105
-
106
-
107
- def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
108
- """Returns the inverse FFT of a complex tensor."""
109
- return torch.fft.ifft(x, norm=norm)
110
50
 
111
51
 
112
52
  def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
lt_tensor/transform.py CHANGED
@@ -28,6 +28,68 @@ def to_mel_spectrogram(
28
28
  return mel_spectrogram(waveform)
29
29
 
30
30
 
31
+ def stft(
32
+ waveform: Tensor,
33
+ n_fft: int = 512,
34
+ hop_length: Optional[int] = None,
35
+ win_length: Optional[int] = None,
36
+ window_fn: str = "hann",
37
+ center: bool = True,
38
+ return_complex: bool = True,
39
+ ) -> Tensor:
40
+ """Performs short-time Fourier transform using PyTorch."""
41
+ window = (
42
+ torch.hann_window(win_length or n_fft).to(waveform.device)
43
+ if window_fn == "hann"
44
+ else None
45
+ )
46
+ return torch.stft(
47
+ input=waveform,
48
+ n_fft=n_fft,
49
+ hop_length=hop_length,
50
+ win_length=win_length,
51
+ window=window,
52
+ center=center,
53
+ return_complex=return_complex,
54
+ )
55
+
56
+
57
+ def istft(
58
+ stft_matrix: Tensor,
59
+ n_fft: int = 512,
60
+ hop_length: Optional[int] = None,
61
+ win_length: Optional[int] = None,
62
+ window_fn: str = "hann",
63
+ center: bool = True,
64
+ length: Optional[int] = None,
65
+ ) -> Tensor:
66
+ """Performs inverse short-time Fourier transform using PyTorch."""
67
+ window = (
68
+ torch.hann_window(win_length or n_fft).to(stft_matrix.device)
69
+ if window_fn == "hann"
70
+ else None
71
+ )
72
+ return torch.istft(
73
+ input=stft_matrix,
74
+ n_fft=n_fft,
75
+ hop_length=hop_length,
76
+ win_length=win_length,
77
+ window=window,
78
+ center=center,
79
+ length=length,
80
+ )
81
+
82
+
83
+ def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
84
+ """Returns the FFT of a real tensor."""
85
+ return torch.fft.fft(x, norm=norm)
86
+
87
+
88
+ def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
89
+ """Returns the inverse FFT of a complex tensor."""
90
+ return torch.fft.ifft(x, norm=norm)
91
+
92
+
31
93
  def to_log_mel_spectrogram(
32
94
  waveform: torch.Tensor, sample_rate: int = 22050, eps: float = 1e-9, **kwargs
33
95
  ) -> torch.Tensor:
@@ -111,3 +173,160 @@ def get_sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
111
173
  emb = timesteps[:, None].float() * emb[None, :] # [B, half_dim]
112
174
  emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # [B, dim]
113
175
  return emb
176
+
177
+
178
+ def generate_window(
179
+ M: int, alpha: float = 0.5, device: Optional[DeviceType] = None
180
+ ) -> Tensor:
181
+ if M < 1:
182
+ raise ValueError("Window length M must be >= 1.")
183
+ if M == 1:
184
+ return torch.ones(1, device=device)
185
+
186
+ n = torch.arange(M, dtype=torch.float32, device=device)
187
+ window = alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
188
+ return window
189
+
190
+
191
+ def pad_center(tensor: torch.Tensor, size: int, axis: int = -1) -> torch.Tensor:
192
+ n = tensor.shape[axis]
193
+ if size < n:
194
+ raise ValueError(f"Target size ({size}) must be at least input size ({n})")
195
+
196
+ lpad = (size - n) // 2
197
+ rpad = size - n - lpad
198
+
199
+ pad = [0] * (2 * tensor.ndim)
200
+ pad[2 * axis + 1] = rpad
201
+ pad[2 * axis] = lpad
202
+
203
+ return F.pad(tensor, pad, mode="constant", value=0)
204
+
205
+
206
+ def normalize(
207
+ S: torch.Tensor,
208
+ norm: float = float("inf"),
209
+ axis: int = 0,
210
+ threshold: float = 1e-10,
211
+ fill: bool = False,
212
+ ) -> torch.Tensor:
213
+ mag = S.abs().float()
214
+
215
+ if norm is None:
216
+ return S
217
+
218
+ elif norm == float("inf"):
219
+ length = mag.max(dim=axis, keepdim=True).values
220
+
221
+ elif norm == float("-inf"):
222
+ length = mag.min(dim=axis, keepdim=True).values
223
+
224
+ elif norm == 0:
225
+ length = (mag > 0).sum(dim=axis, keepdim=True).float()
226
+
227
+ elif norm > 0:
228
+ length = (mag**norm).sum(dim=axis, keepdim=True) ** (1.0 / norm)
229
+
230
+ else:
231
+ raise ValueError(f"Unsupported norm: {norm}")
232
+
233
+ small_idx = length < threshold
234
+ length = length.clone()
235
+ if fill:
236
+ length[small_idx] = float("nan")
237
+ Snorm = S / length
238
+ Snorm[Snorm != Snorm] = 1.0 # replace nan with fill_norm (default 1.0)
239
+ else:
240
+ length[small_idx] = float("inf")
241
+ Snorm = S / length
242
+
243
+ return Snorm
244
+
245
+
246
+ def window_sumsquare(
247
+ window_spec: Union[str, int, float, Callable, List[Any], Tuple[Any, ...]],
248
+ n_frames: int,
249
+ hop_length: int = 300,
250
+ win_length: int = 1200,
251
+ n_fft: int = 2048,
252
+ dtype: torch.dtype = torch.float32,
253
+ norm: Optional[Union[int, float]] = None,
254
+ device: Optional[torch.device] = "cpu",
255
+ ):
256
+ if win_length is None:
257
+ win_length = n_fft
258
+
259
+ total_length = n_fft + hop_length * (n_frames - 1)
260
+ x = torch.zeros(total_length, dtype=dtype, device=device)
261
+
262
+ # Get the window (from scipy for now)
263
+ win = generate_window(window_spec, win_length, fftbins=True)
264
+ win = torch.tensor(win, dtype=dtype, device=device)
265
+
266
+ # Normalize and square
267
+ win_sq = normalize(win, norm=norm, axis=0) ** 2
268
+ win_sq = pad_center(win_sq, size=n_fft, axis=0)
269
+
270
+ # Accumulate squared windows
271
+ for i in range(n_frames):
272
+ sample = i * hop_length
273
+ end = min(total_length, sample + n_fft)
274
+ length = end - sample
275
+ x[sample:end] += win_sq[:length]
276
+
277
+ return x
278
+
279
+
280
+ def get_window(win_length: int = 1200):
281
+ return generate_window(win_length)
282
+
283
+
284
+ def inverse_transform(
285
+ spec: Tensor,
286
+ phase: Tensor,
287
+ window: Optional[Tensor] = None,
288
+ n_fft: int = 2048,
289
+ hop_length: int = 300,
290
+ win_length: int = 1200,
291
+ length: Optional[torch.shape] = None,
292
+ ):
293
+ if window is None:
294
+ window = generate_window(win_length)
295
+ return torch.istft(
296
+ spec * torch.exp(phase * 1j),
297
+ n_fft,
298
+ hop_length,
299
+ win_length,
300
+ window=window,
301
+ length=length,
302
+ )
303
+
304
+
305
+ def stft_istft_rebuild(
306
+ input_data: Tensor,
307
+ window: Optional[Tensor] = None,
308
+ n_fft: int = 2048,
309
+ hop_length: int = 300,
310
+ win_length: int = 1200,
311
+ ):
312
+ """
313
+ Perform STFT followed by ISTFT reconstruction using magnitude and phase.
314
+ """
315
+ if window is None:
316
+ window = generate_window(win_length)
317
+ st = torch.stft(
318
+ input_data,
319
+ n_fft,
320
+ hop_length,
321
+ win_length,
322
+ window=window,
323
+ return_complex=True,
324
+ )
325
+ return torch.istft(
326
+ torch.abs(st) * torch.exp(1j * torch.angle(st)),
327
+ n_fft,
328
+ hop_length,
329
+ win_length,
330
+ window=window,
331
+ length=input_data.shape[-1],
332
+ ).squeeze(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1.dev0
3
+ Version: 0.0.1.dev1
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,11 +1,11 @@
1
- lt_tensor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
1
+ lt_tensor/__init__.py,sha256=fHS3r6eTARkZmS9vI1iPgHEL-H8y3z0skMRkAfilD8g,26
2
2
  lt_tensor/_basics.py,sha256=XS2OrvyzboUTKURGMU1fmbn2gFq2779HW8xz05fm4x8,8181
3
3
  lt_tensor/_torch_commons.py,sha256=_2Eck-MsQ46PxW5ku7NJvNSL5vg54_4GkLCqdzFevwA,402
4
4
  lt_tensor/lr_schedulers.py,sha256=oLYw2X78KSdlOD0pwPO0lsBj1xqLOAzT8rnRBnCO19o,3560
5
- lt_tensor/math_ops.py,sha256=ajHjt5xmS0dU5vdFYV0QE9BPxTp4ymXVwspO6u6Pwaw,3481
5
+ lt_tensor/math_ops.py,sha256=4jUTtS1ZwJnh1AaFGhMuKTcakjwLQ846AKcBoDEwudI,1862
6
6
  lt_tensor/misc_utils.py,sha256=EaQ__986n5J-oAnMXmTp9hXxCDk6NrgvGlU_C05M7G4,19193
7
7
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
8
- lt_tensor/transform.py,sha256=ucHy-nuibhoH75brNBPxXFVahAd_NHr27vXLgLlkNGk,3552
8
+ lt_tensor/transform.py,sha256=J6KFmOmQCeCpbK7r8yaUZ8RiW-o-rvtSFvBVM5utBhk,9208
9
9
  lt_tensor/model_zoo/__init__.py,sha256=sIidI3gSoxTq2OTenqADiFlmd9EH8rS_mzjLePdNHqc,77
10
10
  lt_tensor/model_zoo/basic.py,sha256=Ccjg26-gnqbkOUg0aNnLhZmsfvMwqTxXVXimfC5qnl0,1862
11
11
  lt_tensor/model_zoo/residual.py,sha256=iViJ5MrIY1vUbPB0ZiOjh62wfQagbPulMvbZPTXp-OU,7370
@@ -14,8 +14,8 @@ lt_tensor/model_zoo/diffusion/models.py,sha256=LRf5B2MPic4Dwfvg2PxG61KFtqvATV3wO
14
14
  lt_tensor/model_zoo/transformer_models/__init__.py,sha256=NOvCP-EySkNHLbGOxFMURp0AztvJQSNo9J_0hsEzLto,130
15
15
  lt_tensor/model_zoo/transformer_models/models.py,sha256=Xh2nq83w0qMX-Co7EmaQb_auq7Fi7kyULklc0_yq3oo,4088
16
16
  lt_tensor/model_zoo/transformer_models/positional_encoders.py,sha256=ilJDKAonnU0973BT_7gT5ke_8PbWHMxsr8EvNQHAcUY,3529
17
- lt_tensor-0.0.1.dev0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
- lt_tensor-0.0.1.dev0.dist-info/METADATA,sha256=2hs2eLrpyXhkLdWdzBre_trMJsizLgC0wLPnE4eHVCE,1055
19
- lt_tensor-0.0.1.dev0.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
20
- lt_tensor-0.0.1.dev0.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
21
- lt_tensor-0.0.1.dev0.dist-info/RECORD,,
17
+ lt_tensor-0.0.1.dev1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ lt_tensor-0.0.1.dev1.dist-info/METADATA,sha256=Os0vIySmcMifWqS32PHnLRTz2dW_VfF-tf_jhNcmPBk,1055
19
+ lt_tensor-0.0.1.dev1.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
20
+ lt_tensor-0.0.1.dev1.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
21
+ lt_tensor-0.0.1.dev1.dist-info/RECORD,,