lt-tensor 0.0.1a36__py3-none-any.whl → 0.0.1a37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a36"
1
+ __version__ = "0.0.1a37"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
@@ -10,6 +10,7 @@ from lt_utils.type_utils import is_file, is_array
10
10
  from lt_utils.file_ops import FileScan, get_file_name, path_to_str
11
11
  from torchaudio.functional import detect_pitch_frequency
12
12
  import torch.nn.functional as F
13
+ from librosa.filters import mel as _mel_filter_bank
13
14
 
14
15
  DEFAULT_DEVICE = torch.tensor([0]).device
15
16
 
@@ -25,7 +26,7 @@ class AudioProcessorConfig(ModelConfig):
25
26
  f_min: float = 0
26
27
  f_max: Optional[float] = None
27
28
  center: bool = True
28
- mel_scale: Literal["htk" "slaney"] = "htk"
29
+ mel_scale: Literal["htk", "slaney"] = "htk"
29
30
  std: int = 4
30
31
  mean: int = -4
31
32
  n_iter: int = 32
@@ -33,6 +34,7 @@ class AudioProcessorConfig(ModelConfig):
33
34
  normalized: bool = False
34
35
  onesided: Optional[bool] = None
35
36
  n_stft: int = None
37
+ mel_default: Literal["torch", "librosa"] = "librosa"
36
38
 
37
39
  def __init__(
38
40
  self,
@@ -49,6 +51,7 @@ class AudioProcessorConfig(ModelConfig):
49
51
  mean: int = -4,
50
52
  normalized: bool = False,
51
53
  onesided: Optional[bool] = None,
54
+ mel_default: Literal["torch", "librosa"] = "librosa",
52
55
  *args,
53
56
  **kwargs,
54
57
  ):
@@ -66,6 +69,7 @@ class AudioProcessorConfig(ModelConfig):
66
69
  "mean": mean,
67
70
  "normalized": normalized,
68
71
  "onesided": onesided,
72
+ "mel_default": mel_default,
69
73
  }
70
74
  super().__init__(**settings)
71
75
  self.post_process()
@@ -88,14 +92,10 @@ def _comp_rms_helper(i: int, audio: Tensor, mel: Optional[Tensor]):
88
92
 
89
93
 
90
94
  class AudioProcessor(Model):
91
- def __init__(
92
- self,
93
- config: AudioProcessorConfig = AudioProcessorConfig(),
94
- window: Optional[Tensor] = None,
95
- ):
95
+ def __init__(self, config: AudioProcessorConfig = AudioProcessorConfig()):
96
96
  super().__init__()
97
97
  self.cfg = config
98
- self._mel_spec = torchaudio.transforms.MelSpectrogram(
98
+ self._mel_spec_torch = torchaudio.transforms.MelSpectrogram(
99
99
  sample_rate=self.cfg.sample_rate,
100
100
  n_mels=self.cfg.n_mels,
101
101
  n_fft=self.cfg.n_fft,
@@ -107,6 +107,7 @@ class AudioProcessor(Model):
107
107
  mel_scale=self.cfg.mel_scale,
108
108
  normalized=self.cfg.normalized,
109
109
  )
110
+
110
111
  self._mel_rscale = torchaudio.transforms.InverseMelScale(
111
112
  n_stft=self.cfg.n_stft,
112
113
  n_mels=self.cfg.n_mels,
@@ -115,34 +116,119 @@ class AudioProcessor(Model):
115
116
  f_max=self.cfg.f_max,
116
117
  mel_scale=self.cfg.mel_scale,
117
118
  )
118
-
119
+ self.mel_lib_padding = (self.cfg.n_fft - self.cfg.hop_length) // 2
119
120
  self.register_buffer(
120
121
  "window",
121
- (torch.hann_window(self.cfg.win_length) if window is None else window),
122
+ torch.hann_window(self.cfg.win_length),
122
123
  )
124
+ self.register_buffer(
125
+ "mel_filter_bank",
126
+ torch.from_numpy(
127
+ _mel_filter_bank(
128
+ sr=self.cfg.sample_rate,
129
+ n_fft=self.cfg.n_fft,
130
+ n_mels=self.cfg.n_mels,
131
+ fmin=self.cfg.f_min,
132
+ fmax=self.cfg.f_max,
133
+ )
134
+ ).float(),
135
+ )
136
+
137
+ def spectral_norm(self, x: Tensor, c: int = 1, eps: float = 1e-5):
138
+ return torch.log(torch.clamp(x, min=eps) * c)
139
+
140
+ def spectral_de_norm(self, x: Tensor, c: int = 1):
141
+ return torch.exp(x) / c
142
+
143
+ def log_norm(
144
+ self,
145
+ entry: Tensor,
146
+ eps: float = 1e-5,
147
+ mean: Optional[Number] = None,
148
+ std: Optional[Number] = None,
149
+ ) -> Tensor:
150
+ mean = default(mean, self.cfg.mean)
151
+ std = default(std, self.cfg.std)
152
+ return (torch.log(eps + entry.unsqueeze(0)) - mean) / std
123
153
 
124
154
  def compute_mel(
125
155
  self,
126
156
  wave: Tensor,
127
- eps: float = 1e-5,
128
- raw_mel_only: bool = False,
129
- *,
130
- _recall: bool = False,
157
+ method: Optional[Literal["torch", "librosa"]] = None,
158
+ apply_norm: bool = False,
159
+ eps: Optional[float] = None,
160
+ **kwargs,
161
+ ) -> Tensor:
162
+ method = default(method, self.cfg.mel_default)
163
+ if method == "torch":
164
+ return self.compute_mel_torch(
165
+ wave,
166
+ log_norm=apply_norm,
167
+ eps=eps,
168
+ mean=kwargs.get("mean", None),
169
+ std=kwargs.get("std", None),
170
+ )
171
+ return self.compute_mel_librosa(
172
+ wave,
173
+ log_norm=apply_norm,
174
+ eps=eps,
175
+ )
176
+
177
+ def compute_mel_torch(
178
+ self,
179
+ wave: Tensor,
180
+ log_norm: bool = False,
181
+ eps: Optional[float] = None,
182
+ mean: Optional[Number] = None,
183
+ std: Optional[Number] = None,
184
+ *args,
185
+ **kwargs,
131
186
  ) -> Tensor:
132
187
  """Returns: (M, T) or (B, M, T) if batched"""
133
188
  try:
134
- mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
135
- if not raw_mel_only:
136
- mel_tensor = (
137
- torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
138
- ) / self.cfg.std
139
- return mel_tensor.squeeze()
189
+ mel_tensor = self._mel_spec_torch.forward(wave.to(self.device)) # [M, T]
140
190
 
141
191
  except RuntimeError as e:
142
- if not _recall:
143
- self._mel_spec.to(self.device)
144
- return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
145
- raise e
192
+ mel_tensor = self._mel_spec_torch.forward(wave.to(self.device)) # [M, T]
193
+ if log_norm:
194
+ return self.log_norm(mel_tensor, eps, mean, std).squeeze()
195
+ return mel_tensor.squeeze()
196
+
197
+ def compute_mel_librosa(
198
+ self,
199
+ wave: Tensor,
200
+ eps: float = 1e-5,
201
+ spectral_norm: bool = False,
202
+ *args,
203
+ **kwargs,
204
+ ):
205
+ wave = torch.nn.functional.pad(
206
+ wave.unsqueeze(1),
207
+ (self.mel_lib_padding, self.mel_lib_padding),
208
+ mode="reflect",
209
+ ).squeeze(1)
210
+ spec = torch.stft(
211
+ wave,
212
+ self.cfg.n_fft,
213
+ hop_length=self.cfg.hop_length,
214
+ win_length=self.cfg.win_length,
215
+ window=self.window,
216
+ center=self.cfg.center,
217
+ pad_mode="reflect",
218
+ normalized=False,
219
+ onesided=True,
220
+ return_complex=True,
221
+ )
222
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-12)
223
+ try:
224
+ results = torch.matmul(self.mel_filter_bank, spec)
225
+ except RuntimeError:
226
+ self.mel_filter_bank = self.mel_filter_bank.to(self.device)
227
+ self.window = self.window.to(self.device)
228
+ results = torch.matmul(self.mel_filter_bank, spec)
229
+ if spectral_norm:
230
+ return self.spectral_norm(results, eps=eps).squeeze()
231
+ return results.squeeze()
146
232
 
147
233
  def compute_inverse_mel(self, melspec: Tensor, *, _recall=False):
148
234
  try:
@@ -382,7 +468,7 @@ class AudioProcessor(Model):
382
468
  antialias=antialias,
383
469
  )
384
470
 
385
- def istft(
471
+ def istft_spec_phase(
386
472
  self,
387
473
  spec: Tensor,
388
474
  phase: Tensor,
@@ -394,34 +480,91 @@ class AudioProcessor(Model):
394
480
  normalized: Optional[bool] = None,
395
481
  onesided: Optional[bool] = None,
396
482
  return_complex: bool = False,
397
- *,
398
- _recall: bool = False,
399
483
  ):
400
- if win_length is not None and win_length != self.cfg.win_length:
401
- window = torch.hann_window(win_length, device=spec.device)
402
- else:
403
- window = self.window
484
+ """Util for models that needs to reconstruct the audio using inverse stft"""
485
+ window = (
486
+ torch.hann_window(win_length, device=spec.device)
487
+ if win_length is not None and win_length != self.cfg.win_length
488
+ else self.window.to(spec.device)
489
+ )
490
+ return torch.istft(
491
+ spec * torch.exp(phase * 1j),
492
+ n_fft=default(n_fft, self.cfg.n_fft),
493
+ hop_length=default(hop_length, self.cfg.hop_length),
494
+ win_length=default(win_length, self.cfg.win_length),
495
+ window=window,
496
+ center=center,
497
+ normalized=default(normalized, self.cfg.normalized),
498
+ onesided=default(onesided, self.cfg.onesided),
499
+ length=length,
500
+ return_complex=return_complex,
501
+ )
404
502
 
405
- try:
406
- return torch.istft(
407
- spec * torch.exp(phase * 1j),
408
- n_fft=default(n_fft, self.cfg.n_fft),
409
- hop_length=default(hop_length, self.cfg.hop_length),
410
- win_length=default(win_length, self.cfg.win_length),
411
- window=window,
412
- center=center,
413
- normalized=default(normalized, self.cfg.normalized),
414
- onesided=default(onesided, self.cfg.onesided),
415
- length=length,
416
- return_complex=return_complex,
417
- )
418
- except RuntimeError as e:
419
- if not _recall and spec.device != self.window.device:
420
- self.window = self.window.to(spec.device)
421
- return self.istft(
422
- spec, phase, n_fft, hop_length, win_length, length, _recall=True
423
- )
424
- raise e
503
+ def istft(
504
+ self,
505
+ wave: Tensor,
506
+ n_fft: Optional[int] = None,
507
+ hop_length: Optional[int] = None,
508
+ win_length: Optional[int] = None,
509
+ length: Optional[int] = None,
510
+ center: bool = True,
511
+ normalized: Optional[bool] = None,
512
+ onesided: Optional[bool] = None,
513
+ return_complex: bool = False,
514
+ ):
515
+ window = (
516
+ torch.hann_window(win_length, device=wave.device)
517
+ if win_length is not None and win_length != self.cfg.win_length
518
+ else self.window.to(wave.device)
519
+ )
520
+ if not torch.is_complex(wave):
521
+ wave = wave * 1j
522
+ return torch.istft(
523
+ wave,
524
+ n_fft=default(n_fft, self.cfg.n_fft),
525
+ hop_length=default(hop_length, self.cfg.hop_length),
526
+ win_length=default(win_length, self.cfg.win_length),
527
+ window=window,
528
+ center=center,
529
+ normalized=default(normalized, self.cfg.normalized),
530
+ onesided=default(onesided, self.cfg.onesided),
531
+ length=length,
532
+ return_complex=return_complex,
533
+ )
534
+
535
+ def stft(
536
+ self,
537
+ wave: Tensor,
538
+ center: bool = True,
539
+ n_fft: Optional[int] = None,
540
+ hop_length: Optional[int] = None,
541
+ win_length: Optional[int] = None,
542
+ normalized: Optional[bool] = None,
543
+ onesided: Optional[bool] = None,
544
+ return_complex: bool = True,
545
+ ):
546
+
547
+ window = (
548
+ torch.hann_window(win_length, device=wave.device)
549
+ if win_length is not None and win_length != self.cfg.win_length
550
+ else self.window.to(wave.device)
551
+ )
552
+
553
+ results = torch.stft(
554
+ input=wave,
555
+ n_fft=default(n_fft, self.cfg.n_fft),
556
+ hop_length=default(hop_length, self.cfg.hop_length),
557
+ win_length=default(win_length, self.cfg.win_length),
558
+ window=window,
559
+ center=center,
560
+ pad_mode="reflect",
561
+ normalized=default(normalized, self.cfg.normalized),
562
+ onesided=default(onesided, self.cfg.onesided),
563
+ return_complex=True, # always, then if we need a not complex type we use view as real.
564
+ )
565
+ if not return_complex:
566
+ return torch.view_as_real(results)
567
+ return results
425
568
 
426
569
  def istft_norm(
427
570
  self,
@@ -435,11 +578,11 @@ class AudioProcessor(Model):
435
578
  onesided: Optional[bool] = None,
436
579
  return_complex: bool = False,
437
580
  ):
438
-
439
- if win_length is not None and win_length != self.cfg.win_length:
440
- window = torch.hann_window(win_length, device=wave.device)
441
- else:
442
- window = self.window
581
+ window = (
582
+ torch.hann_window(win_length, device=wave.device)
583
+ if win_length is not None and win_length != self.cfg.win_length
584
+ else self.window.to(wave.device)
585
+ )
443
586
  spectrogram = torch.stft(
444
587
  input=wave,
445
588
  n_fft=default(n_fft, self.cfg.n_fft),
@@ -473,15 +616,15 @@ class AudioProcessor(Model):
473
616
  def load_audio(
474
617
  self,
475
618
  path: PathLike,
476
- top_db: float = 30,
619
+ top_db: Optional[float] = None,
477
620
  normalize: bool = False,
621
+ mono: bool = True,
478
622
  *,
479
- ref: float | Callable[[np.ndarray], Any] = np.max,
480
- frame_length: int = 2048,
623
+ sample_rate: Optional[float] = None,
481
624
  hop_length: int = 512,
482
- mono: bool = True,
483
- offset: float = 0.0,
625
+ frame_length: int = 2048,
484
626
  duration: Optional[float] = None,
627
+ offset: float = 0.0,
485
628
  dtype: Any = np.float32,
486
629
  res_type: str = "soxr_hq",
487
630
  fix: bool = True,
@@ -491,29 +634,32 @@ class AudioProcessor(Model):
491
634
  norm_axis: int = 0,
492
635
  norm_threshold: Optional[float] = None,
493
636
  norm_fill: Optional[bool] = None,
637
+ ref: float | Callable[[np.ndarray], Any] = np.max,
494
638
  ) -> Tensor:
495
639
  is_file(path, True)
640
+ sample_rate = default(sample_rate, self.cfg.sample_rate)
496
641
  wave, sr = librosa.load(
497
642
  str(path),
498
- sr=self.cfg.sample_rate,
643
+ sr=sample_rate,
499
644
  mono=mono,
500
645
  offset=offset,
501
646
  duration=duration,
502
647
  dtype=dtype,
503
648
  res_type=res_type,
504
649
  )
505
- wave, _ = librosa.effects.trim(
506
- wave,
507
- top_db=top_db,
508
- ref=ref,
509
- frame_length=frame_length,
510
- hop_length=hop_length,
511
- )
512
- if sr != self.cfg.sample_rate:
650
+ if top_db is not None:
651
+ wave, _ = librosa.effects.trim(
652
+ wave,
653
+ top_db=top_db,
654
+ ref=ref,
655
+ frame_length=frame_length,
656
+ hop_length=hop_length,
657
+ )
658
+ if sr != sample_rate:
513
659
  wave = librosa.resample(
514
660
  wave,
515
661
  orig_sr=sr,
516
- target_sr=self.cfg.sample_rate,
662
+ target_sr=sample_rate,
517
663
  res_type=res_type,
518
664
  fix=fix,
519
665
  scale=scale,
@@ -553,10 +699,6 @@ class AudioProcessor(Model):
553
699
  maximum,
554
700
  )
555
701
 
556
- def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
557
- ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
558
- return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
559
-
560
702
  def forward(
561
703
  self,
562
704
  *inputs: Union[Tensor, float],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a36
3
+ Version: 0.0.1a37
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -18,7 +18,7 @@ Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
20
  Requires-Dist: lt-utils>=0.0.4
21
- Requires-Dist: librosa==0.11.*
21
+ Requires-Dist: librosa<1,>=0.10.2.post1
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
24
24
  Requires-Dist: scipy
@@ -1,4 +1,4 @@
1
- lt_tensor/__init__.py,sha256=nBbiGH1byHU0aTTKKorRj8MIEO2oEMBXl7kt5DOCatU,441
1
+ lt_tensor/__init__.py,sha256=CFVK5h2Y-p3xFJ6mCW8dI1FOFeObsOyDjyUqJtxmkmg,441
2
2
  lt_tensor/config_templates.py,sha256=F9UvL8paAjkSvio890kp8WznpYeI50pYnm9iqQroBxk,2797
3
3
  lt_tensor/losses.py,sha256=Heco_WyoC1HkNkcJEircOAzS9umusATHiNAG-FKGyzc,8918
4
4
  lt_tensor/lr_schedulers.py,sha256=6_vcfaPHrozfH3wvmNEdKSFYl6iTIijYoHL8vuG-45U,7651
@@ -35,9 +35,9 @@ lt_tensor/model_zoo/losses/CQT/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
35
35
  lt_tensor/model_zoo/losses/CQT/transforms.py,sha256=Vkid0J9dqLnlINfyyUlQf-qB3gOQAgU7W9j7xLOjDFw,13218
36
36
  lt_tensor/model_zoo/losses/CQT/utils.py,sha256=twGw6FVD7V5Ksfx_1BUEN3EP1tAS6wo-9LL3VnuHB8c,16751
37
37
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
38
- lt_tensor/processors/audio.py,sha256=3YzyEpMwh124rb1KMAly62qweeruF200BnM-vQIbzy0,18645
39
- lt_tensor-0.0.1a36.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
40
- lt_tensor-0.0.1a36.dist-info/METADATA,sha256=mTmnoWn8EG48j_VOM3rr_8RLLgaxB5pWZE1tkPdFrac,1062
41
- lt_tensor-0.0.1a36.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
- lt_tensor-0.0.1a36.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
43
- lt_tensor-0.0.1a36.dist-info/RECORD,,
38
+ lt_tensor/processors/audio.py,sha256=QadO6e7uXRkheNU8ba-SNw72HPD1XvR-6VJltoF8YRA,23535
39
+ lt_tensor-0.0.1a37.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
40
+ lt_tensor-0.0.1a37.dist-info/METADATA,sha256=6EkGRk9fT_wsvl_pqKZ0S8I-x1Awkm9pvr3MKnW8OPM,1071
41
+ lt_tensor-0.0.1a37.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
+ lt_tensor-0.0.1a37.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
43
+ lt_tensor-0.0.1a37.dist-info/RECORD,,