lt-tensor 0.0.1a22__py3-none-any.whl → 0.0.1a27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
- __all__ = ["AudioProcessor"]
1
+ __all__ = ["AudioProcessor", "AudioProcessorConfig"]
2
2
  from lt_tensor.torch_commons import *
3
3
  from lt_utils.common import *
4
4
  import librosa
5
5
  import torchaudio
6
6
  import numpy as np
7
7
  from lt_tensor.model_base import Model
8
+ from lt_utils.misc_utils import default
8
9
  from lt_utils.type_utils import is_file, is_array
9
10
  from lt_utils.file_ops import FileScan, get_file_name, path_to_str
10
11
  from torchaudio.functional import detect_pitch_frequency
@@ -12,8 +13,27 @@ import torch.nn.functional as F
12
13
 
13
14
  DEFAULT_DEVICE = torch.tensor([0]).device
14
15
 
16
+ from lt_tensor.config_templates import ModelConfig
17
+
18
+
19
+ class AudioProcessorConfig(ModelConfig):
20
+ sample_rate: int = 24000
21
+ n_mels: int = 80
22
+ n_fft: int = 1024
23
+ win_length: int = 1024
24
+ hop_length: int = 256
25
+ f_min: float = 0
26
+ f_max: float = 8000.0
27
+ center: bool = True
28
+ mel_scale: Literal["htk" "slaney"] = "htk"
29
+ std: int = 4
30
+ mean: int = -4
31
+ n_iter: int = 32
32
+ window: Optional[Tensor] = None
33
+ normalized: bool = False
34
+ onesided: Optional[bool] = None
35
+ n_stft: int = None
15
36
 
16
- class AudioProcessor(Model):
17
37
  def __init__(
18
38
  self,
19
39
  sample_rate: int = 24000,
@@ -21,63 +41,95 @@ class AudioProcessor(Model):
21
41
  n_fft: int = 1024,
22
42
  win_length: Optional[int] = None,
23
43
  hop_length: Optional[int] = None,
24
- f_min: float = 0,
44
+ f_min: float = 1,
25
45
  f_max: float = 12000.0,
26
46
  center: bool = True,
27
47
  mel_scale: Literal["htk", "slaney"] = "htk",
28
48
  std: int = 4,
29
49
  mean: int = -4,
30
- n_iter: int = 32,
31
- window: Optional[Tensor] = None,
32
50
  normalized: bool = False,
33
51
  onesided: Optional[bool] = None,
52
+ *args,
53
+ **kwargs,
54
+ ):
55
+ settings = {
56
+ "sample_rate": sample_rate,
57
+ "n_mels": n_mels,
58
+ "n_fft": n_fft,
59
+ "win_length": win_length,
60
+ "hop_length": hop_length,
61
+ "f_min": f_min,
62
+ "f_max": f_max,
63
+ "center": center,
64
+ "mel_scale": mel_scale,
65
+ "std": std,
66
+ "mean": mean,
67
+ "normalized": normalized,
68
+ "onesided": onesided,
69
+ }
70
+ super().__init__(**settings)
71
+ self.post_process()
72
+
73
+ def post_process(self):
74
+ self.f_min = max(self.f_min, 1)
75
+ self.f_max = max(min(self.f_max, self.n_fft // 2), self.f_min + 1)
76
+ self.n_stft = self.n_fft // 2 + 1
77
+ self.hop_length = default(self.hop_length, self.n_fft // 4)
78
+ self.win_length = default(self.win_length, self.n_fft)
79
+
80
+
81
+ def _comp_rms_helper(i: int, audio: Tensor, mel: Optional[Tensor]):
82
+ if mel is None:
83
+ return {"y": audio[i, :]}
84
+ return {"y": audio[i, :], "S": mel[i, :, :]}
85
+
86
+
87
+ class AudioProcessor(Model):
88
+ def __init__(
89
+ self,
90
+ config: AudioProcessorConfig = AudioProcessorConfig(),
91
+ window: Optional[Tensor] = None,
34
92
  ):
35
93
  super().__init__()
36
- self.mean = mean
37
- self.std = std
38
- self.n_mels = n_mels
39
- self.n_fft = n_fft
40
- self.n_stft = n_fft // 2 + 1
41
- self.f_min = f_min
42
- self.f_max = max(min(f_max, 12000), self.f_min + 1)
43
- self.n_iter = n_iter
44
- self.hop_length = hop_length or n_fft // 4
45
- self.win_length = win_length or n_fft
46
- self.sample_rate = sample_rate
47
- self.center = center
94
+ self.cfg = config
48
95
  self._mel_spec = torchaudio.transforms.MelSpectrogram(
49
- sample_rate=sample_rate,
50
- n_mels=n_mels,
51
- n_fft=n_fft,
52
- win_length=win_length,
53
- hop_length=hop_length,
54
- center=center,
55
- f_min=f_min,
56
- f_max=f_max,
57
- mel_scale=mel_scale,
96
+ sample_rate=self.cfg.sample_rate,
97
+ n_mels=self.cfg.n_mels,
98
+ n_fft=self.cfg.n_fft,
99
+ win_length=self.cfg.win_length,
100
+ hop_length=self.cfg.hop_length,
101
+ center=self.cfg.center,
102
+ f_min=self.cfg.f_min,
103
+ f_max=self.cfg.f_max,
104
+ mel_scale=self.cfg.mel_scale,
105
+ onesided=self.cfg.onesided,
106
+ normalized=self.cfg.normalized,
58
107
  )
108
+ self.griffin_lm_iters = 32
59
109
  self.mel_rscale = torchaudio.transforms.InverseMelScale(
60
- n_stft=self.n_stft,
61
- n_mels=n_mels,
62
- sample_rate=sample_rate,
63
- f_min=f_min,
64
- f_max=f_max,
65
- mel_scale=mel_scale,
110
+ n_stft=self.cfg.n_stft,
111
+ n_mels=self.cfg.n_mels,
112
+ sample_rate=self.cfg.sample_rate,
113
+ f_min=self.cfg.f_min,
114
+ f_max=self.cfg.f_max,
115
+ mel_scale=self.cfg.mel_scale,
66
116
  )
67
117
  self.giffin_lim = torchaudio.transforms.GriffinLim(
68
- n_fft=n_fft,
69
- n_iter=n_iter,
70
- win_length=win_length,
71
- hop_length=hop_length,
118
+ n_fft=self.cfg.n_fft,
119
+ win_length=self.cfg.win_length,
120
+ hop_length=self.cfg.hop_length,
72
121
  )
73
- self.normalized = normalized
74
- self.onesided = onesided
75
-
76
122
  self.register_buffer(
77
123
  "window",
78
- (torch.hann_window(self.win_length) if window is None else window),
124
+ (torch.hann_window(self.cfg.win_length) if window is None else window),
79
125
  )
80
126
 
127
+ def _apply_device(self):
128
+ print(f"Audio Processor Device: {self.device.type}")
129
+ self.giffin_lim.to(device=self.device)
130
+ self._mel_spec.to(device=self.device)
131
+ self.mel_rscale.to(device=self.device)
132
+
81
133
  def from_numpy(
82
134
  self,
83
135
  array: np.ndarray,
@@ -85,8 +137,8 @@ class AudioProcessor(Model):
85
137
  dtype: Optional[torch.dtype] = None,
86
138
  ):
87
139
  converted = torch.from_numpy(array)
88
- if not any([device is not None, dtype is not None]):
89
- return converted
140
+ if device is None:
141
+ device = self.device
90
142
  return converted.to(device=device, dtype=dtype)
91
143
 
92
144
  def from_numpy_batch(
@@ -96,18 +148,32 @@ class AudioProcessor(Model):
96
148
  dtype: Optional[torch.dtype] = None,
97
149
  ):
98
150
  stacked = torch.stack([torch.from_numpy(x) for x in arrays])
99
- if not any([device is not None, dtype is not None]):
100
- return stacked
151
+ if device is None:
152
+ device = self.device
101
153
  return stacked.to(device=device, dtype=dtype)
102
154
 
103
- def to_numpy_safe(self, tensor: Tensor):
155
+ def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
156
+ if isinstance(tensor, np.ndarray):
157
+ return tensor
104
158
  return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
105
159
 
106
160
  def compute_rms(
107
161
  self,
108
- audio: Union[Tensor, np.ndarray],
162
+ audio: Optional[Union[Tensor, np.ndarray]] = None,
109
163
  mel: Optional[Tensor] = None,
164
+ frame_length: Optional[int] = None,
165
+ hop_length: Optional[int] = None,
166
+ center: Optional[int] = None,
110
167
  ):
168
+ assert any([audio is not None, mel is not None])
169
+ rms_kwargs = dict(
170
+ frame_length=default(frame_length, self.cfg.n_fft),
171
+ hop_length=default(hop_length, self.cfg.hop_length),
172
+ center=default(center, self.cfg.center),
173
+ )
174
+
175
+ if audio is None and mel is not None:
176
+ return self.from_numpy(librosa.feature.rms(S=mel, **rms_kwargs)[0])
111
177
  default_dtype = audio.dtype
112
178
  default_device = audio.device
113
179
  if audio.ndim > 1:
@@ -123,34 +189,32 @@ class AudioProcessor(Model):
123
189
  else:
124
190
  assert B == mel.shape[0], "Batch from mel and audio must be the same!"
125
191
  mel = self.to_numpy_safe(mel)
126
- gt_mel = lambda idx: mel[idx, :, :]
127
-
128
- else:
129
- gt_mel = lambda idx: None
130
192
  audio = self.to_numpy_safe(audio)
131
193
  if B == 1:
132
- _r = librosa.feature.rms(
133
- y=audio, frame_length=self.n_fft, hop_length=self.hop_length
134
- )[0]
135
- rms = self.from_numpy(_r, default_device, default_dtype)
136
-
194
+ if mel is None:
195
+ return self.from_numpy(librosa.feature.rms(y=audio, **rms_kwargs)[0])
196
+ return self.from_numpy(librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0])
137
197
  else:
138
198
  rms_ = []
139
199
  for i in range(B):
140
- _r = librosa.feature.rms(
141
- y=audio[i, :],
142
- S=gt_mel(i),
143
- frame_length=self.n_fft,
144
- hop_length=self.hop_length,
145
- )[0]
200
+ _r = librosa.feature.rms(_comp_rms_helper(i, audio, mel), **rms_kwargs)[
201
+ 0
202
+ ]
146
203
  rms_.append(_r)
147
- rms = self.from_numpy_batch(rms_, default_device, default_dtype)
148
-
149
- return rms
204
+ return self.from_numpy_batch(rms_, default_device, default_dtype)
150
205
 
151
206
  def compute_pitch(
152
207
  self,
153
208
  audio: Tensor,
209
+ *,
210
+ pad_mode: str = "constant",
211
+ trough_threshold: float = 0.1,
212
+ fmin: Optional[float] = None,
213
+ fmax: Optional[float] = None,
214
+ sr: Optional[float] = None,
215
+ frame_length: Optional[int] = None,
216
+ hop_length: Optional[int] = None,
217
+ center: Optional[bool] = None,
154
218
  ):
155
219
  default_dtype = audio.dtype
156
220
  default_device = audio.device
@@ -158,18 +222,25 @@ class AudioProcessor(Model):
158
222
  B = audio.shape[0]
159
223
  else:
160
224
  B = 1
161
- fmin = max(self.f_min, 80)
225
+ sr = default(sr, self.cfg.sample_rate)
226
+ fmin = max(default(fmin, self.cfg.f_min), 65)
227
+ fmax = min(default(fmax, self.cfg.f_max), sr // 2)
228
+ frame_length = default(frame_length, self.cfg.n_fft)
229
+ hop_length = default(hop_length, self.cfg.hop_length)
230
+ center = default(center, self.cfg.center)
231
+ yn_kwargs = dict(
232
+ fmin=fmin,
233
+ fmax=fmax,
234
+ frame_length=frame_length,
235
+ sr=sr,
236
+ hop_length=hop_length,
237
+ center=center,
238
+ trough_threshold=trough_threshold,
239
+ pad_mode=pad_mode,
240
+ )
162
241
  if B == 1:
163
242
  f0 = self.from_numpy(
164
- librosa.yin(
165
- self.to_numpy_safe(audio),
166
- fmin=fmin,
167
- fmax=min(self.f_max, self.sample_rate // 2 - 1),
168
- frame_length=self.n_fft,
169
- sr=self.sample_rate,
170
- hop_length=self.hop_length,
171
- center=self.center,
172
- ),
243
+ librosa.yin(self.to_numpy_safe(audio), **yn_kwargs),
173
244
  default_device,
174
245
  default_dtype,
175
246
  )
@@ -177,32 +248,34 @@ class AudioProcessor(Model):
177
248
  else:
178
249
  f0_ = []
179
250
  for i in range(B):
180
- r = librosa.yin(
181
- self.to_numpy_safe(audio[i, :]),
182
- fmin=fmin,
183
- fmax=self.f_max,
184
- frame_length=self.n_fft,
185
- sr=self.sample_rate,
186
- hop_length=self.hop_length,
187
- center=self.center,
188
- )
189
- f0_.append(r)
251
+ f0_.append(librosa.yin(self.to_numpy_safe(audio[i, :]), **yn_kwargs))
190
252
  f0 = self.from_numpy_batch(f0_, default_device, default_dtype)
253
+ return f0.squeeze()
191
254
 
192
- # librosa.pyin(self.f_min, self.f_max)
193
- return f0 # dict(f0=f0, attention_mask=f0 != f_max)
194
-
195
- def compute_pitch_torch(self, audio: Tensor):
255
+ def compute_pitch_torch(
256
+ self,
257
+ audio: Tensor,
258
+ fmin: Optional[float] = None,
259
+ fmax: Optional[float] = None,
260
+ sr: Optional[float] = None,
261
+ win_length: Optional[Number] = None,
262
+ frame_length: Optional[Number] = None,
263
+ ):
264
+ sr = default(sr, self.sample_rate)
265
+ fmin = max(default(fmin, self.f_min), 1)
266
+ fmax = min(default(fmax, self.f_max), sr // 2)
267
+ win_length = default(win_length, self.win_length)
268
+ frame_length = default(frame_length, self.n_fft)
196
269
  return detect_pitch_frequency(
197
270
  audio,
198
- sample_rate=self.sample_rate,
199
- frame_time=self.n_fft,
200
- win_length=self.win_length,
201
- freq_low=max(self.f_min, 35),
202
- freq_high=self.f_max,
271
+ sample_rate=sr,
272
+ frame_time=frame_length,
273
+ win_length=win_length,
274
+ freq_low=fmin,
275
+ freq_high=fmax,
203
276
  )
204
277
 
205
- def interpolate_tensor(
278
+ def interpolate(
206
279
  self,
207
280
  tensor: Tensor,
208
281
  target_len: int,
@@ -227,6 +300,8 @@ class AudioProcessor(Model):
227
300
 
228
301
  if tensor.ndim == 2: # [1, T]
229
302
  tensor = tensor.unsqueeze(1) # [1, 1, T]
303
+ elif tensor.ndim == 1:
304
+ tensor = tensor.unsqueeze(0).unsqueeze(0) # [1, 1, T]
230
305
  return F.interpolate(
231
306
  tensor,
232
307
  size=target_len,
@@ -248,18 +323,21 @@ class AudioProcessor(Model):
248
323
  *,
249
324
  _recall: bool = False,
250
325
  ):
326
+ if win_length is not None and win_length != self.cfg.win_length:
327
+ window = torch.hann_window(win_length, device=spec.device)
328
+ else:
329
+ window = self.window
330
+
251
331
  try:
252
332
  return torch.istft(
253
333
  spec * torch.exp(phase * 1j),
254
- n_fft=n_fft or self.n_fft,
255
- hop_length=hop_length or self.hop_length,
256
- win_length=win_length or self.win_length,
257
- window=torch.hann_window(
258
- win_length or self.win_length, device=spec.device
259
- ),
260
- center=self.center,
261
- normalized=self.normalized,
262
- onesided=self.onesided,
334
+ n_fft=n_fft or self.cfg.n_fft,
335
+ hop_length=hop_length or self.cfg.hop_length,
336
+ win_length=win_length or self.cfg.win_length,
337
+ window=window,
338
+ center=self.cfg.center,
339
+ normalized=self.cfg.normalized,
340
+ onesided=self.cfg.onesided,
263
341
  length=length,
264
342
  return_complex=False,
265
343
  )
@@ -281,15 +359,15 @@ class AudioProcessor(Model):
281
359
  try:
282
360
  spectrogram = torch.stft(
283
361
  input=wave,
284
- n_fft=self.n_fft,
285
- hop_length=self.hop_length,
286
- win_length=self.win_length,
362
+ n_fft=self.cfg.n_fft,
363
+ hop_length=self.cfg.hop_length,
364
+ win_length=self.cfg.win_length,
287
365
  window=self.window,
288
- center=self.center,
366
+ center=self.cfg.center,
289
367
  pad_mode="reflect",
290
- normalized=self.normalized,
291
- onesided=self.onesided,
292
- return_complex=True, # needed for the istft
368
+ normalized=self.cfg.normalized,
369
+ onesided=self.cfg.onesided,
370
+ return_complex=True,
293
371
  )
294
372
  return torch.istft(
295
373
  spectrogram
@@ -298,14 +376,14 @@ class AudioProcessor(Model):
298
376
  fill_value=1,
299
377
  device=spectrogram.device,
300
378
  ),
301
- n_fft=self.n_fft,
302
- hop_length=self.hop_length,
303
- win_length=self.win_length,
379
+ n_fft=self.cfg.n_fft,
380
+ hop_length=self.cfg.hop_length,
381
+ win_length=self.cfg.win_length,
304
382
  window=self.window,
305
383
  length=length,
306
- center=self.center,
307
- normalized=self.normalized,
308
- onesided=self.onesided,
384
+ center=self.cfg.center,
385
+ normalized=self.cfg.normalized,
386
+ onesided=self.cfg.onesided,
309
387
  return_complex=False,
310
388
  )
311
389
  except RuntimeError as e:
@@ -317,21 +395,30 @@ class AudioProcessor(Model):
317
395
  def compute_mel(
318
396
  self,
319
397
  wave: Tensor,
320
- base: float = 1e-5,
321
- add_base: bool = True,
398
+ raw_mel_only: bool = False,
399
+ eps: float = 1e-5,
400
+ *,
401
+ _recall: bool = False,
322
402
  ) -> Tensor:
323
- """Returns: [B, M, ML]"""
324
- mel_tensor = self._mel_spec(wave.to(self.device)) # [M, ML]
325
- if not add_base:
326
- return (mel_tensor - self.mean) / self.std
327
- return (
328
- (torch.log(base + mel_tensor.unsqueeze(0)) - self.mean) / self.std
329
- ).squeeze()
403
+ """Returns: [B, M, T]"""
404
+ try:
405
+ mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
406
+ if not raw_mel_only:
407
+ mel_tensor = (
408
+ torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
409
+ ) / self.cfg.std
410
+ return mel_tensor.squeeze()
411
+
412
+ except RuntimeError as e:
413
+ if not _recall:
414
+ self._mel_spec.to(self.device)
415
+ return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
416
+ raise e
330
417
 
331
418
  def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
332
- if isinstance(n_iter, int) and n_iter != self.n_iter:
419
+ if isinstance(n_iter, int) and n_iter != self.griffin_lm_iters:
333
420
  self.giffin_lim.n_iter = n_iter
334
- self.n_iter = n_iter
421
+ self.griffin_lm_iters = n_iter
335
422
  return self.giffin_lim.forward(
336
423
  self.mel_rscale(mel),
337
424
  )
@@ -341,18 +428,59 @@ class AudioProcessor(Model):
341
428
  path: PathLike,
342
429
  top_db: float = 30,
343
430
  normalize: bool = False,
344
- alpha: float = 1.0,
431
+ *,
432
+ ref: float | Callable[[np.ndarray], Any] = np.max,
433
+ frame_length: int = 2048,
434
+ hop_length: int = 512,
435
+ mono: bool = True,
436
+ offset: float = 0.0,
437
+ duration: Optional[float] = None,
438
+ dtype: Any = np.float32,
439
+ res_type: str = "soxr_hq",
440
+ fix: bool = True,
441
+ scale: bool = False,
442
+ axis: int = -1,
443
+ norm: Optional[float] = np.inf,
444
+ norm_axis: int = 0,
445
+ norm_threshold: Optional[float] = None,
446
+ norm_fill: Optional[bool] = None,
345
447
  ) -> Tensor:
346
448
  is_file(path, True)
347
- wave, sr = librosa.load(str(path), sr=self.sample_rate)
348
- wave, _ = librosa.effects.trim(wave, top_db=top_db)
349
- if sr != self.sample_rate:
350
- wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
449
+ wave, sr = librosa.load(
450
+ str(path),
451
+ sr=self.cfg.sample_rate,
452
+ mono=mono,
453
+ offset=offset,
454
+ duration=duration,
455
+ dtype=dtype,
456
+ res_type=res_type,
457
+ )
458
+ wave, _ = librosa.effects.trim(
459
+ wave,
460
+ top_db=top_db,
461
+ ref=ref,
462
+ frame_length=frame_length,
463
+ hop_length=hop_length,
464
+ )
465
+ if sr != self.cfg.sample_rate:
466
+ wave = librosa.resample(
467
+ wave,
468
+ orig_sr=sr,
469
+ target_sr=self.cfg.sample_rate,
470
+ res_type=res_type,
471
+ fix=fix,
472
+ scale=scale,
473
+ axis=axis,
474
+ )
351
475
  if normalize:
352
- wave = librosa.util.normalize(wave)
353
- if alpha not in [0.0, 1.0]:
354
- wave = wave * alpha
355
- return torch.from_numpy(wave).float().unsqueeze(0)
476
+ wave = librosa.util.normalize(
477
+ wave,
478
+ norm=norm,
479
+ axis=norm_axis,
480
+ threshold=norm_threshold,
481
+ fill=norm_fill,
482
+ )
483
+ return torch.from_numpy(wave).float().unsqueeze(0).to(self.device)
356
484
 
357
485
  def find_audios(
358
486
  self,
@@ -378,79 +506,18 @@ class AudioProcessor(Model):
378
506
  maximum,
379
507
  )
380
508
 
381
- def find_audio_text_pairs(
382
- self,
383
- path,
384
- additional_extensions: List[str] = [],
385
- text_file_patterns: List[str] = [".normalized.txt", ".original.txt"],
386
- ):
387
- is_array(text_file_patterns, True, validate=True) # Rases if empty or not valid
388
- additional_extensions = [
389
- x
390
- for x in additional_extensions
391
- if isinstance(x, str)
392
- and "*" in x
393
- and not any(list(map(lambda y: y in x), text_file_patterns))
394
- ]
395
- audio_files = self.find_audios(path, additional_extensions)
396
- results = []
397
- for audio in audio_files:
398
- base_audio_dir = Path(audio).parent
399
- audio_name = get_file_name(audio, False)
400
- for pattern in text_file_patterns:
401
- possible_txt_file = Path(base_audio_dir, audio_name + pattern)
402
- if is_file(possible_txt_file):
403
- results.append((audio, path_to_str(possible_txt_file)))
404
- break
405
- return results
406
-
407
509
  def stft_loss(
408
510
  self,
409
511
  signal: Tensor,
410
512
  ground: Tensor,
411
513
  ):
412
- ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
413
- if ground.ndim != signal.ndim:
414
- assert ground.ndim in [1, 2, 3]
415
- assert signal.ndim in [1, 2, 3]
416
- if ground.ndim == 3:
417
- ground = ground.squeeze(1)
418
- elif ground.ndim == 1:
419
- ground = ground.unsqueeze(0)
420
- if signal.ndim == 3:
421
- signal = signal.squeeze(1)
422
- elif signal.ndim == 1:
423
- signal = signal.unsqueeze(0)
514
+ with torch.no_grad():
515
+ ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
424
516
  return F.l1_loss(signal, ground)
425
517
 
426
- @staticmethod
427
- def plot_spectrogram(spectrogram, ax_):
428
- import matplotlib.pylab as plt
429
-
430
- fig, ax = plt.subplots(figsize=(10, 2))
431
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
432
- plt.colorbar(im, ax=ax)
433
-
434
- fig.canvas.draw()
435
- plt.close()
436
-
437
- return fig
438
-
439
518
  def forward(
440
519
  self,
441
520
  *inputs: Union[Tensor, float],
442
- ap_task: Literal[
443
- "get_mel", "get_loss", "inv_transform", "revert_mel"
444
- ] = "get_mel",
445
521
  **inputs_kwargs,
446
522
  ):
447
- if ap_task == "get_mel":
448
- return self.compute_mel(*inputs, **inputs_kwargs)
449
- elif ap_task == "get_loss":
450
- return self.stft_loss(*inputs, **inputs_kwargs)
451
- elif ap_task == "inv_transform":
452
- return self.inverse_transform(*inputs, **inputs_kwargs)
453
- elif ap_task == "revert_mel":
454
- return self.inverse_mel_spectogram(*inputs, **inputs_kwargs)
455
- else:
456
- raise ValueError(f"Invalid task '{ap_task}'")
523
+ return self.compute_mel(*inputs, **inputs_kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a22
3
+ Version: 0.0.1a27
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336