lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,12 +7,12 @@ __all__ = [
7
7
  ]
8
8
 
9
9
  import math
10
- from ..torch_commons import *
11
- from ..model_base import Model
10
+ from lt_tensor.torch_commons import *
11
+ from lt_tensor.model_base import Model
12
12
  from lt_utils.misc_utils import default
13
13
  from typing import Optional
14
- from .pos import *
15
- from .bsc import FeedForward
14
+ from lt_tensor.model_zoo.pos_encoder import *
15
+ from lt_tensor.model_zoo.basic import FeedForward
16
16
 
17
17
 
18
18
  def init_weights(module):
lt_tensor/noise_tools.py CHANGED
@@ -14,10 +14,10 @@ __all__ = [
14
14
 
15
15
  from lt_utils.common import *
16
16
  import torch.nn.functional as F
17
- from .torch_commons import *
17
+ from lt_tensor.torch_commons import *
18
18
  import math
19
19
  import random
20
- from .misc_utils import set_seed
20
+ from lt_tensor.misc_utils import set_seed
21
21
 
22
22
 
23
23
  def add_gaussian_noise(x: Tensor, noise_level=0.025):
@@ -5,11 +5,14 @@ from lt_utils.type_utils import is_file, is_array
5
5
  from lt_tensor.misc_utils import log_tensor
6
6
  import librosa
7
7
  import torchaudio
8
-
8
+ import numpy as np
9
9
  from lt_tensor.transform import InverseTransformConfig, InverseTransform
10
10
  from lt_utils.file_ops import FileScan, get_file_name, path_to_str
11
-
11
+ from torchaudio.functional import detect_pitch_frequency
12
12
  from lt_tensor.model_base import Model
13
+ import torch.nn.functional as F
14
+
15
+ DEFAULT_DEVICE = torch.tensor([0]).device
13
16
 
14
17
 
15
18
  class AudioProcessor(Model):
@@ -21,32 +24,30 @@ class AudioProcessor(Model):
21
24
  win_length: Optional[int] = None,
22
25
  hop_length: Optional[int] = None,
23
26
  f_min: float = 0,
24
- f_max: float | None = None,
27
+ f_max: float = 12000.0,
25
28
  center: bool = True,
26
29
  mel_scale: Literal["htk", "slaney"] = "htk",
27
30
  std: int = 4,
28
31
  mean: int = -4,
29
- inverse_transform_config: Union[
30
- Dict[str, Union[Number, Tensor, bool]], InverseTransformConfig
31
- ] = dict(n_fft=16, hop_length=4, win_length=16, center=True),
32
32
  n_iter: int = 32,
33
- *__,
34
- **_,
33
+ window: Optional[Tensor] = None,
34
+ normalized: bool = False,
35
+ onesided: Optional[bool] = None,
35
36
  ):
36
37
  super().__init__()
37
- assert isinstance(inverse_transform_config, (InverseTransformConfig, dict))
38
38
  self.mean = mean
39
39
  self.std = std
40
40
  self.n_mels = n_mels
41
41
  self.n_fft = n_fft
42
42
  self.n_stft = n_fft // 2 + 1
43
43
  self.f_min = f_min
44
- self.f_max = f_max
44
+ self.f_max = max(min(f_max, 12000), self.f_min + 1)
45
45
  self.n_iter = n_iter
46
46
  self.hop_length = hop_length or n_fft // 4
47
47
  self.win_length = win_length or n_fft
48
48
  self.sample_rate = sample_rate
49
- self.mel_spec = torchaudio.transforms.MelSpectrogram(
49
+ self.center = center
50
+ self._mel_spec = torchaudio.transforms.MelSpectrogram(
50
51
  sample_rate=sample_rate,
51
52
  n_mels=n_mels,
52
53
  n_fft=n_fft,
@@ -71,14 +72,260 @@ class AudioProcessor(Model):
71
72
  win_length=win_length,
72
73
  hop_length=hop_length,
73
74
  )
74
- if isinstance(inverse_transform_config, dict):
75
- inverse_transform_config = InverseTransformConfig(
76
- **inverse_transform_config
75
+ self.normalized = normalized
76
+ self.onesided = onesided
77
+
78
+ self.register_buffer(
79
+ "window",
80
+ (torch.hann_window(self.win_length) if window is None else window),
81
+ )
82
+
83
+ def from_numpy(
84
+ self,
85
+ array: np.ndarray,
86
+ device: Optional[torch.device] = None,
87
+ dtype: Optional[torch.dtype] = None,
88
+ ):
89
+ converted = torch.from_numpy(array)
90
+ if not any([device is not None, dtype is not None]):
91
+ return converted
92
+ return converted.to(device=device, dtype=dtype)
93
+
94
+ def from_numpy_batch(
95
+ self,
96
+ arrays: List[np.ndarray],
97
+ device: Optional[torch.device] = None,
98
+ dtype: Optional[torch.dtype] = None,
99
+ ):
100
+ stacked = torch.stack([torch.from_numpy(x) for x in arrays])
101
+ if not any([device is not None, dtype is not None]):
102
+ return stacked
103
+ return stacked.to(device=device, dtype=dtype)
104
+
105
+ def to_numpy_safe(self, tensor: Tensor):
106
+ return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
107
+
108
+ def compute_rms(
109
+ self, audio: Union[Tensor, np.ndarray], mel: Optional[Tensor] = None
110
+ ):
111
+ default_dtype = audio.dtype
112
+ default_device = audio.device
113
+ assert audio.ndim in [1, 2], (
114
+ f"Audio should have 1D for unbatched and 2D for batched"
115
+ ", received instead a: {audio.ndim}D"
116
+ )
117
+ if mel is not None:
118
+ assert mel.ndim in [2, 3], (
119
+ "Mel spectogram should have 2D dim for non-batched or 3D dim for both non-batched or batched"
120
+ f". Received instead {mel.ndim}D."
121
+ )
122
+ if audio.ndim == 2:
123
+ B = audio.shape[0]
124
+ else:
125
+ B = 1
126
+ audio = audio.unsqueeze(0)
127
+
128
+ if mel is not None:
129
+ if mel.ndim == 2:
130
+ assert B == 1, "Batch from mel and audio must be the same!"
131
+ mel = mel.unsqueeze(0)
132
+ else:
133
+ assert B == mel.shape[0], "Batch from mel and audio must be the same!"
134
+ mel = self.to_numpy_safe(mel)
135
+ gt_mel = lambda idx: mel[idx, :, :]
136
+
137
+ else:
138
+ gt_mel = lambda idx: None
139
+ audio = self.to_numpy_safe(audio)
140
+ if B == 1:
141
+ _r = librosa.feature.rms(
142
+ y=audio, frame_length=self.n_fft, hop_length=self.hop_length
143
+ )[0]
144
+ rms = self.from_numpy(_r, default_device, default_dtype)
145
+
146
+ else:
147
+ rms_ = []
148
+ for i in range(B):
149
+ _r = librosa.feature.rms(
150
+ y=audio[i, :],
151
+ S=gt_mel(i),
152
+ frame_length=self.n_fft,
153
+ hop_length=self.hop_length,
154
+ )[0]
155
+ rms_.append(_r)
156
+ rms = self.from_numpy_batch(rms_, default_device, default_dtype)
157
+
158
+ return rms
159
+
160
+ def compute_pitch(
161
+ self,
162
+ audio: Tensor,
163
+ ):
164
+ default_dtype = audio.dtype
165
+ default_device = audio.device
166
+ assert audio.ndim in [1, 2], (
167
+ f"Audio should have 1D for unbatched and 2D for batched"
168
+ ", received instead a: {audio.ndim}D"
169
+ )
170
+ if audio.ndim == 2:
171
+ B = audio.shape[0]
172
+ else:
173
+ B = 1
174
+ fmin = max(self.f_min, 80)
175
+ if B == 1:
176
+ f0 = self.from_numpy(
177
+ librosa.yin(
178
+ self.to_numpy_safe(audio),
179
+ fmin=fmin,
180
+ fmax=self.f_max,
181
+ frame_length=self.n_fft,
182
+ sr=self.sample_rate,
183
+ hop_length=self.hop_length,
184
+ center=self.center,
185
+ ),
186
+ default_device,
187
+ default_dtype,
188
+ )
189
+
190
+ else:
191
+ f0_ = []
192
+ for i in range(B):
193
+ r = librosa.yin(
194
+ self.to_numpy_safe(audio[i, :]),
195
+ fmin=fmin,
196
+ fmax=self.f_max,
197
+ frame_length=self.n_fft,
198
+ sr=self.sample_rate,
199
+ hop_length=self.hop_length,
200
+ center=self.center,
201
+ )
202
+ f0_.append(r)
203
+ f0 = self.from_numpy_batch(f0_, default_device, default_dtype)
204
+
205
+ # librosa.pyin(self.f_min, self.f_max)
206
+ return f0 # dict(f0=f0, attention_mask=f0 != f_max)
207
+
208
+ def compute_pitch_torch(self, audio: Tensor):
209
+ return detect_pitch_frequency(
210
+ audio,
211
+ sample_rate=self.sample_rate,
212
+ frame_time=self.n_fft,
213
+ win_length=self.win_length,
214
+ freq_low=max(self.f_min, 35),
215
+ freq_high=self.f_max,
216
+ )
217
+
218
+ def interpolate_tensor(
219
+ self,
220
+ tensor: Tensor,
221
+ target_len: int,
222
+ mode: Literal[
223
+ "nearest",
224
+ "linear",
225
+ "bilinear",
226
+ "bicubic",
227
+ "trilinear",
228
+ "area",
229
+ "nearest-exact",
230
+ ] = "nearest",
231
+ align_corners: Optional[bool] = None,
232
+ scale_factor: Optional[list[float]] = None,
233
+ recompute_scale_factor: Optional[bool] = None,
234
+ antialias: bool = False,
235
+ ):
236
+ """
237
+ The modes available for upsampling are: `nearest`, `linear` (3D-only),
238
+ `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
239
+ """
240
+
241
+ if tensor.ndim == 2: # [1, T]
242
+ tensor = tensor.unsqueeze(1) # [1, 1, T]
243
+ return F.interpolate(
244
+ tensor,
245
+ size=target_len,
246
+ mode=mode,
247
+ align_corners=align_corners,
248
+ scale_factor=scale_factor,
249
+ recompute_scale_factor=recompute_scale_factor,
250
+ antialias=antialias,
251
+ )
252
+
253
+ def inverse_transform(
254
+ self,
255
+ spec: Tensor,
256
+ phase: Tensor,
257
+ n_fft: Optional[int] = None,
258
+ hop_length: Optional[int] = None,
259
+ win_length: Optional[int] = None,
260
+ length: Optional[int] = None,
261
+ *,
262
+ _recall: bool = False,
263
+ ):
264
+ try:
265
+ return torch.istft(
266
+ spec * torch.exp(phase * 1j),
267
+ n_fft=n_fft or self.n_fft,
268
+ hop_length=hop_length or self.hop_length,
269
+ win_length=win_length or self.win_length,
270
+ window=torch.hann_window(
271
+ win_length or self.win_length, device=spec.device
272
+ ),
273
+ center=self.center,
274
+ normalized=self.normalized,
275
+ onesided=self.onesided,
276
+ length=length,
277
+ return_complex=False,
77
278
  )
78
- self._inv_transform = InverseTransform(**inverse_transform_config.to_dict())
279
+ except RuntimeError as e:
280
+ if not _recall and spec.device != self.window.device:
281
+ self.window = self.window.to(spec.device)
282
+ return self.inverse_transform(
283
+ spec, phase, n_fft, hop_length, win_length, length, _recall=True
284
+ )
285
+ raise e
79
286
 
80
- def inverse_transform(self, spec: Tensor, phase: Tensor, *_, **kwargs):
81
- return self._inv_transform(spec, phase, **kwargs)
287
+ def normalize_audio(
288
+ self,
289
+ wave: Tensor,
290
+ length: Optional[int] = None,
291
+ *,
292
+ _recall: bool = False,
293
+ ):
294
+ try:
295
+ spectrogram = torch.stft(
296
+ input=wave,
297
+ n_fft=self.n_fft,
298
+ hop_length=self.hop_length,
299
+ win_length=self.win_length,
300
+ window=self.window,
301
+ center=self.center,
302
+ pad_mode="reflect",
303
+ normalized=self.normalized,
304
+ onesided=self.onesided,
305
+ return_complex=True, # needed for the istft
306
+ )
307
+ return torch.istft(
308
+ spectrogram
309
+ * torch.full(
310
+ spectrogram.size(),
311
+ fill_value=1,
312
+ device=spectrogram.device,
313
+ ),
314
+ n_fft=self.n_fft,
315
+ hop_length=self.hop_length,
316
+ win_length=self.win_length,
317
+ window=self.window,
318
+ length=length,
319
+ center=self.center,
320
+ normalized=self.normalized,
321
+ onesided=self.onesided,
322
+ return_complex=False,
323
+ )
324
+ except RuntimeError as e:
325
+ if not _recall and wave.device != self.window.device:
326
+ self.window = self.window.to(wave.device)
327
+ return self.normalize_audio(wave, length, _recall=True)
328
+ raise e
82
329
 
83
330
  def compute_mel(
84
331
  self,
@@ -87,21 +334,16 @@ class AudioProcessor(Model):
87
334
  add_base: bool = True,
88
335
  ) -> Tensor:
89
336
  """Returns: [B, M, ML]"""
90
- mel_tensor = self.mel_spec(wave.to(self.device)) # [M, ML]
337
+ mel_tensor = self._mel_spec(wave.to(self.device)) # [M, ML]
91
338
  if not add_base:
92
339
  return (mel_tensor - self.mean) / self.std
93
340
  return (
94
341
  (torch.log(base + mel_tensor.unsqueeze(0)) - self.mean) / self.std
95
342
  ).squeeze()
96
343
 
97
- def reverse_mel(self, mel: Tensor, n_iter: Optional[int] = None):
344
+ def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
98
345
  if isinstance(n_iter, int) and n_iter != self.n_iter:
99
- self.giffin_lim = torchaudio.transforms.GriffinLim(
100
- n_fft=self.n_fft,
101
- n_iter=n_iter,
102
- win_length=self.win_length,
103
- hop_length=self.hop_length,
104
- )
346
+ self.giffin_lim.n_iter = n_iter
105
347
  self.n_iter = n_iter
106
348
  return self.giffin_lim.forward(
107
349
  self.mel_rscale(mel),
@@ -111,21 +353,26 @@ class AudioProcessor(Model):
111
353
  self,
112
354
  path: PathLike,
113
355
  top_db: float = 30,
356
+ normalize: bool = False,
357
+ alpha: float = 1.0,
114
358
  ) -> Tensor:
115
359
  is_file(path, True)
116
360
  wave, sr = librosa.load(str(path), sr=self.sample_rate)
117
361
  wave, _ = librosa.effects.trim(wave, top_db=top_db)
118
- return (
119
- torch.from_numpy(
120
- librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
121
- if sr != self.sample_rate
122
- else wave
123
- )
124
- .float()
125
- .unsqueeze(0)
126
- )
362
+ if sr != self.sample_rate:
363
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
364
+ if normalize:
365
+ wave = librosa.util.normalize(wave)
366
+ if alpha not in [0.0, 1.0]:
367
+ wave = wave * alpha
368
+ return torch.from_numpy(wave).float().unsqueeze(0)
127
369
 
128
- def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
370
+ def find_audios(
371
+ self,
372
+ path: PathLike,
373
+ additional_extensions: List[str] = [],
374
+ maximum: int | None = None,
375
+ ):
129
376
  extensions = [
130
377
  "*.wav",
131
378
  "*.aac",
@@ -141,6 +388,7 @@ class AudioProcessor(Model):
141
388
  return FileScan.files(
142
389
  path,
143
390
  extensions,
391
+ maximum,
144
392
  )
145
393
 
146
394
  def find_audio_text_pairs(
@@ -169,57 +417,6 @@ class AudioProcessor(Model):
169
417
  break
170
418
  return results
171
419
 
172
- def slice_mismatch_outputs(
173
- self,
174
- tensor_1: Tensor,
175
- tensor_2: Tensor,
176
- smallest_size: Optional[int] = None,
177
- left_to_right: bool = True,
178
- ):
179
- assert tensor_1.ndim == tensor_2.ndim, (
180
- "Tensors must have the same dimentions to be sliced! \n"
181
- f"Received instead a tensor_1 with {tensor_1.ndim}D and tensor_2 with {tensor_1.ndim}D."
182
- )
183
- dim = tensor_1.ndim
184
- assert dim < 5, (
185
- "Excpected to receive tensors with from 1D up to 4D. "
186
- f"Received instead a {dim}D tensor."
187
- )
188
-
189
- if tensor_1.shape[-1] == tensor_2.shape[-1]:
190
- return tensor_1, tensor_2
191
-
192
- if smallest_size is None:
193
- smallest_size = min(tensor_1.shape[-1], tensor_2.shape[-1])
194
- if dim == 0:
195
- tensor_1 = tensor_1.unsqueeze(0)
196
- tensor_2 = tensor_2.unsqueeze(0)
197
- dim = 1
198
-
199
- if dim == 1:
200
- if left_to_right:
201
- return tensor_1[:smallest_size], tensor_2[:smallest_size]
202
- return tensor_1[-smallest_size:], tensor_2[-smallest_size:]
203
- elif dim == 2:
204
- if left_to_right:
205
- return tensor_1[:, :smallest_size], tensor_2[:, :smallest_size]
206
- return tensor_1[:, -smallest_size:], tensor_2[:, -smallest_size:]
207
- elif dim == 3:
208
- if left_to_right:
209
- return tensor_1[:, :, :smallest_size], tensor_2[:, :, :smallest_size]
210
- return tensor_1[:, :, -smallest_size:], tensor_2[:, :, -smallest_size:]
211
-
212
- # else:
213
- if left_to_right:
214
- return (
215
- tensor_1[:, :, :, :smallest_size],
216
- tensor_2[:, :, :, :smallest_size],
217
- )
218
- return (
219
- tensor_1[:, :, :, -smallest_size:],
220
- tensor_2[:, :, :, -smallest_size:],
221
- )
222
-
223
420
  def stft_loss(
224
421
  self,
225
422
  signal: Tensor,
@@ -231,11 +428,23 @@ class AudioProcessor(Model):
231
428
  smallest = min(signal.shape[-1], ground.shape[-1])
232
429
  signal = signal[:, -smallest:]
233
430
  ground = ground[:, -smallest:]
234
- sig_mel = self.compute_mel(signal, base, True).detach().cpu()
235
- gnd_mel = self.compute_mel(ground, base, True).detach().cpu()
431
+ sig_mel = self.compute_mel(signal, base, True)
432
+ gnd_mel = self.compute_mel(ground, base, True)
236
433
  return torch.norm(gnd_mel - sig_mel, p=1) / torch.norm(gnd_mel, p=1)
237
434
 
238
- # def forward(self, wave: Tensor, base: Optional[float] = None):
435
+ @staticmethod
436
+ def plot_spectrogram(spectrogram, ax_):
437
+ import matplotlib.pylab as plt
438
+
439
+ fig, ax = plt.subplots(figsize=(10, 2))
440
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
441
+ plt.colorbar(im, ax=ax)
442
+
443
+ fig.canvas.draw()
444
+ plt.close()
445
+
446
+ return fig
447
+
239
448
  def forward(
240
449
  self,
241
450
  *inputs: Union[Tensor, float],
@@ -251,6 +460,6 @@ class AudioProcessor(Model):
251
460
  elif ap_task == "inv_transform":
252
461
  return self.inverse_transform(*inputs, **inputs_kwargs)
253
462
  elif ap_task == "revert_mel":
254
- return self.reverse_mel(*inputs, **inputs_kwargs)
463
+ return self.inverse_mel_spectogram(*inputs, **inputs_kwargs)
255
464
  else:
256
465
  raise ValueError(f"Invalid task '{ap_task}'")
lt_tensor/transform.py CHANGED
@@ -20,14 +20,14 @@ __all__ = [
20
20
  "stft_istft_rebuild",
21
21
  ]
22
22
 
23
- from .torch_commons import *
23
+ from lt_tensor.torch_commons import *
24
24
  import torchaudio
25
25
  import math
26
- from .misc_utils import log_tensor
26
+ from lt_tensor.misc_utils import log_tensor
27
27
  from lt_utils.common import *
28
28
  from lt_utils.misc_utils import cache_wrapper, default
29
29
  import torch.nn.functional as F
30
- from .model_base import Model
30
+ from lt_tensor.model_base import Model
31
31
  import warnings
32
32
 
33
33
 
@@ -316,24 +316,12 @@ def inverse_transform(
316
316
  )
317
317
 
318
318
 
319
- def is_nand(a: bool, b: bool):
320
- """[a -> b = result]
321
- ```
322
- False -> False = True
323
- False -> True = True
324
- True -> False = True
325
- True -> True = False
326
- ```
327
- """
328
- return not (a and b)
329
-
330
-
331
319
  class InverseTransformConfig:
332
320
  def __init__(
333
321
  self,
334
322
  n_fft: int = 1024,
335
- hop_length: Optional[int] = None,
336
323
  win_length: Optional[int] = None,
324
+ hop_length: Optional[int] = None,
337
325
  length: Optional[int] = None,
338
326
  window: Optional[Tensor] = None,
339
327
  onesided: Optional[bool] = None,
@@ -342,8 +330,8 @@ class InverseTransformConfig:
342
330
  center: bool = True,
343
331
  ):
344
332
  self.n_fft = n_fft
345
- self.hop_length = hop_length
346
333
  self.win_length = win_length
334
+ self.hop_length = hop_length
347
335
  self.length = length
348
336
  self.onesided = onesided
349
337
  self.return_complex = return_complex
@@ -359,8 +347,8 @@ class InverseTransform(Model):
359
347
  def __init__(
360
348
  self,
361
349
  n_fft: int = 1024,
362
- hop_length: Optional[int] = None,
363
350
  win_length: Optional[int] = None,
351
+ hop_length: Optional[int] = None,
364
352
  length: Optional[int] = None,
365
353
  window: Optional[Tensor] = None,
366
354
  onesided: Optional[bool] = None,
@@ -378,10 +366,10 @@ class InverseTransform(Model):
378
366
  ----------
379
367
  n_fft : int, optional
380
368
  Size of FFT to use during inversion. Default is 1024.
381
- hop_length : int, optional
382
- Number of audio samples between STFT columns. Defaults to `n_fft`.
383
369
  win_length : int, optional
384
370
  Size of the window function. Defaults to `n_fft // 4`.
371
+ hop_length : int, optional
372
+ Number of audio samples between STFT columns. Defaults to `n_fft`.
385
373
  length : int, optional
386
374
  Output waveform length. If not provided, length will be inferred.
387
375
  window : Tensor, optional
@@ -403,14 +391,8 @@ class InverseTransform(Model):
403
391
  Updates ISTFT parameters dynamically (used internally during forward).
404
392
  """
405
393
  super().__init__()
406
- assert window is None or isinstance(window, Tensor)
407
- assert any(
408
- (
409
- (not return_complex and not onesided),
410
- (not onesided and return_complex),
411
- (not return_complex and onesided),
412
- )
413
- )
394
+ assert window is None or isinstance(window, (Tensor, nn.Module))
395
+ assert not bool(return_complex and onesided)
414
396
  self.n_fft = n_fft
415
397
  self.length = length
416
398
  self.win_length = win_length or n_fft // 4
@@ -419,13 +401,11 @@ class InverseTransform(Model):
419
401
  self.return_complex = return_complex
420
402
  self.onesided = onesided
421
403
  self.normalized = normalized
422
- self.window = torch.hann_window(win_length) if window is None else window
423
-
424
- def _apply_device_to(self):
425
- """Applies to device while used with module `Model`"""
426
- self.window = self.window.to(device=self.device)
404
+ self.register_buffer(
405
+ "window", torch.hann_window(self.win_length) if window is None else window
406
+ )
427
407
 
428
- def forward(self, spec: Tensor, phase: Tensor, **kwargs):
408
+ def forward(self, spec: Tensor, phase: Tensor, *, _recall: bool = False):
429
409
  """
430
410
  Perform the inverse short-time Fourier transform.
431
411
 
@@ -435,24 +415,28 @@ class InverseTransform(Model):
435
415
  Magnitude spectrogram of shape (batch, freq, time).
436
416
  phase : Tensor
437
417
  Phase angles tensor, same shape as `spec`, in radians.
438
- **kwargs : dict, optional
439
- Optional ISTFT override parameters (same as in `update_settings`).
440
418
 
441
419
  Returns
442
420
  -------
443
421
  Tensor
444
422
  Time-domain waveform reconstructed from `spec` and `phase`.
445
423
  """
424
+ try:
425
+ return torch.istft(
426
+ spec * torch.exp(phase * 1j),
427
+ n_fft=self.n_fft,
428
+ hop_length=self.hop_length,
429
+ win_length=self.win_length,
430
+ window=self.window,
431
+ center=self.center,
432
+ normalized=self.normalized,
433
+ onesided=self.onesided,
434
+ length=self.length,
435
+ return_complex=self.return_complex,
436
+ )
437
+ except RuntimeError as e:
438
+ if not _recall and spec.device != self.window.device:
439
+ self.window = self.window.to(spec.device)
440
+ return self.forward(spec, phase, _recall=True)
441
+ raise e
446
442
 
447
- return torch.istft(
448
- spec * torch.exp(phase * 1j),
449
- n_fft = self.n_fft,
450
- hop_length=self.hop_length,
451
- win_length=self.win_length,
452
- window=self.window,
453
- center=self.center,
454
- normalized=self.normalized,
455
- onesided=self.onesided,
456
- length=self.length,
457
- return_complex=self.return_complex,
458
- )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a11
3
+ Version: 0.0.1a13
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -11,14 +11,17 @@ Classifier: Topic :: Software Development :: Libraries
11
11
  Classifier: Topic :: Utilities
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
- Requires-Dist: torch>=2.2.0
15
- Requires-Dist: torchaudio>=2.2.0
14
+ Requires-Dist: torch>=2.7.0
15
+ Requires-Dist: torchaudio>=2.7.0
16
16
  Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.1
21
- Requires-Dist: librosa>=0.11.0
20
+ Requires-Dist: lt-utils>=0.0.2a1
21
+ Requires-Dist: librosa==0.11.*
22
+ Requires-Dist: einops
23
+ Requires-Dist: plotly
24
+ Requires-Dist: scipy
22
25
  Dynamic: author
23
26
  Dynamic: classifier
24
27
  Dynamic: description