lt-tensor 0.0.1a36__py3-none-any.whl → 0.0.1a38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a36"
1
+ __version__ = "0.0.1a38"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
lt_tensor/losses.py CHANGED
@@ -133,7 +133,7 @@ class MultiMelScaleLoss(Model):
133
133
  loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
134
134
  loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
135
135
  loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
136
- center: bool = True,
136
+ center: bool = False,
137
137
  power: float = 1.0,
138
138
  normalized: bool = False,
139
139
  pad_mode: str = "reflect",
@@ -149,6 +149,7 @@ class MultiMelScaleLoss(Model):
149
149
  lambda_rms: float = 1.0,
150
150
  lambda_pitch: float = 1.0,
151
151
  weight: float = 1.0,
152
+ mel: Literal["librosa", "torch"] = "torch",
152
153
  ):
153
154
  super().__init__()
154
155
  assert (
@@ -188,6 +189,7 @@ class MultiMelScaleLoss(Model):
188
189
  onesided,
189
190
  std,
190
191
  mean,
192
+ mel,
191
193
  )
192
194
 
193
195
  def _setup_mels(
@@ -206,6 +208,7 @@ class MultiMelScaleLoss(Model):
206
208
  onesided: Optional[bool],
207
209
  std: int,
208
210
  mean: int,
211
+ mel: str,
209
212
  ):
210
213
  assert (
211
214
  len(n_mels)
@@ -224,6 +227,7 @@ class MultiMelScaleLoss(Model):
224
227
  pad_mode=pad_mode,
225
228
  std=std,
226
229
  mean=mean,
230
+ mel_default=mel,
227
231
  )
228
232
  self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
229
233
  [
@@ -247,12 +251,14 @@ class MultiMelScaleLoss(Model):
247
251
  def forward(
248
252
  self, input_wave: torch.Tensor, target_wave: torch.Tensor
249
253
  ) -> torch.Tensor:
250
- assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
254
+ assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1], (
255
+ f"Size mismatch! input_wave {input_wave.shape[-1]} must match target_wave: {target_wave.shape[-1]}. "
256
+ "Alternatively 'use_istft_norm' can be set to Trie with will automatically force the audio to that size."
257
+ )
251
258
  target_wave = target_wave.to(input_wave.device)
252
259
  losses = 0.0
253
260
  for M in self.mel_spectrograms:
254
- # Apply normalization if requested
255
- if self.use_istft_norm:
261
+ if self.use_istft_norm and input_proc.shape[-1] != target_proc.shape[-1]:
256
262
  input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
257
263
  target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
258
264
  else:
@@ -1,14 +1,15 @@
1
- __all__ = ["DiffWave", "DiffWaveConfig", "SpectrogramUpsample", "DiffusionEmbedding"]
1
+ __all__ = ["DiffWave", "DiffWaveConfig", "SpectrogramUpsampler", "DiffusionEmbedding"]
2
2
 
3
3
  import numpy as np
4
4
  from lt_tensor.torch_commons import *
5
5
  from torch.nn import functional as F
6
6
  from lt_tensor.config_templates import ModelConfig
7
7
  from lt_tensor.torch_commons import *
8
- from lt_tensor.model_zoo.convs import ConvNets, Conv1dEXT
8
+ from lt_tensor.model_zoo.convs import ConvNets, ConvEXT
9
9
  from lt_tensor.model_base import Model
10
10
  from math import sqrt
11
11
  from lt_utils.common import *
12
+ from lt_tensor.misc_utils import log_tensor
12
13
 
13
14
 
14
15
  class DiffWaveConfig(ModelConfig):
@@ -21,12 +22,8 @@ class DiffWaveConfig(ModelConfig):
21
22
  unconditional = False
22
23
  apply_norm: Optional[Literal["weight", "spectral"]] = None
23
24
  apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None
24
- noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist()
25
+ noise_schedule: list[int] = np.linspace(1e-4, 0.05, 25).tolist()
25
26
  # settings for auto-fixes
26
- interpolate = False
27
- interpolation_mode: Literal[
28
- "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
29
- ] = "nearest"
30
27
 
31
28
  def __init__(
32
29
  self,
@@ -37,16 +34,6 @@ class DiffWaveConfig(ModelConfig):
37
34
  dilation_cycle_length=10,
38
35
  unconditional=False,
39
36
  noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist(),
40
- interpolate_cond=False,
41
- interpolation_mode: Literal[
42
- "nearest",
43
- "linear",
44
- "bilinear",
45
- "bicubic",
46
- "trilinear",
47
- "area",
48
- "nearest-exact",
49
- ] = "nearest",
50
37
  apply_norm: Optional[Literal["weight", "spectral"]] = None,
51
38
  apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None,
52
39
  ):
@@ -58,8 +45,6 @@ class DiffWaveConfig(ModelConfig):
58
45
  "residual_channels": residual_channels,
59
46
  "unconditional": unconditional,
60
47
  "noise_schedule": noise_schedule,
61
- "interpolate": interpolate_cond,
62
- "interpolation_mode": interpolation_mode,
63
48
  "apply_norm": apply_norm,
64
49
  "apply_norm_resblock": apply_norm_resblock,
65
50
  }
@@ -102,19 +87,34 @@ class DiffusionEmbedding(Model):
102
87
  return table
103
88
 
104
89
 
105
- class SpectrogramUpsample(Model):
90
+ class SpectrogramUpsampler(Model):
106
91
  def __init__(self):
107
92
  super().__init__()
108
- self.conv1 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
109
- self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
110
- self.activation = nn.LeakyReLU(0.4)
93
+ self.conv_net = nn.Sequential(
94
+ ConvEXT(
95
+ 1,
96
+ 1,
97
+ [3, 32],
98
+ stride=[1, 16],
99
+ padding=[1, 8],
100
+ module_type="2d",
101
+ transpose=True,
102
+ ),
103
+ nn.LeakyReLU(0.1),
104
+ ConvEXT(
105
+ 1,
106
+ 1,
107
+ [3, 32],
108
+ stride=[1, 16],
109
+ padding=[1, 8],
110
+ module_type="2d",
111
+ transpose=True,
112
+ ),
113
+ nn.LeakyReLU(0.1),
114
+ )
111
115
 
112
- def forward(self, x):
113
- x = torch.unsqueeze(x, 1)
114
- x = self.activation(self.conv1(x))
115
- x = self.activation(self.conv2(x))
116
- x = torch.squeeze(x, 1)
117
- return x
116
+ def forward(self, x: Tensor):
117
+ return self.conv_net(x.unsqueeze(0)).squeeze(1)
118
118
 
119
119
 
120
120
  class ResidualBlock(Model):
@@ -133,7 +133,7 @@ class ResidualBlock(Model):
133
133
  :param uncond: disable spectrogram conditional
134
134
  """
135
135
  super().__init__()
136
- self.dilated_conv = Conv1dEXT(
136
+ self.dilated_conv = ConvEXT(
137
137
  residual_channels,
138
138
  2 * residual_channels,
139
139
  3,
@@ -142,18 +142,18 @@ class ResidualBlock(Model):
142
142
  apply_norm=apply_norm,
143
143
  )
144
144
  self.diffusion_projection = nn.Linear(512, residual_channels)
145
- if not uncond: # conditional model
146
- self.conditioner_projection = Conv1dEXT(
145
+ self.uncoditional = uncond
146
+ self.conditioner_projection = None
147
+ if not uncond:
148
+ self.conditioner_projection = ConvEXT(
147
149
  n_mels,
148
150
  2 * residual_channels,
149
151
  1,
150
152
  apply_norm=apply_norm,
151
153
  )
152
- else: # unconditional model
153
- self.conditioner_projection = None
154
154
 
155
- self.output_projection = Conv1dEXT(
156
- residual_channels, 2 * residual_channels, 1, apply_norm == apply_norm
155
+ self.output_projection = ConvEXT(
156
+ residual_channels, 2 * residual_channels, 1, apply_norm=apply_norm
157
157
  )
158
158
 
159
159
  def forward(
@@ -164,20 +164,15 @@ class ResidualBlock(Model):
164
164
  ):
165
165
 
166
166
  diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
167
- y = x + diffusion_step
168
- if (
169
- conditioner is None or self.conditioner_projection is None
170
- ): # using a unconditional model
171
- y = self.dilated_conv(y)
172
- else:
173
- conditioner = self.conditioner_projection(conditioner)
174
- y = self.dilated_conv(y) + conditioner
175
-
176
- gate, filter = torch.chunk(y, 2, dim=1)
177
- y = torch.sigmoid(gate) * torch.tanh(filter)
167
+ y = (x + diffusion_step).squeeze(1)
168
+ y = self.dilated_conv(y)
169
+ if not self.uncoditional and conditioner is not None:
170
+ y = y + self.conditioner_projection(conditioner)
178
171
 
172
+ gate, _filter = y.chunk(2, dim=1)
173
+ y = gate.sigmoid() * _filter.tanh()
179
174
  y = self.output_projection(y)
180
- residual, skip = torch.chunk(y, 2, dim=1)
175
+ residual, skip = y.chunk(2, dim=1)
181
176
  return (x + residual) / sqrt(2.0), skip
182
177
 
183
178
 
@@ -186,19 +181,17 @@ class DiffWave(Model):
186
181
  super().__init__()
187
182
  self.params = params
188
183
  self.n_hop = self.params.hop_samples
189
- self.interpolate = self.params.interpolate
190
- self.interpolate_mode = self.params.interpolation_mode
191
- self.input_projection = Conv1dEXT(
184
+ self.input_projection = ConvEXT(
192
185
  in_channels=1,
193
186
  out_channels=params.residual_channels,
194
187
  kernel_size=1,
195
188
  apply_norm=self.params.apply_norm,
189
+ activation_out=nn.LeakyReLU(0.1),
196
190
  )
197
191
  self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
198
- if self.params.unconditional: # use unconditional model
199
- self.spectrogram_upsample = None
200
- else:
201
- self.spectrogram_upsample = SpectrogramUpsample()
192
+ self.spectrogram_upsampler = (
193
+ SpectrogramUpsampler() if not self.params.unconditional else None
194
+ )
202
195
 
203
196
  self.residual_layers = nn.ModuleList(
204
197
  [
@@ -212,18 +205,18 @@ class DiffWave(Model):
212
205
  for i in range(params.residual_layers)
213
206
  ]
214
207
  )
215
- self.skip_projection = Conv1dEXT(
208
+ self.skip_projection = ConvEXT(
216
209
  in_channels=params.residual_channels,
217
210
  out_channels=params.residual_channels,
218
211
  kernel_size=1,
219
212
  apply_norm=self.params.apply_norm,
213
+ activation_out=nn.LeakyReLU(0.1),
220
214
  )
221
- self.output_projection = Conv1dEXT(
222
- params.residual_channels, 1, 1, apply_norm=self.params.apply_norm
215
+ self.output_projection = ConvEXT(
216
+ params.residual_channels, 1, 1, apply_norm=self.params.apply_norm, init_weights=True,
223
217
  )
224
218
  self.activation = nn.LeakyReLU(0.1)
225
- self.r_sqrt = sqrt(len(self.residual_layers))
226
- nn.init.zeros_(self.output_projection.weight)
219
+ self._res_d = sqrt(len(self.residual_layers))
227
220
 
228
221
  def forward(
229
222
  self,
@@ -231,31 +224,25 @@ class DiffWave(Model):
231
224
  diffusion_step: Tensor,
232
225
  spectrogram: Optional[Tensor] = None,
233
226
  ):
234
- T = x.shape[-1]
235
- if x.ndim == 2:
236
- x = audio.unsqueeze(1)
237
- x = self.activation(self.input_projection(x))
227
+ if not self.params.unconditional:
228
+ assert spectrogram is not None
229
+ if audio.ndim < 3:
230
+ if audio.ndim == 2:
231
+ audio = audio.unsqueeze(1)
232
+ else:
233
+ audio = audio.unsqueeze(0).unsqueeze(0)
238
234
 
235
+ x = self.input_projection(audio)
239
236
  diffusion_step = self.diffusion_embedding(diffusion_step)
240
- if spectrogram is not None and self.spectrogram_upsample is not None:
241
- if self.auto_interpolate:
242
- # a little heavy, but helps a lot to fix mismatched shapes,
243
- # not always recommended due to data loss
244
- spectrogram = F.interpolate(
245
- input=spectrogram,
246
- size=int(T * self.n_hop),
247
- mode=self.interpolate_mode,
248
- )
249
- spectrogram = self.spectrogram_upsample(spectrogram)
237
+ if not self.params.unconditional: # use conditional model
238
+ spectrogram = self.spectrogram_upsampler(spectrogram)
250
239
 
251
- skip = None
240
+ skip = torch.zeros_like(x, device=x.device)
252
241
  for i, layer in enumerate(self.residual_layers):
253
242
  x, skip_connection = layer(x, diffusion_step, spectrogram)
254
- if i == 0:
255
- skip = skip_connection
256
- else:
257
- skip = skip_connection + skip
258
- x = skip / self.r_sqrt
259
- x = self.activation(self.skip_projection(x))
243
+ skip += skip_connection
244
+
245
+ x = skip / self._res_d
246
+ x = self.skip_projection(x)
260
247
  x = self.output_projection(x)
261
248
  return x
@@ -1,4 +1,4 @@
1
- __all__ = ["ConvNets", "Conv1dEXT"]
1
+ __all__ = ["ConvNets", "ConvEXT"]
2
2
  import math
3
3
  from lt_utils.common import *
4
4
  import torch.nn.functional as F
@@ -6,6 +6,7 @@ from lt_tensor.torch_commons import *
6
6
  from lt_tensor.model_base import Model
7
7
  from lt_tensor.misc_utils import log_tensor
8
8
  from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
9
+ from lt_utils.misc_utils import default
9
10
 
10
11
 
11
12
  def spectral_norm_select(module: nn.Module, enabled: bool):
@@ -52,10 +53,7 @@ class ConvNets(Model):
52
53
  m.weight.data.normal_(mean, std)
53
54
 
54
55
 
55
- class Conv1dEXT(ConvNets):
56
-
57
- # TODO: Use this module to replace all that are using normalizations, mostly those in `audio_models`
58
-
56
+ class ConvEXT(ConvNets):
59
57
  def __init__(
60
58
  self,
61
59
  in_channels: int,
@@ -72,6 +70,10 @@ class Conv1dEXT(ConvNets):
72
70
  apply_norm: Optional[Literal["weight", "spectral"]] = None,
73
71
  activation_in: nn.Module = nn.Identity(),
74
72
  activation_out: nn.Module = nn.Identity(),
73
+ module_type: Literal["1d", "2d", "3d"] = "1d",
74
+ transpose: bool = False,
75
+ weight_init: Optional[Callable[[nn.Module], None]] = None,
76
+ init_weights: bool = True,
75
77
  *args,
76
78
  **kwargs,
77
79
  ):
@@ -91,23 +93,30 @@ class Conv1dEXT(ConvNets):
91
93
  device=device,
92
94
  dtype=dtype,
93
95
  )
96
+ match module_type.lower():
97
+ case "1d":
98
+ md = nn.Conv1d if not transpose else nn.ConvTranspose1d
99
+ case "2d":
100
+ md = nn.Conv2d if not transpose else nn.ConvTranspose2d
101
+ case "3d":
102
+ md = nn.Conv3d if not transpose else nn.ConvTranspose3d
103
+ case _:
104
+ raise ValueError(
105
+ f"module_type {module_type} is not a valid module type! use '1d', '2d' or '3d'"
106
+ )
107
+
94
108
  if apply_norm is None:
95
- self.cnn = nn.Conv1d(**cnn_kwargs)
96
- self.has_wn = False
109
+ self.cnn = md(**cnn_kwargs)
97
110
  else:
98
- self.has_wn = True
99
111
  if apply_norm == "spectral":
100
- self.cnn = spectral_norm(nn.Conv1d(**cnn_kwargs))
112
+ self.cnn = spectral_norm(md(**cnn_kwargs))
101
113
  else:
102
- self.cnn = weight_norm(nn.Conv1d(**cnn_kwargs))
114
+ self.cnn = weight_norm(md(**cnn_kwargs))
103
115
  self.actv_in = activation_in
104
116
  self.actv_out = activation_out
105
- self.cnn.apply(self.init_weights)
117
+ if init_weights:
118
+ weight_init = default(weight_init, self.init_weights)
119
+ self.cnn.apply(weight_init)
106
120
 
107
121
  def forward(self, input: Tensor):
108
122
  return self.actv_out(self.cnn(self.actv_in(input)))
109
-
110
- def remove_norms(self, name="weight"):
111
- if self.has_wn:
112
- remove_norm(self.cnn, name)
113
- self.has_wn = False
@@ -10,6 +10,7 @@ from lt_utils.type_utils import is_file, is_array
10
10
  from lt_utils.file_ops import FileScan, get_file_name, path_to_str
11
11
  from torchaudio.functional import detect_pitch_frequency
12
12
  import torch.nn.functional as F
13
+ from librosa.filters import mel as _mel_filter_bank
13
14
 
14
15
  DEFAULT_DEVICE = torch.tensor([0]).device
15
16
 
@@ -25,7 +26,7 @@ class AudioProcessorConfig(ModelConfig):
25
26
  f_min: float = 0
26
27
  f_max: Optional[float] = None
27
28
  center: bool = True
28
- mel_scale: Literal["htk" "slaney"] = "htk"
29
+ mel_scale: Literal["htk", "slaney"] = "htk"
29
30
  std: int = 4
30
31
  mean: int = -4
31
32
  n_iter: int = 32
@@ -33,6 +34,7 @@ class AudioProcessorConfig(ModelConfig):
33
34
  normalized: bool = False
34
35
  onesided: Optional[bool] = None
35
36
  n_stft: int = None
37
+ mel_default: Literal["torch", "librosa"] = "librosa"
36
38
 
37
39
  def __init__(
38
40
  self,
@@ -49,6 +51,7 @@ class AudioProcessorConfig(ModelConfig):
49
51
  mean: int = -4,
50
52
  normalized: bool = False,
51
53
  onesided: Optional[bool] = None,
54
+ mel_default: Literal["torch", "librosa"] = "librosa",
52
55
  *args,
53
56
  **kwargs,
54
57
  ):
@@ -66,6 +69,7 @@ class AudioProcessorConfig(ModelConfig):
66
69
  "mean": mean,
67
70
  "normalized": normalized,
68
71
  "onesided": onesided,
72
+ "mel_default": mel_default,
69
73
  }
70
74
  super().__init__(**settings)
71
75
  self.post_process()
@@ -88,14 +92,10 @@ def _comp_rms_helper(i: int, audio: Tensor, mel: Optional[Tensor]):
88
92
 
89
93
 
90
94
  class AudioProcessor(Model):
91
- def __init__(
92
- self,
93
- config: AudioProcessorConfig = AudioProcessorConfig(),
94
- window: Optional[Tensor] = None,
95
- ):
95
+ def __init__(self, config: AudioProcessorConfig = AudioProcessorConfig()):
96
96
  super().__init__()
97
97
  self.cfg = config
98
- self._mel_spec = torchaudio.transforms.MelSpectrogram(
98
+ self._mel_spec_torch = torchaudio.transforms.MelSpectrogram(
99
99
  sample_rate=self.cfg.sample_rate,
100
100
  n_mels=self.cfg.n_mels,
101
101
  n_fft=self.cfg.n_fft,
@@ -107,6 +107,7 @@ class AudioProcessor(Model):
107
107
  mel_scale=self.cfg.mel_scale,
108
108
  normalized=self.cfg.normalized,
109
109
  )
110
+
110
111
  self._mel_rscale = torchaudio.transforms.InverseMelScale(
111
112
  n_stft=self.cfg.n_stft,
112
113
  n_mels=self.cfg.n_mels,
@@ -115,34 +116,121 @@ class AudioProcessor(Model):
115
116
  f_max=self.cfg.f_max,
116
117
  mel_scale=self.cfg.mel_scale,
117
118
  )
118
-
119
+ self.mel_lib_padding = (self.cfg.n_fft - self.cfg.hop_length) // 2
119
120
  self.register_buffer(
120
121
  "window",
121
- (torch.hann_window(self.cfg.win_length) if window is None else window),
122
+ torch.hann_window(self.cfg.win_length),
122
123
  )
124
+ self.register_buffer(
125
+ "mel_filter_bank",
126
+ torch.from_numpy(
127
+ _mel_filter_bank(
128
+ sr=self.cfg.sample_rate,
129
+ n_fft=self.cfg.n_fft,
130
+ n_mels=self.cfg.n_mels,
131
+ fmin=self.cfg.f_min,
132
+ fmax=self.cfg.f_max,
133
+ )
134
+ ).float(),
135
+ )
136
+
137
+ def spectral_norm(self, x: Tensor, c: int = 1, eps: float = 1e-5):
138
+ return torch.log(torch.clamp(x, min=eps) * c)
139
+
140
+ def spectral_de_norm(self, x: Tensor, c: int = 1):
141
+ return torch.exp(x) / c
142
+
143
+ def log_norm(
144
+ self,
145
+ entry: Tensor,
146
+ eps: float = 1e-5,
147
+ mean: Optional[Number] = None,
148
+ std: Optional[Number] = None,
149
+ ) -> Tensor:
150
+ mean = default(mean, self.cfg.mean)
151
+ std = default(std, self.cfg.std)
152
+ return (torch.log(eps + entry.unsqueeze(0)) - mean) / std
123
153
 
124
154
  def compute_mel(
125
155
  self,
126
156
  wave: Tensor,
127
- eps: float = 1e-5,
128
- raw_mel_only: bool = False,
129
- *,
130
- _recall: bool = False,
157
+ method: Optional[Literal["torch", "librosa"]] = None,
158
+ apply_norm: bool = False,
159
+ eps: Optional[float] = None,
160
+ **kwargs,
161
+ ) -> Tensor:
162
+ method = default(method, self.cfg.mel_default)
163
+ if method == "torch":
164
+ return self.compute_mel_torch(
165
+ wave,
166
+ log_norm=apply_norm,
167
+ eps=eps,
168
+ mean=kwargs.get("mean", None),
169
+ std=kwargs.get("std", None),
170
+ )
171
+ return self.compute_mel_librosa(
172
+ wave,
173
+ log_norm=apply_norm,
174
+ eps=eps,
175
+ )
176
+
177
+ def compute_mel_torch(
178
+ self,
179
+ wave: Tensor,
180
+ log_norm: bool = False,
181
+ eps: Optional[float] = None,
182
+ mean: Optional[Number] = None,
183
+ std: Optional[Number] = None,
184
+ *args,
185
+ **kwargs,
131
186
  ) -> Tensor:
132
187
  """Returns: (M, T) or (B, M, T) if batched"""
133
188
  try:
134
- mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
135
- if not raw_mel_only:
136
- mel_tensor = (
137
- torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
138
- ) / self.cfg.std
139
- return mel_tensor.squeeze()
189
+ mel_tensor = self._mel_spec_torch.forward(wave.to(self.device)) # [M, T]
140
190
 
141
191
  except RuntimeError as e:
142
- if not _recall:
143
- self._mel_spec.to(self.device)
144
- return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
145
- raise e
192
+ mel_tensor = self._mel_spec_torch.forward(wave.to(self.device)) # [M, T]
193
+ if log_norm:
194
+ return self.log_norm(mel_tensor, eps, mean, std).squeeze()
195
+ return mel_tensor.squeeze()
196
+
197
+ def compute_mel_librosa(
198
+ self,
199
+ wave: Tensor,
200
+ eps: float = 1e-5,
201
+ spectral_norm: bool = False,
202
+ *args,
203
+ **kwargs,
204
+ ):
205
+ if wave.ndim == 1:
206
+ wave = wave.unsqueeze(0)
207
+ wave = torch.nn.functional.pad(
208
+ wave.unsqueeze(1),
209
+ (self.mel_lib_padding, self.mel_lib_padding),
210
+ mode="reflect",
211
+ ).squeeze(1)
212
+ spec = torch.stft(
213
+ wave,
214
+ self.cfg.n_fft,
215
+ hop_length=self.cfg.hop_length,
216
+ win_length=self.cfg.win_length,
217
+ window=self.window,
218
+ center=self.cfg.center,
219
+ pad_mode="reflect",
220
+ normalized=False,
221
+ onesided=True,
222
+ return_complex=True,
223
+ )
224
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-12)
225
+ try:
226
+ results = torch.matmul(self.mel_filter_bank, spec)
227
+ except RuntimeError:
228
+ self.mel_filter_bank = self.mel_filter_bank.to(self.device)
229
+ self.window = self.window.to(self.device)
230
+ results = torch.matmul(self.mel_filter_bank, spec)
231
+ if spectral_norm:
232
+ return self.spectral_norm(results, eps=eps).squeeze()
233
+ return results.squeeze()
146
234
 
147
235
  def compute_inverse_mel(self, melspec: Tensor, *, _recall=False):
148
236
  try:
@@ -382,7 +470,7 @@ class AudioProcessor(Model):
382
470
  antialias=antialias,
383
471
  )
384
472
 
385
- def istft(
473
+ def istft_spec_phase(
386
474
  self,
387
475
  spec: Tensor,
388
476
  phase: Tensor,
@@ -394,34 +482,91 @@ class AudioProcessor(Model):
394
482
  normalized: Optional[bool] = None,
395
483
  onesided: Optional[bool] = None,
396
484
  return_complex: bool = False,
397
- *,
398
- _recall: bool = False,
399
485
  ):
400
- if win_length is not None and win_length != self.cfg.win_length:
401
- window = torch.hann_window(win_length, device=spec.device)
402
- else:
403
- window = self.window
486
+ """Util for models that needs to reconstruct the audio using inverse stft"""
487
+ window = (
488
+ torch.hann_window(win_length, device=spec.device)
489
+ if win_length is not None and win_length != self.cfg.win_length
490
+ else self.window.to(spec.device)
491
+ )
492
+ return torch.istft(
493
+ spec * torch.exp(phase * 1j),
494
+ n_fft=default(n_fft, self.cfg.n_fft),
495
+ hop_length=default(hop_length, self.cfg.hop_length),
496
+ win_length=default(win_length, self.cfg.win_length),
497
+ window=window,
498
+ center=center,
499
+ normalized=default(normalized, self.cfg.normalized),
500
+ onesided=default(onesided, self.cfg.onesided),
501
+ length=length,
502
+ return_complex=return_complex,
503
+ )
404
504
 
405
- try:
406
- return torch.istft(
407
- spec * torch.exp(phase * 1j),
408
- n_fft=default(n_fft, self.cfg.n_fft),
409
- hop_length=default(hop_length, self.cfg.hop_length),
410
- win_length=default(win_length, self.cfg.win_length),
411
- window=window,
412
- center=center,
413
- normalized=default(normalized, self.cfg.normalized),
414
- onesided=default(onesided, self.cfg.onesided),
415
- length=length,
416
- return_complex=return_complex,
417
- )
418
- except RuntimeError as e:
419
- if not _recall and spec.device != self.window.device:
420
- self.window = self.window.to(spec.device)
421
- return self.istft(
422
- spec, phase, n_fft, hop_length, win_length, length, _recall=True
423
- )
424
- raise e
505
+ def istft(
506
+ self,
507
+ wave: Tensor,
508
+ n_fft: Optional[int] = None,
509
+ hop_length: Optional[int] = None,
510
+ win_length: Optional[int] = None,
511
+ length: Optional[int] = None,
512
+ center: bool = True,
513
+ normalized: Optional[bool] = None,
514
+ onesided: Optional[bool] = None,
515
+ return_complex: bool = False,
516
+ ):
517
+ window = (
518
+ torch.hann_window(win_length, device=wave.device)
519
+ if win_length is not None and win_length != self.cfg.win_length
520
+ else self.window.to(wave.device)
521
+ )
522
+ if not torch.is_complex(wave):
523
+ wave = wave * 1j
524
+ return torch.istft(
525
+ wave,
526
+ n_fft=default(n_fft, self.cfg.n_fft),
527
+ hop_length=default(hop_length, self.cfg.hop_length),
528
+ win_length=default(win_length, self.cfg.win_length),
529
+ window=window,
530
+ center=center,
531
+ normalized=default(normalized, self.cfg.normalized),
532
+ onesided=default(onesided, self.cfg.onesided),
533
+ length=length,
534
+ return_complex=return_complex,
535
+ )
536
+
537
+ def stft(
538
+ self,
539
+ wave: Tensor,
540
+ center: bool = True,
541
+ n_fft: Optional[int] = None,
542
+ hop_length: Optional[int] = None,
543
+ win_length: Optional[int] = None,
544
+ normalized: Optional[bool] = None,
545
+ onesided: Optional[bool] = None,
546
+ return_complex: bool = True,
547
+ ):
548
+
549
+ window = (
550
+ torch.hann_window(win_length, device=wave.device)
551
+ if win_length is not None and win_length != self.cfg.win_length
552
+ else self.window.to(wave.device)
553
+ )
554
+
555
+ results = torch.stft(
556
+ input=wave,
557
+ n_fft=default(n_fft, self.cfg.n_fft),
558
+ hop_length=default(hop_length, self.cfg.hop_length),
559
+ win_length=default(win_length, self.cfg.win_length),
560
+ window=window,
561
+ center=center,
562
+ pad_mode="reflect",
563
+ normalized=default(normalized, self.cfg.normalized),
564
+ onesided=default(onesided, self.cfg.onesided),
565
+ return_complex=True, # always, then if we need a not complex type we use view as real.
566
+ )
567
+ if not return_complex:
568
+ return torch.view_as_real(results)
569
+ return results
425
570
 
426
571
  def istft_norm(
427
572
  self,
@@ -435,11 +580,11 @@ class AudioProcessor(Model):
435
580
  onesided: Optional[bool] = None,
436
581
  return_complex: bool = False,
437
582
  ):
438
-
439
- if win_length is not None and win_length != self.cfg.win_length:
440
- window = torch.hann_window(win_length, device=wave.device)
441
- else:
442
- window = self.window
583
+ window = (
584
+ torch.hann_window(win_length, device=wave.device)
585
+ if win_length is not None and win_length != self.cfg.win_length
586
+ else self.window.to(wave.device)
587
+ )
443
588
  spectrogram = torch.stft(
444
589
  input=wave,
445
590
  n_fft=default(n_fft, self.cfg.n_fft),
@@ -473,15 +618,15 @@ class AudioProcessor(Model):
473
618
  def load_audio(
474
619
  self,
475
620
  path: PathLike,
476
- top_db: float = 30,
621
+ top_db: Optional[float] = None,
477
622
  normalize: bool = False,
623
+ mono: bool = True,
478
624
  *,
479
- ref: float | Callable[[np.ndarray], Any] = np.max,
480
- frame_length: int = 2048,
625
+ sample_rate: Optional[float] = None,
481
626
  hop_length: int = 512,
482
- mono: bool = True,
483
- offset: float = 0.0,
627
+ frame_length: int = 2048,
484
628
  duration: Optional[float] = None,
629
+ offset: float = 0.0,
485
630
  dtype: Any = np.float32,
486
631
  res_type: str = "soxr_hq",
487
632
  fix: bool = True,
@@ -491,29 +636,32 @@ class AudioProcessor(Model):
491
636
  norm_axis: int = 0,
492
637
  norm_threshold: Optional[float] = None,
493
638
  norm_fill: Optional[bool] = None,
639
+ ref: float | Callable[[np.ndarray], Any] = np.max,
494
640
  ) -> Tensor:
495
641
  is_file(path, True)
642
+ sample_rate = default(sample_rate, self.cfg.sample_rate)
496
643
  wave, sr = librosa.load(
497
644
  str(path),
498
- sr=self.cfg.sample_rate,
645
+ sr=sample_rate,
499
646
  mono=mono,
500
647
  offset=offset,
501
648
  duration=duration,
502
649
  dtype=dtype,
503
650
  res_type=res_type,
504
651
  )
505
- wave, _ = librosa.effects.trim(
506
- wave,
507
- top_db=top_db,
508
- ref=ref,
509
- frame_length=frame_length,
510
- hop_length=hop_length,
511
- )
512
- if sr != self.cfg.sample_rate:
652
+ if top_db is not None:
653
+ wave, _ = librosa.effects.trim(
654
+ wave,
655
+ top_db=top_db,
656
+ ref=ref,
657
+ frame_length=frame_length,
658
+ hop_length=hop_length,
659
+ )
660
+ if sr != sample_rate:
513
661
  wave = librosa.resample(
514
662
  wave,
515
663
  orig_sr=sr,
516
- target_sr=self.cfg.sample_rate,
664
+ target_sr=sample_rate,
517
665
  res_type=res_type,
518
666
  fix=fix,
519
667
  scale=scale,
@@ -553,10 +701,6 @@ class AudioProcessor(Model):
553
701
  maximum,
554
702
  )
555
703
 
556
- def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
557
- ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
558
- return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
559
-
560
704
  def forward(
561
705
  self,
562
706
  *inputs: Union[Tensor, float],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a36
3
+ Version: 0.0.1a38
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -18,7 +18,7 @@ Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
20
  Requires-Dist: lt-utils>=0.0.4
21
- Requires-Dist: librosa==0.11.*
21
+ Requires-Dist: librosa<1,>=0.10.2.post1
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
24
24
  Requires-Dist: scipy
@@ -1,6 +1,6 @@
1
- lt_tensor/__init__.py,sha256=nBbiGH1byHU0aTTKKorRj8MIEO2oEMBXl7kt5DOCatU,441
1
+ lt_tensor/__init__.py,sha256=2C0DCdesX13deU-nV8EbiF0poSsHWU0VuZvFcTZhJQk,441
2
2
  lt_tensor/config_templates.py,sha256=F9UvL8paAjkSvio890kp8WznpYeI50pYnm9iqQroBxk,2797
3
- lt_tensor/losses.py,sha256=Heco_WyoC1HkNkcJEircOAzS9umusATHiNAG-FKGyzc,8918
3
+ lt_tensor/losses.py,sha256=e-YyKMmI0FwWQ3VLfJLDGSH4_rNpnYj0-htuk4eYboE,9283
4
4
  lt_tensor/lr_schedulers.py,sha256=6_vcfaPHrozfH3wvmNEdKSFYl6iTIijYoHL8vuG-45U,7651
5
5
  lt_tensor/math_ops.py,sha256=ahX6Z1Mt3X-FhmwSZYZea5mB1B0S8GDuvKPfAm5e_FQ,2646
6
6
  lt_tensor/misc_utils.py,sha256=stL6q3M7S2N4FBICFYbgYpdPDrJRlwmr24-iCXMRifM,28933
@@ -11,7 +11,7 @@ lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,81
11
11
  lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
12
  lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
13
13
  lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
14
- lt_tensor/model_zoo/convs.py,sha256=Tws0jrPfs9m7OLmJ30W0AfkAvZgppW7lNi4xt0e-qRU,3518
14
+ lt_tensor/model_zoo/convs.py,sha256=Ri_8BV2dho-KS57GS8xg6SJOKPkYEmJs4Rl_byofiSE,3995
15
15
  lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
16
16
  lt_tensor/model_zoo/fusion.py,sha256=usC1bcjQRNivDc8xzkIS5T1glm78OLcs2V_tPqfp-eI,5422
17
17
  lt_tensor/model_zoo/pos_encoder.py,sha256=3d1EYLinCU9UAy-WuEWeYMGhMqaGknCiQ5qEmhw_UYM,4487
@@ -26,7 +26,7 @@ lt_tensor/model_zoo/activations/snake/__init__.py,sha256=AtOAbJuMinxmKkppITGMzRb
26
26
  lt_tensor/model_zoo/audio_models/__init__.py,sha256=WwiP9MekJreMOfKPWLl24VkRJIpLk6hhL8ch0aKgOss,103
27
27
  lt_tensor/model_zoo/audio_models/resblocks.py,sha256=u-foHxaFDUICjxSkpyHXljQYQG9zMxVYaOGqLR_nJ-k,7978
28
28
  lt_tensor/model_zoo/audio_models/bigvgan/__init__.py,sha256=4EZG8Non75dHoDCizMHbMTvPrKwdUlPYGHc7hkfT_nw,8526
29
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
29
+ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=g9tSLjRgl7whafA9aunNna-n41Afdqz13EiaiHOHwo0,8249
30
30
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=ITSXHg3c0Um1P2HaPaXkQKI7meG5Ne60wTbyyYju3hY,6360
31
31
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=blICjLX_z_IFmR3_TCz_dJiSayLYGza9eG6fd9aKyvE,7448
32
32
  lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
@@ -35,9 +35,9 @@ lt_tensor/model_zoo/losses/CQT/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5
35
35
  lt_tensor/model_zoo/losses/CQT/transforms.py,sha256=Vkid0J9dqLnlINfyyUlQf-qB3gOQAgU7W9j7xLOjDFw,13218
36
36
  lt_tensor/model_zoo/losses/CQT/utils.py,sha256=twGw6FVD7V5Ksfx_1BUEN3EP1tAS6wo-9LL3VnuHB8c,16751
37
37
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
38
- lt_tensor/processors/audio.py,sha256=3YzyEpMwh124rb1KMAly62qweeruF200BnM-vQIbzy0,18645
39
- lt_tensor-0.0.1a36.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
40
- lt_tensor-0.0.1a36.dist-info/METADATA,sha256=mTmnoWn8EG48j_VOM3rr_8RLLgaxB5pWZE1tkPdFrac,1062
41
- lt_tensor-0.0.1a36.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
- lt_tensor-0.0.1a36.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
43
- lt_tensor-0.0.1a36.dist-info/RECORD,,
38
+ lt_tensor/processors/audio.py,sha256=QaEbzoCxl7zJNv6ELFwX6AO--8NuOGscgqxwNpV8Czw,23599
39
+ lt_tensor-0.0.1a38.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
40
+ lt_tensor-0.0.1a38.dist-info/METADATA,sha256=g87aQm1aw-2dlCEvss9CcQ4iNl1Bi_mlqafGIeR1AdU,1071
41
+ lt_tensor-0.0.1a38.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
+ lt_tensor-0.0.1a38.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
43
+ lt_tensor-0.0.1a38.dist-info/RECORD,,