lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,252 @@
1
+ __all__ = [
2
+ "spectral_norm_select",
3
+ "get_weight_norm",
4
+ "ResBlock1D",
5
+ "ResBlock2D",
6
+ "ResBlock1DShuffled",
7
+ "AdaResBlock1D",
8
+ "ResBlocks",
9
+ ]
10
+ import math
11
+ from lt_utils.common import *
12
+ import torch.nn.functional as F
13
+ from lt_tensor.torch_commons import *
14
+ from lt_tensor.model_base import Model
15
+ from lt_tensor.misc_utils import log_tensor
16
+ from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
17
+
18
+
19
+ def spectral_norm_select(module: nn.Module, enabled: bool):
20
+ if enabled:
21
+ return spectral_norm(module)
22
+ return module
23
+
24
+
25
+ def get_weight_norm(norm_type: Optional[Literal["weight", "spectral"]] = None):
26
+ if not norm_type:
27
+ return lambda x: x
28
+ if norm_type == "weight":
29
+ return lambda x: weight_norm(x)
30
+ return lambda x: spectral_norm(x)
31
+
32
+
33
+ class ConvNets(Model):
34
+ def remove_weight_norm(self):
35
+ for module in self.modules():
36
+ try:
37
+ remove_weight_norm(module)
38
+ except ValueError:
39
+ pass
40
+
41
+ @staticmethod
42
+ def init_weights(m, mean=0.0, std=0.01):
43
+ classname = m.__class__.__name__
44
+ if "Conv" in classname:
45
+ m.weight.data.normal_(mean, std)
46
+
47
+
48
+ class ResBlocks(ConvNets):
49
+ def __init__(
50
+ self,
51
+ channels: int,
52
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
53
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
54
+ [1, 3, 5],
55
+ [1, 3, 5],
56
+ [1, 3, 5],
57
+ ],
58
+ activation: nn.Module = nn.LeakyReLU(0.1),
59
+ ):
60
+ super().__init__()
61
+ self.num_kernels = len(resblock_kernel_sizes)
62
+ self.rb = nn.ModuleList()
63
+ self.activation = activation
64
+
65
+ for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
66
+ self.rb.append(ResBlock1D(channels, k, j, activation))
67
+
68
+ self.rb.apply(self.init_weights)
69
+
70
+ def forward(self, x: torch.Tensor):
71
+ xs = None
72
+ for i, block in enumerate(self.rb):
73
+ if i == 0:
74
+ xs = block(x)
75
+ else:
76
+ xs += block(x)
77
+ x = xs / self.num_kernels
78
+ return x
79
+
80
+
81
+
82
+ class ResBlock1D(ConvNets):
83
+ def __init__(
84
+ self,
85
+ channels,
86
+ kernel_size=3,
87
+ dilation=(1, 3, 5),
88
+ activation: nn.Module = nn.LeakyReLU(0.1),
89
+ ):
90
+ super().__init__()
91
+
92
+ self.conv_nets = nn.ModuleList(
93
+ [
94
+ self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
95
+ for i in range(3)
96
+ ]
97
+ )
98
+ self.conv_nets.apply(self.init_weights)
99
+ self.last_index = len(self.conv_nets) - 1
100
+
101
+ def _get_conv_layer(self, id, ch, k, stride, d, actv):
102
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
103
+ return nn.Sequential(
104
+ actv, # 1
105
+ weight_norm(
106
+ nn.Conv1d(
107
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
108
+ )
109
+ ), # 2
110
+ actv, # 3
111
+ weight_norm(
112
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
113
+ ), # 4
114
+ )
115
+
116
+ def forward(self, x: Tensor):
117
+ for cnn in self.conv_nets:
118
+ x = cnn(x) + x
119
+ return x
120
+
121
+
122
+ class ResBlock1DShuffled(ConvNets):
123
+ def __init__(
124
+ self,
125
+ channels,
126
+ kernel_size=3,
127
+ dilation=(1, 3, 5),
128
+ activation: nn.Module = nn.LeakyReLU(0.1),
129
+ add_channel_shuffle: bool = False, # requires pytorch 2.7.0 +
130
+ channel_shuffle_groups=1,
131
+ ):
132
+ super().__init__()
133
+
134
+ self.channel_shuffle = (
135
+ nn.ChannelShuffle(channel_shuffle_groups)
136
+ if add_channel_shuffle
137
+ else nn.Identity()
138
+ )
139
+
140
+ self.conv_nets = nn.ModuleList(
141
+ [
142
+ self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
143
+ for i in range(3)
144
+ ]
145
+ )
146
+ self.conv_nets.apply(self.init_weights)
147
+ self.last_index = len(self.conv_nets) - 1
148
+
149
+ def _get_conv_layer(self, id, ch, k, stride, d, actv):
150
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
151
+ return nn.Sequential(
152
+ actv, # 1
153
+ weight_norm(
154
+ nn.Conv1d(
155
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
156
+ )
157
+ ), # 2
158
+ actv, # 3
159
+ weight_norm(
160
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
161
+ ), # 4
162
+ )
163
+
164
+ def forward(self, x: Tensor):
165
+ b = x.clone() * 0.5
166
+ for cnn in self.conv_nets:
167
+ x = cnn(self.channel_shuffle(x)) + b
168
+ return x
169
+
170
+
171
+ class ResBlock2D(Model):
172
+ def __init__(
173
+ self,
174
+ in_channels,
175
+ out_channels,
176
+ downsample=False,
177
+ ):
178
+ super().__init__()
179
+ stride = 2 if downsample else 1
180
+
181
+ self.block = nn.Sequential(
182
+ nn.Conv2d(in_channels, out_channels, 3, stride, 1),
183
+ nn.LeakyReLU(0.2),
184
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
185
+ )
186
+
187
+ self.skip = nn.Identity()
188
+ if downsample or in_channels != out_channels:
189
+ self.skip = spectral_norm_select(
190
+ nn.Conv2d(in_channels, out_channels, 1, stride)
191
+ )
192
+ # on less to be handled every cicle
193
+ self.sqrt_2 = math.sqrt(2)
194
+
195
+ def forward(self, x: Tensor):
196
+ return (self.block(x) + self.skip(x)) / self.sqrt_2
197
+
198
+
199
+ class AdaResBlock1D(ConvNets):
200
+ def __init__(
201
+ self,
202
+ res_block_channels: int,
203
+ ada_channel_in: int,
204
+ kernel_size=3,
205
+ dilation=(1, 3, 5),
206
+ activation: nn.Module = nn.LeakyReLU(0.1),
207
+ ):
208
+ super().__init__()
209
+
210
+ self.conv_nets = nn.ModuleList(
211
+ [
212
+ self._get_conv_layer(
213
+ i,
214
+ res_block_channels,
215
+ ada_channel_in,
216
+ kernel_size,
217
+ 1,
218
+ dilation,
219
+ )
220
+ for i in range(3)
221
+ ]
222
+ )
223
+ self.conv_nets.apply(self.init_weights)
224
+ self.last_index = len(self.conv_nets) - 1
225
+ self.activation = activation
226
+
227
+ def _get_conv_layer(self, id, ch, ada_ch, k, stride, d):
228
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
229
+ return nn.ModuleDict(
230
+ dict(
231
+ norm1=AdaFusion1D(ada_ch, ch),
232
+ norm2=AdaFusion1D(ada_ch, ch),
233
+ alpha1=nn.Parameter(torch.ones(1, ada_ch, 1)),
234
+ alpha2=nn.Parameter(torch.ones(1, ada_ch, 1)),
235
+ conv1=weight_norm(
236
+ nn.Conv1d(
237
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
238
+ )
239
+ ), # 2
240
+ conv2=weight_norm(
241
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
242
+ ), # 4
243
+ )
244
+ )
245
+
246
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
247
+ for cnn in self.conv_nets:
248
+ xt = self.activation(cnn["norm1"](x, y, cnn["alpha1"]))
249
+ xt = cnn["conv1"](xt)
250
+ xt = self.activation(cnn["norm2"](xt, y, cnn["alpha2"]))
251
+ x = cnn["conv2"](xt) + x
252
+ return x
@@ -11,8 +11,8 @@ from lt_tensor.torch_commons import *
11
11
  from lt_tensor.model_base import Model
12
12
  from lt_utils.misc_utils import default
13
13
  from typing import Optional
14
- from lt_tensor.model_zoo.pos import *
15
- from lt_tensor.model_zoo.bsc import FeedForward
14
+ from lt_tensor.model_zoo.pos_encoder import *
15
+ from lt_tensor.model_zoo.basic import FeedForward
16
16
 
17
17
 
18
18
  def init_weights(module):
@@ -5,11 +5,14 @@ from lt_utils.type_utils import is_file, is_array
5
5
  from lt_tensor.misc_utils import log_tensor
6
6
  import librosa
7
7
  import torchaudio
8
-
8
+ import numpy as np
9
9
  from lt_tensor.transform import InverseTransformConfig, InverseTransform
10
10
  from lt_utils.file_ops import FileScan, get_file_name, path_to_str
11
-
11
+ from torchaudio.functional import detect_pitch_frequency
12
12
  from lt_tensor.model_base import Model
13
+ import torch.nn.functional as F
14
+
15
+ DEFAULT_DEVICE = torch.tensor([0]).device
13
16
 
14
17
 
15
18
  class AudioProcessor(Model):
@@ -21,14 +24,14 @@ class AudioProcessor(Model):
21
24
  win_length: Optional[int] = None,
22
25
  hop_length: Optional[int] = None,
23
26
  f_min: float = 0,
24
- f_max: float | None = None,
27
+ f_max: float = 12000.0,
25
28
  center: bool = True,
26
29
  mel_scale: Literal["htk", "slaney"] = "htk",
27
30
  std: int = 4,
28
31
  mean: int = -4,
29
32
  n_iter: int = 32,
30
33
  window: Optional[Tensor] = None,
31
- normalized: bool =False,
34
+ normalized: bool = False,
32
35
  onesided: Optional[bool] = None,
33
36
  ):
34
37
  super().__init__()
@@ -38,7 +41,7 @@ class AudioProcessor(Model):
38
41
  self.n_fft = n_fft
39
42
  self.n_stft = n_fft // 2 + 1
40
43
  self.f_min = f_min
41
- self.f_max = f_max
44
+ self.f_max = max(min(f_max, 12000), self.f_min + 1)
42
45
  self.n_iter = n_iter
43
46
  self.hop_length = hop_length or n_fft // 4
44
47
  self.win_length = win_length or n_fft
@@ -76,7 +79,165 @@ class AudioProcessor(Model):
76
79
  "window",
77
80
  (torch.hann_window(self.win_length) if window is None else window),
78
81
  )
79
- # self._inv_transform = InverseTransform(**inverse_transform_config.to_dict())
82
+
83
+ def from_numpy(
84
+ self,
85
+ array: np.ndarray,
86
+ device: Optional[torch.device] = None,
87
+ dtype: Optional[torch.dtype] = None,
88
+ ):
89
+ converted = torch.from_numpy(array)
90
+ if not any([device is not None, dtype is not None]):
91
+ return converted
92
+ return converted.to(device=device, dtype=dtype)
93
+
94
+ def from_numpy_batch(
95
+ self,
96
+ arrays: List[np.ndarray],
97
+ device: Optional[torch.device] = None,
98
+ dtype: Optional[torch.dtype] = None,
99
+ ):
100
+ stacked = torch.stack([torch.from_numpy(x) for x in arrays])
101
+ if not any([device is not None, dtype is not None]):
102
+ return stacked
103
+ return stacked.to(device=device, dtype=dtype)
104
+
105
+ def to_numpy_safe(self, tensor: Tensor):
106
+ return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
107
+
108
+ def compute_rms(
109
+ self,
110
+ audio: Union[Tensor, np.ndarray],
111
+ mel: Optional[Tensor] = None,
112
+ ):
113
+ default_dtype = audio.dtype
114
+ default_device = audio.device
115
+ if audio.ndim > 1:
116
+ B = audio.shape[0]
117
+ else:
118
+ B = 1
119
+ audio = audio.unsqueeze(0)
120
+
121
+ if mel is not None:
122
+ if mel.ndim == 2:
123
+ assert B == 1, "Batch from mel and audio must be the same!"
124
+ mel = mel.unsqueeze(0)
125
+ else:
126
+ assert B == mel.shape[0], "Batch from mel and audio must be the same!"
127
+ mel = self.to_numpy_safe(mel)
128
+ gt_mel = lambda idx: mel[idx, :, :]
129
+
130
+ else:
131
+ gt_mel = lambda idx: None
132
+ audio = self.to_numpy_safe(audio)
133
+ if B == 1:
134
+ _r = librosa.feature.rms(
135
+ y=audio, frame_length=self.n_fft, hop_length=self.hop_length
136
+ )[0]
137
+ rms = self.from_numpy(_r, default_device, default_dtype)
138
+
139
+ else:
140
+ rms_ = []
141
+ for i in range(B):
142
+ _r = librosa.feature.rms(
143
+ y=audio[i, :],
144
+ S=gt_mel(i),
145
+ frame_length=self.n_fft,
146
+ hop_length=self.hop_length,
147
+ )[0]
148
+ rms_.append(_r)
149
+ rms = self.from_numpy_batch(rms_, default_device, default_dtype)
150
+
151
+ return rms
152
+
153
+ def compute_pitch(
154
+ self,
155
+ audio: Tensor,
156
+ ):
157
+ default_dtype = audio.dtype
158
+ default_device = audio.device
159
+ if audio.ndim > 1:
160
+ B = audio.shape[0]
161
+ else:
162
+ B = 1
163
+ fmin = max(self.f_min, 80)
164
+ if B == 1:
165
+ f0 = self.from_numpy(
166
+ librosa.yin(
167
+ self.to_numpy_safe(audio),
168
+ fmin=fmin,
169
+ fmax=self.f_max,
170
+ frame_length=self.n_fft,
171
+ sr=self.sample_rate,
172
+ hop_length=self.hop_length,
173
+ center=self.center,
174
+ ),
175
+ default_device,
176
+ default_dtype,
177
+ )
178
+
179
+ else:
180
+ f0_ = []
181
+ for i in range(B):
182
+ r = librosa.yin(
183
+ self.to_numpy_safe(audio[i, :]),
184
+ fmin=fmin,
185
+ fmax=self.f_max,
186
+ frame_length=self.n_fft,
187
+ sr=self.sample_rate,
188
+ hop_length=self.hop_length,
189
+ center=self.center,
190
+ )
191
+ f0_.append(r)
192
+ f0 = self.from_numpy_batch(f0_, default_device, default_dtype)
193
+
194
+ # librosa.pyin(self.f_min, self.f_max)
195
+ return f0 # dict(f0=f0, attention_mask=f0 != f_max)
196
+
197
+ def compute_pitch_torch(self, audio: Tensor):
198
+ return detect_pitch_frequency(
199
+ audio,
200
+ sample_rate=self.sample_rate,
201
+ frame_time=self.n_fft,
202
+ win_length=self.win_length,
203
+ freq_low=max(self.f_min, 35),
204
+ freq_high=self.f_max,
205
+ )
206
+
207
+ def interpolate_tensor(
208
+ self,
209
+ tensor: Tensor,
210
+ target_len: int,
211
+ mode: Literal[
212
+ "nearest",
213
+ "linear",
214
+ "bilinear",
215
+ "bicubic",
216
+ "trilinear",
217
+ "area",
218
+ "nearest-exact",
219
+ ] = "nearest",
220
+ align_corners: Optional[bool] = None,
221
+ scale_factor: Optional[list[float]] = None,
222
+ recompute_scale_factor: Optional[bool] = None,
223
+ antialias: bool = False,
224
+ ):
225
+ """
226
+ The modes available for upsampling are: `nearest`, `linear` (3D-only),
227
+ `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
228
+ """
229
+
230
+ if tensor.ndim == 2: # [1, T]
231
+ tensor = tensor.unsqueeze(1) # [1, 1, T]
232
+ return F.interpolate(
233
+ tensor,
234
+ size=target_len,
235
+ mode=mode,
236
+ align_corners=align_corners,
237
+ scale_factor=scale_factor,
238
+ recompute_scale_factor=recompute_scale_factor,
239
+ antialias=antialias,
240
+ )
80
241
 
81
242
  def inverse_transform(
82
243
  self,
@@ -95,7 +256,9 @@ class AudioProcessor(Model):
95
256
  n_fft=n_fft or self.n_fft,
96
257
  hop_length=hop_length or self.hop_length,
97
258
  win_length=win_length or self.win_length,
98
- window=torch.hann_window(win_length or self.win_length, device=spec.device),
259
+ window=torch.hann_window(
260
+ win_length or self.win_length, device=spec.device
261
+ ),
99
262
  center=self.center,
100
263
  normalized=self.normalized,
101
264
  onesided=self.onesided,
@@ -105,10 +268,12 @@ class AudioProcessor(Model):
105
268
  except RuntimeError as e:
106
269
  if not _recall and spec.device != self.window.device:
107
270
  self.window = self.window.to(spec.device)
108
- return self.inverse_transform(spec, phase, n_fft, hop_length, win_length, length, _recall=True)
271
+ return self.inverse_transform(
272
+ spec, phase, n_fft, hop_length, win_length, length, _recall=True
273
+ )
109
274
  raise e
110
275
 
111
- def rebuild_spectrogram(
276
+ def normalize_audio(
112
277
  self,
113
278
  wave: Tensor,
114
279
  length: Optional[int] = None,
@@ -148,7 +313,7 @@ class AudioProcessor(Model):
148
313
  except RuntimeError as e:
149
314
  if not _recall and wave.device != self.window.device:
150
315
  self.window = self.window.to(wave.device)
151
- return self.rebuild_spectrogram(wave, length, _recall=True)
316
+ return self.normalize_audio(wave, length, _recall=True)
152
317
  raise e
153
318
 
154
319
  def compute_mel(
@@ -167,12 +332,7 @@ class AudioProcessor(Model):
167
332
 
168
333
  def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
169
334
  if isinstance(n_iter, int) and n_iter != self.n_iter:
170
- self.giffin_lim = torchaudio.transforms.GriffinLim(
171
- n_fft=self.n_fft,
172
- n_iter=n_iter,
173
- win_length=self.win_length,
174
- hop_length=self.hop_length,
175
- )
335
+ self.giffin_lim.n_iter = n_iter
176
336
  self.n_iter = n_iter
177
337
  return self.giffin_lim.forward(
178
338
  self.mel_rscale(mel),
@@ -182,21 +342,26 @@ class AudioProcessor(Model):
182
342
  self,
183
343
  path: PathLike,
184
344
  top_db: float = 30,
345
+ normalize: bool = False,
346
+ alpha: float = 1.0,
185
347
  ) -> Tensor:
186
348
  is_file(path, True)
187
349
  wave, sr = librosa.load(str(path), sr=self.sample_rate)
188
350
  wave, _ = librosa.effects.trim(wave, top_db=top_db)
189
- return (
190
- torch.from_numpy(
191
- librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
192
- if sr != self.sample_rate
193
- else wave
194
- )
195
- .float()
196
- .unsqueeze(0)
197
- )
351
+ if sr != self.sample_rate:
352
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
353
+ if normalize:
354
+ wave = librosa.util.normalize(wave)
355
+ if alpha not in [0.0, 1.0]:
356
+ wave = wave * alpha
357
+ return torch.from_numpy(wave).float().unsqueeze(0)
198
358
 
199
- def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
359
+ def find_audios(
360
+ self,
361
+ path: PathLike,
362
+ additional_extensions: List[str] = [],
363
+ maximum: int | None = None,
364
+ ):
200
365
  extensions = [
201
366
  "*.wav",
202
367
  "*.aac",
@@ -212,6 +377,7 @@ class AudioProcessor(Model):
212
377
  return FileScan.files(
213
378
  path,
214
379
  extensions,
380
+ maximum,
215
381
  )
216
382
 
217
383
  def find_audio_text_pairs(
@@ -240,57 +406,6 @@ class AudioProcessor(Model):
240
406
  break
241
407
  return results
242
408
 
243
- def slice_mismatch_outputs(
244
- self,
245
- tensor_1: Tensor,
246
- tensor_2: Tensor,
247
- smallest_size: Optional[int] = None,
248
- left_to_right: bool = True,
249
- ):
250
- assert tensor_1.ndim == tensor_2.ndim, (
251
- "Tensors must have the same dimentions to be sliced! \n"
252
- f"Received instead a tensor_1 with {tensor_1.ndim}D and tensor_2 with {tensor_1.ndim}D."
253
- )
254
- dim = tensor_1.ndim
255
- assert dim < 5, (
256
- "Excpected to receive tensors with from 1D up to 4D. "
257
- f"Received instead a {dim}D tensor."
258
- )
259
-
260
- if tensor_1.shape[-1] == tensor_2.shape[-1]:
261
- return tensor_1, tensor_2
262
-
263
- if smallest_size is None:
264
- smallest_size = min(tensor_1.shape[-1], tensor_2.shape[-1])
265
- if dim == 0:
266
- tensor_1 = tensor_1.unsqueeze(0)
267
- tensor_2 = tensor_2.unsqueeze(0)
268
- dim = 1
269
-
270
- if dim == 1:
271
- if left_to_right:
272
- return tensor_1[:smallest_size], tensor_2[:smallest_size]
273
- return tensor_1[-smallest_size:], tensor_2[-smallest_size:]
274
- elif dim == 2:
275
- if left_to_right:
276
- return tensor_1[:, :smallest_size], tensor_2[:, :smallest_size]
277
- return tensor_1[:, -smallest_size:], tensor_2[:, -smallest_size:]
278
- elif dim == 3:
279
- if left_to_right:
280
- return tensor_1[:, :, :smallest_size], tensor_2[:, :, :smallest_size]
281
- return tensor_1[:, :, -smallest_size:], tensor_2[:, :, -smallest_size:]
282
-
283
- # else:
284
- if left_to_right:
285
- return (
286
- tensor_1[:, :, :, :smallest_size],
287
- tensor_2[:, :, :, :smallest_size],
288
- )
289
- return (
290
- tensor_1[:, :, :, -smallest_size:],
291
- tensor_2[:, :, :, -smallest_size:],
292
- )
293
-
294
409
  def stft_loss(
295
410
  self,
296
411
  signal: Tensor,
@@ -302,11 +417,23 @@ class AudioProcessor(Model):
302
417
  smallest = min(signal.shape[-1], ground.shape[-1])
303
418
  signal = signal[:, -smallest:]
304
419
  ground = ground[:, -smallest:]
305
- sig_mel = self.compute_mel(signal, base, True).detach().cpu()
306
- gnd_mel = self.compute_mel(ground, base, True).detach().cpu()
420
+ sig_mel = self.compute_mel(signal, base, True)
421
+ gnd_mel = self.compute_mel(ground, base, True)
307
422
  return torch.norm(gnd_mel - sig_mel, p=1) / torch.norm(gnd_mel, p=1)
308
423
 
309
- # def forward(self, wave: Tensor, base: Optional[float] = None):
424
+ @staticmethod
425
+ def plot_spectrogram(spectrogram, ax_):
426
+ import matplotlib.pylab as plt
427
+
428
+ fig, ax = plt.subplots(figsize=(10, 2))
429
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
430
+ plt.colorbar(im, ax=ax)
431
+
432
+ fig.canvas.draw()
433
+ plt.close()
434
+
435
+ return fig
436
+
310
437
  def forward(
311
438
  self,
312
439
  *inputs: Union[Tensor, float],
lt_tensor/transform.py CHANGED
@@ -316,18 +316,6 @@ def inverse_transform(
316
316
  )
317
317
 
318
318
 
319
- def is_nand(a: bool, b: bool):
320
- """[a -> b = result]
321
- ```
322
- False -> False = True
323
- False -> True = True
324
- True -> False = True
325
- True -> True = False
326
- ```
327
- """
328
- return not (a and b)
329
-
330
-
331
319
  class InverseTransformConfig:
332
320
  def __init__(
333
321
  self,
@@ -413,9 +401,11 @@ class InverseTransform(Model):
413
401
  self.return_complex = return_complex
414
402
  self.onesided = onesided
415
403
  self.normalized = normalized
416
- self.register_buffer('window', torch.hann_window(self.win_length) if window is None else window)
404
+ self.register_buffer(
405
+ "window", torch.hann_window(self.win_length) if window is None else window
406
+ )
417
407
 
418
- def forward(self, spec: Tensor, phase: Tensor, *, _recall:bool = False):
408
+ def forward(self, spec: Tensor, phase: Tensor, *, _recall: bool = False):
419
409
  """
420
410
  Perform the inverse short-time Fourier transform.
421
411
 
@@ -434,7 +424,7 @@ class InverseTransform(Model):
434
424
  try:
435
425
  return torch.istft(
436
426
  spec * torch.exp(phase * 1j),
437
- n_fft = self.n_fft,
427
+ n_fft=self.n_fft,
438
428
  hop_length=self.hop_length,
439
429
  win_length=self.win_length,
440
430
  window=self.window,
@@ -448,4 +438,5 @@ class InverseTransform(Model):
448
438
  if not _recall and spec.device != self.window.device:
449
439
  self.window = self.window.to(spec.device)
450
440
  return self.forward(spec, phase, _recall=True)
451
- raise e
441
+ raise e
442
+