lt-tensor 0.0.1a27__py3-none-any.whl → 0.0.1a29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -137,16 +137,24 @@ class _Devices_Base(nn.Module):
137
137
  )
138
138
 
139
139
  def _apply_device(self):
140
- """Add here components that are needed to have device applied to them,
141
- that usually the '.to()' function fails to apply
142
-
143
- example:
144
- ```
145
- def _apply_device_to(self):
146
- self.my_tensor = self.my_tensor.to(device=self.device)
147
- ```
148
- """
149
- pass
140
+ """This may be seem as overkill, but its necessary"""
141
+ for modules in self.modules():
142
+ try:
143
+ modules.to(self.device)
144
+ except:
145
+ pass
146
+
147
+ for buffer in self.buffers():
148
+ try:
149
+ buffer.to(self.device)
150
+ except:
151
+ pass
152
+
153
+ for tensor in self.parameters():
154
+ try:
155
+ tensor.to(self.device)
156
+ except:
157
+ pass
150
158
 
151
159
  def _to_dvc(
152
160
  self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
@@ -9,6 +9,7 @@ __all__ = [
9
9
  "audio_models",
10
10
  "hifigan",
11
11
  "istft",
12
+ "losses",
12
13
  ]
13
14
  from .audio_models import hifigan, istft
14
15
  from . import (
@@ -19,4 +20,5 @@ from . import (
19
20
  pos_encoder,
20
21
  residual,
21
22
  transformer,
23
+ losses,
22
24
  )
@@ -0,0 +1,3 @@
1
+ from . import discriminators
2
+
3
+ __all__ = ["discriminators"]
@@ -0,0 +1,610 @@
1
+ from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
2
+ from lt_utils.common import *
3
+ from lt_tensor.torch_commons import *
4
+ from lt_tensor.model_base import Model
5
+ from lt_tensor.model_zoo.convs import ConvNets
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+
9
+ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
10
+ List[Tensor],
11
+ List[Tensor],
12
+ List[List[Tensor]],
13
+ List[List[Tensor]],
14
+ ]
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class MultiDiscriminatorWrapper(ConvNets):
22
+ """Base for all multi-steps type of discriminators"""
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.leaky_relu = nn.LeakyReLU(kwargs.get("negative_slope", 0.1))
26
+
27
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
28
+ pass
29
+
30
+ # for type hinting
31
+ def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
32
+ return super().__call__(*args, **kwds)
33
+
34
+ def gen_step(self, y: Tensor, y_hat: Tensor) -> tuple[Tensor, Tensor, List[float]]:
35
+ """For generator loss step [feature loss, generator loss, list of generator losses (float)]"""
36
+ _, y_hat_gen, feat_map_real, feat_map_gen = self.train_step(y, y_hat)
37
+ loss_feat = self.feature_loss(feat_map_real, feat_map_gen)
38
+ loss_generator, losses_gen_s = self.generator_loss(y_hat_gen)
39
+ return loss_feat, loss_generator, losses_gen_s
40
+
41
+ def disc_step(
42
+ self, y: Tensor, y_hat: Tensor
43
+ ) -> tuple[Tensor, tuple[List[float], List[float]]]:
44
+ """For discriminator loss step [discriminator loss, (disc losses real, disc losses generated)]"""
45
+ y_hat_real, y_hat_gen, _, _ = self.train_step(y, y_hat)
46
+
47
+ loss_disc, losses_disc_real, losses_disc_generated = self.discriminator_loss(
48
+ y_hat_real, y_hat_gen
49
+ )
50
+ return loss_disc, (losses_disc_real, losses_disc_generated)
51
+
52
+ @staticmethod
53
+ def discriminator_loss(
54
+ disc_real_outputs, disc_generated_outputs
55
+ ) -> Tuple[Tensor, List[float], List[float]]:
56
+ loss = 0
57
+ r_losses = []
58
+ g_losses = []
59
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
60
+ r_loss = torch.mean((1 - dr) ** 2)
61
+ g_loss = torch.mean(dg**2)
62
+ loss += r_loss + g_loss
63
+ r_losses.append(r_loss.item())
64
+ g_losses.append(g_loss.item())
65
+
66
+ return loss, r_losses, g_losses
67
+
68
+ @staticmethod
69
+ def feature_loss(fmap_r, fmap_g) -> Tensor:
70
+ loss = 0
71
+ for dr, dg in zip(fmap_r, fmap_g):
72
+ for rl, gl in zip(dr, dg):
73
+ loss += torch.mean(torch.abs(rl - gl))
74
+
75
+ return loss * 2
76
+
77
+ @staticmethod
78
+ def generator_loss(disc_outputs) -> Tuple[Tensor, List[float]]:
79
+ loss = 0
80
+ gen_losses = []
81
+ for dg in disc_outputs:
82
+ l = torch.mean((1 - dg) ** 2)
83
+ gen_losses.append(l.item())
84
+ loss += l
85
+
86
+ return loss, gen_losses
87
+
88
+
89
+ class DiscriminatorP(ConvNets):
90
+ def __init__(
91
+ self,
92
+ period: List[int],
93
+ discriminator_channel_mult: Number = 1,
94
+ kernel_size: int = 5,
95
+ stride: int = 3,
96
+ use_spectral_norm: bool = False,
97
+ ):
98
+ super().__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
101
+ dsc = lambda x: int(x * discriminator_channel_mult)
102
+ self.convs = nn.ModuleList(
103
+ [
104
+ norm_f(
105
+ nn.Conv2d(
106
+ 1,
107
+ dsc(32),
108
+ (kernel_size, 1),
109
+ (stride, 1),
110
+ padding=(get_padding(5, 1), 0),
111
+ )
112
+ ),
113
+ norm_f(
114
+ nn.Conv2d(
115
+ dsc(32),
116
+ dsc(128),
117
+ (kernel_size, 1),
118
+ (stride, 1),
119
+ padding=(get_padding(5, 1), 0),
120
+ )
121
+ ),
122
+ norm_f(
123
+ nn.Conv2d(
124
+ dsc(128),
125
+ dsc(512),
126
+ (kernel_size, 1),
127
+ (stride, 1),
128
+ padding=(get_padding(5, 1), 0),
129
+ )
130
+ ),
131
+ norm_f(
132
+ nn.Conv2d(
133
+ dsc(512),
134
+ dsc(1024),
135
+ (kernel_size, 1),
136
+ (stride, 1),
137
+ padding=(get_padding(5, 1), 0),
138
+ )
139
+ ),
140
+ norm_f(
141
+ nn.Conv2d(
142
+ dsc(1024),
143
+ dsc(1024),
144
+ (kernel_size, 1),
145
+ 1,
146
+ padding=(2, 0),
147
+ )
148
+ ),
149
+ ]
150
+ )
151
+ self.conv_post = norm_f(nn.Conv2d(dsc(1024), 1, (3, 1), 1, padding=(1, 0)))
152
+
153
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
154
+ fmap = []
155
+
156
+ # 1d to 2d
157
+ b, c, t = x.shape
158
+ if t % self.period != 0: # pad first
159
+ n_pad = self.period - (t % self.period)
160
+ x = F.pad(x, (0, n_pad), "reflect")
161
+ t = t + n_pad
162
+ x = x.view(b, c, t // self.period, self.period)
163
+
164
+ for l in self.convs:
165
+ x = l(x)
166
+ x = F.leaky_relu(x, 0.1)
167
+ fmap.append(x)
168
+ x = self.conv_post(x)
169
+ fmap.append(x)
170
+ return x.flatten(1, -1), fmap
171
+
172
+
173
+ class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
174
+ def __init__(
175
+ self,
176
+ discriminator_channel_mult: Number = 1,
177
+ mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
178
+ use_spectral_norm: bool = False,
179
+ ):
180
+ super().__init__()
181
+ self.mpd_reshapes = mpd_reshapes
182
+ print(f"mpd_reshapes: {self.mpd_reshapes}")
183
+ self.discriminators = nn.ModuleList(
184
+ [
185
+ DiscriminatorP(
186
+ rs,
187
+ use_spectral_norm=use_spectral_norm,
188
+ discriminator_channel_mult=discriminator_channel_mult,
189
+ )
190
+ for rs in self.mpd_reshapes
191
+ ]
192
+ )
193
+
194
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> MULTI_DISC_OUT_TYPE:
195
+ y_d_rs = []
196
+ y_d_gs = []
197
+ fmap_rs = []
198
+ fmap_gs = []
199
+ for i, d in enumerate(self.discriminators):
200
+ y_d_r, fmap_r = d(y)
201
+ y_d_g, fmap_g = d(y_hat)
202
+ y_d_rs.append(y_d_r)
203
+ fmap_rs.append(fmap_r)
204
+ y_d_gs.append(y_d_g)
205
+ fmap_gs.append(fmap_g)
206
+
207
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
208
+
209
+
210
+ class EnvelopeExtractor(nn.Module):
211
+ """Extracts the amplitude envelope of the audio signal."""
212
+
213
+ def __init__(self, kernel_size=101):
214
+ super().__init__()
215
+ # Lowpass filter for smoothing envelope (moving average)
216
+ self.kernel_size = kernel_size
217
+ self.register_buffer("kernel", torch.ones(1, 1, kernel_size) / kernel_size)
218
+
219
+ def forward(self, x):
220
+ # x: (B, 1, T) -> abs(x)
221
+ envelope = torch.abs(x)
222
+ # Apply low-pass smoothing (via conv1d)
223
+ envelope = F.pad(
224
+ envelope, (self.kernel_size // 2, self.kernel_size // 2), mode="reflect"
225
+ )
226
+ envelope = F.conv1d(envelope, self.kernel)
227
+ return envelope
228
+
229
+
230
+ class DiscriminatorEnvelope(ConvNets):
231
+ def __init__(self, use_spectral_norm=False):
232
+ super().__init__()
233
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
234
+ self.extractor = EnvelopeExtractor(kernel_size=101)
235
+ self.convs = nn.ModuleList(
236
+ [
237
+ norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
238
+ norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
239
+ norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
240
+ norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
241
+ norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
242
+ norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
243
+ ]
244
+ )
245
+ self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
246
+ self.activation = nn.LeakyReLU(0.1)
247
+
248
+ def forward(self, x):
249
+ # Input: raw audio (B, 1, T)
250
+ x = self.extractor(x)
251
+ fmap = []
252
+ for layer in self.convs:
253
+ x = self.activation(layer(x))
254
+ fmap.append(x)
255
+ x = self.conv_post(x)
256
+ fmap.append(x)
257
+ return x.flatten(1), fmap
258
+
259
+
260
+ class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
261
+ def __init__(self, use_spectral_norm: bool = False):
262
+ super().__init__()
263
+ self.discriminators = nn.ModuleList(
264
+ [
265
+ DiscriminatorEnvelope(use_spectral_norm), # raw envelope
266
+ DiscriminatorEnvelope(use_spectral_norm), # downsampled once
267
+ DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
268
+ ]
269
+ )
270
+ self.meanpools = nn.ModuleList(
271
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
272
+ )
273
+
274
+ def forward(self, y, y_hat):
275
+ y_d_rs, y_d_gs = [], []
276
+ fmap_rs, fmap_gs = [], []
277
+
278
+ for i, d in enumerate(self.discriminators):
279
+ if i != 0:
280
+ y = self.meanpools[i - 1](y)
281
+ y_hat = self.meanpools[i - 1](y_hat)
282
+ y_d_r, fmap_r = d(y)
283
+ y_d_g, fmap_g = d(y_hat)
284
+ y_d_rs.append(y_d_r)
285
+ y_d_gs.append(y_d_g)
286
+ fmap_rs.append(fmap_r)
287
+ fmap_gs.append(fmap_g)
288
+
289
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
290
+
291
+
292
+ class DiscriminatorB(ConvNets):
293
+ """
294
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
295
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ window_length: int,
301
+ channels: int = 32,
302
+ hop_factor: float = 0.25,
303
+ bands: Tuple[Tuple[float, float], ...] = (
304
+ (0.0, 0.1),
305
+ (0.1, 0.25),
306
+ (0.25, 0.5),
307
+ (0.5, 0.75),
308
+ (0.75, 1.0),
309
+ ),
310
+ ):
311
+ super().__init__()
312
+ self.window_length = window_length
313
+ self.hop_factor = hop_factor
314
+ self.spec_fn = T.Spectrogram(
315
+ n_fft=window_length,
316
+ hop_length=int(window_length * hop_factor),
317
+ win_length=window_length,
318
+ power=None,
319
+ )
320
+ n_fft = window_length // 2 + 1
321
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
322
+ self.bands = bands
323
+ convs = lambda: nn.ModuleList(
324
+ [
325
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
326
+ weight_norm(
327
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
328
+ ),
329
+ weight_norm(
330
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
331
+ ),
332
+ weight_norm(
333
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
334
+ ),
335
+ weight_norm(
336
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
337
+ ),
338
+ ]
339
+ )
340
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
341
+
342
+ self.conv_post = weight_norm(
343
+ nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
344
+ )
345
+
346
+ def spectrogram(self, x: Tensor) -> List[Tensor]:
347
+ # Remove DC offset
348
+ x = x - x.mean(dim=-1, keepdims=True)
349
+ # Peak normalize the volume of input audio
350
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
351
+ x = self.spec_fn(x)
352
+ x = torch.view_as_real(x)
353
+ x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
354
+ # Split into bands
355
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
356
+ return x_bands
357
+
358
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
359
+ x_bands = self.spectrogram(x.squeeze(1))
360
+ fmap = []
361
+ x = []
362
+
363
+ for band, stack in zip(x_bands, self.band_convs):
364
+ for i, layer in enumerate(stack):
365
+ band = layer(band)
366
+ band = torch.nn.functional.leaky_relu(band, 0.1)
367
+ if i > 0:
368
+ fmap.append(band)
369
+ x.append(band)
370
+
371
+ x = torch.cat(x, dim=-1)
372
+ x = self.conv_post(x)
373
+ fmap.append(x)
374
+
375
+ return x, fmap
376
+
377
+
378
+ class MultiBandDiscriminator(MultiDiscriminatorWrapper):
379
+ """
380
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
381
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ mbd_fft_sizes: list[int] = [2048, 1024, 512],
387
+ ):
388
+ super().__init__()
389
+ self.fft_sizes = mbd_fft_sizes
390
+ self.discriminators = nn.ModuleList(
391
+ [DiscriminatorB(window_length=w) for w in self.fft_sizes]
392
+ )
393
+
394
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
395
+
396
+ y_d_rs = []
397
+ y_d_gs = []
398
+ fmap_rs = []
399
+ fmap_gs = []
400
+
401
+ for d in self.discriminators:
402
+
403
+ y_d_r, fmap_r = d(x=y)
404
+ y_d_g, fmap_g = d(x=y_hat)
405
+ y_d_rs.append(y_d_r)
406
+ fmap_rs.append(fmap_r)
407
+ y_d_gs.append(y_d_g)
408
+ fmap_gs.append(fmap_g)
409
+
410
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
411
+
412
+
413
+ class DiscriminatorR(ConvNets):
414
+ def __init__(
415
+ self,
416
+ resolution: List[int],
417
+ use_spectral_norm: bool = False,
418
+ discriminator_channel_mult: int = 1,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.resolution = resolution
423
+ assert (
424
+ len(self.resolution) == 3
425
+ ), f"MRD layer requires list with len=3, got {self.resolution}"
426
+ self.lrelu_slope = 0.1
427
+
428
+ self.register_buffer("window", torch.hann_window(self.resolution[-1]))
429
+
430
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
431
+
432
+ self.convs = nn.ModuleList(
433
+ [
434
+ norm_f(
435
+ nn.Conv2d(
436
+ 1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
437
+ )
438
+ ),
439
+ norm_f(
440
+ nn.Conv2d(
441
+ int(32 * discriminator_channel_mult),
442
+ int(32 * discriminator_channel_mult),
443
+ (3, 9),
444
+ stride=(1, 2),
445
+ padding=(1, 4),
446
+ )
447
+ ),
448
+ norm_f(
449
+ nn.Conv2d(
450
+ int(32 * discriminator_channel_mult),
451
+ int(32 * discriminator_channel_mult),
452
+ (3, 9),
453
+ stride=(1, 2),
454
+ padding=(1, 4),
455
+ )
456
+ ),
457
+ norm_f(
458
+ nn.Conv2d(
459
+ int(32 * discriminator_channel_mult),
460
+ int(32 * discriminator_channel_mult),
461
+ (3, 9),
462
+ stride=(1, 2),
463
+ padding=(1, 4),
464
+ )
465
+ ),
466
+ norm_f(
467
+ nn.Conv2d(
468
+ int(32 * discriminator_channel_mult),
469
+ int(32 * discriminator_channel_mult),
470
+ (3, 3),
471
+ padding=(1, 1),
472
+ )
473
+ ),
474
+ ]
475
+ )
476
+ self.conv_post = norm_f(
477
+ nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
478
+ )
479
+
480
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
481
+ fmap = []
482
+ x = self.spectrogram(x)
483
+ x = x.unsqueeze(1)
484
+ for l in self.convs:
485
+ x = l(x)
486
+ x = F.leaky_relu(x, self.lrelu_slope)
487
+ fmap.append(x)
488
+ x = self.conv_post(x)
489
+ fmap.append(x)
490
+ x = torch.flatten(x, 1, -1)
491
+
492
+ return x, fmap
493
+
494
+ def spectrogram(self, x: Tensor) -> Tensor:
495
+ n_fft, hop_length, win_length = self.resolution
496
+ x = F.pad(
497
+ x,
498
+ (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
499
+ mode="reflect",
500
+ )
501
+ x = x.squeeze(1)
502
+ x = torch.stft(
503
+ x,
504
+ n_fft=n_fft,
505
+ hop_length=hop_length,
506
+ win_length=win_length,
507
+ center=False,
508
+ return_complex=True,
509
+ window=self.window,
510
+ )
511
+ x = torch.view_as_real(x) # [B, F, TT, 2]
512
+ mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
513
+
514
+ return mag
515
+
516
+
517
+ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
518
+ def __init__(
519
+ self,
520
+ use_spectral_norm: bool = False,
521
+ discriminator_channel_mult: int = 1,
522
+ resolutions: List[List[int]] = [
523
+ [1024, 120, 600],
524
+ [2048, 240, 1200],
525
+ [512, 50, 240],
526
+ ],
527
+ ):
528
+ super().__init__()
529
+ self.resolutions = resolutions
530
+ assert (
531
+ len(self.resolutions) == 3
532
+ ), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}, type: {type(self.resolutions)}"
533
+ self.discriminators = nn.ModuleList(
534
+ [
535
+ DiscriminatorR(
536
+ resolution, use_spectral_norm, discriminator_channel_mult
537
+ )
538
+ for resolution in self.resolutions
539
+ ]
540
+ )
541
+
542
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
543
+ y_d_rs = []
544
+ y_d_gs = []
545
+ fmap_rs = []
546
+ fmap_gs = []
547
+ for disc in self.discriminators:
548
+ y_d_r, fmap_r = disc(x=y)
549
+ y_d_g, fmap_g = disc(x=y_hat)
550
+ y_d_rs.append(y_d_r)
551
+ fmap_rs.append(fmap_r)
552
+ y_d_gs.append(y_d_g)
553
+ fmap_gs.append(fmap_g)
554
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
555
+
556
+
557
+ class MultiDiscriminatorStep(Model):
558
+ def __init__(self, list_discriminator: List[MultiDiscriminatorWrapper]):
559
+ super().__init__()
560
+ self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
561
+ list_discriminator
562
+ )
563
+ self.total = len(self.disc)
564
+
565
+ def forward(
566
+ self,
567
+ y: Tensor,
568
+ y_hat: Tensor,
569
+ step_type: Literal["discriminator", "generator"],
570
+ ) -> Union[
571
+ Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
572
+ ]:
573
+ """
574
+ It returns the content based on the choice of "step_type", being it a
575
+ 'discriminator' or 'generator'
576
+
577
+ For generator it returns:
578
+ Tuple[Tensor, Tensor, List[float]]
579
+ "gen_loss, feat_loss, all_g_losses"
580
+
581
+ For 'discriminator' it returns:
582
+ Tuple[Tensor, List[float], List[float]]
583
+ "disc_loss, disc_real_losses, disc_gen_losses"
584
+ """
585
+ if step_type == "generator":
586
+ all_g_losses: List[float] = []
587
+ feat_loss: Tensor = 0
588
+ gen_loss: Tensor = 0
589
+ else:
590
+ disc_loss: Tensor = 0
591
+ disc_real_losses: List[float] = []
592
+ disc_gen_losses: List[float] = []
593
+
594
+ for disc in self.disc:
595
+ if step_type == "generator":
596
+ # feature loss, generator loss, list of generator losses (float)]
597
+ f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
598
+ gen_loss += g_loss
599
+ feat_loss += f_loss
600
+ all_g_losses.extend(g_losses)
601
+ else:
602
+ # [discriminator loss, (disc losses real, disc losses generated)]
603
+ d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
604
+ disc_loss += d_loss
605
+ disc_real_losses.extend(d_real_losses)
606
+ disc_gen_losses.extend(d_gen_losses)
607
+
608
+ if step_type == "generator":
609
+ return gen_loss, feat_loss, all_g_losses
610
+ return disc_loss, disc_real_losses, disc_gen_losses
@@ -105,7 +105,6 @@ class AudioProcessor(Model):
105
105
  onesided=self.cfg.onesided,
106
106
  normalized=self.cfg.normalized,
107
107
  )
108
- self.griffin_lm_iters = 32
109
108
  self.mel_rscale = torchaudio.transforms.InverseMelScale(
110
109
  n_stft=self.cfg.n_stft,
111
110
  n_mels=self.cfg.n_mels,
@@ -114,22 +113,12 @@ class AudioProcessor(Model):
114
113
  f_max=self.cfg.f_max,
115
114
  mel_scale=self.cfg.mel_scale,
116
115
  )
117
- self.giffin_lim = torchaudio.transforms.GriffinLim(
118
- n_fft=self.cfg.n_fft,
119
- win_length=self.cfg.win_length,
120
- hop_length=self.cfg.hop_length,
121
- )
116
+
122
117
  self.register_buffer(
123
118
  "window",
124
119
  (torch.hann_window(self.cfg.win_length) if window is None else window),
125
120
  )
126
121
 
127
- def _apply_device(self):
128
- print(f"Audio Processor Device: {self.device.type}")
129
- self.giffin_lim.to(device=self.device)
130
- self._mel_spec.to(device=self.device)
131
- self.mel_rscale.to(device=self.device)
132
-
133
122
  def from_numpy(
134
123
  self,
135
124
  array: np.ndarray,
@@ -173,7 +162,9 @@ class AudioProcessor(Model):
173
162
  )
174
163
 
175
164
  if audio is None and mel is not None:
176
- return self.from_numpy(librosa.feature.rms(S=mel, **rms_kwargs)[0])
165
+ return self.from_numpy(
166
+ librosa.feature.rms(S=mel, **rms_kwargs)[0]
167
+ ).squeeze()
177
168
  default_dtype = audio.dtype
178
169
  default_device = audio.device
179
170
  if audio.ndim > 1:
@@ -192,8 +183,12 @@ class AudioProcessor(Model):
192
183
  audio = self.to_numpy_safe(audio)
193
184
  if B == 1:
194
185
  if mel is None:
195
- return self.from_numpy(librosa.feature.rms(y=audio, **rms_kwargs)[0])
196
- return self.from_numpy(librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0])
186
+ return self.from_numpy(
187
+ librosa.feature.rms(y=audio, **rms_kwargs)[0]
188
+ ).squeeze()
189
+ return self.from_numpy(
190
+ librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0]
191
+ ).squeeze()
197
192
  else:
198
193
  rms_ = []
199
194
  for i in range(B):
@@ -201,7 +196,7 @@ class AudioProcessor(Model):
201
196
  0
202
197
  ]
203
198
  rms_.append(_r)
204
- return self.from_numpy_batch(rms_, default_device, default_dtype)
199
+ return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
205
200
 
206
201
  def compute_pitch(
207
202
  self,
@@ -273,7 +268,7 @@ class AudioProcessor(Model):
273
268
  win_length=win_length,
274
269
  freq_low=fmin,
275
270
  freq_high=fmax,
276
- )
271
+ ).squeeze()
277
272
 
278
273
  def interpolate(
279
274
  self,
@@ -312,7 +307,7 @@ class AudioProcessor(Model):
312
307
  antialias=antialias,
313
308
  )
314
309
 
315
- def inverse_transform(
310
+ def istft(
316
311
  self,
317
312
  spec: Tensor,
318
313
  phase: Tensor,
@@ -320,6 +315,10 @@ class AudioProcessor(Model):
320
315
  hop_length: Optional[int] = None,
321
316
  win_length: Optional[int] = None,
322
317
  length: Optional[int] = None,
318
+ center: Optional[bool] = None,
319
+ normalized: Optional[bool] = None,
320
+ onesided: Optional[bool] = None,
321
+ return_complex: bool = False,
323
322
  *,
324
323
  _recall: bool = False,
325
324
  ):
@@ -331,25 +330,25 @@ class AudioProcessor(Model):
331
330
  try:
332
331
  return torch.istft(
333
332
  spec * torch.exp(phase * 1j),
334
- n_fft=n_fft or self.cfg.n_fft,
335
- hop_length=hop_length or self.cfg.hop_length,
336
- win_length=win_length or self.cfg.win_length,
333
+ n_fft=default(n_fft, self.cfg.n_fft),
334
+ hop_length=default(hop_length, self.cfg.hop_length),
335
+ win_length=default(win_length, self.cfg.win_length),
337
336
  window=window,
338
- center=self.cfg.center,
339
- normalized=self.cfg.normalized,
340
- onesided=self.cfg.onesided,
337
+ center=default(center, self.cfg.center),
338
+ normalized=default(normalized, self.cfg.normalized),
339
+ onesided=default(onesided, self.cfg.onesided),
341
340
  length=length,
342
- return_complex=False,
341
+ return_complex=return_complex,
343
342
  )
344
343
  except RuntimeError as e:
345
344
  if not _recall and spec.device != self.window.device:
346
345
  self.window = self.window.to(spec.device)
347
- return self.inverse_transform(
346
+ return self.istft(
348
347
  spec, phase, n_fft, hop_length, win_length, length, _recall=True
349
348
  )
350
349
  raise e
351
350
 
352
- def normalize_audio(
351
+ def istft_norm(
353
352
  self,
354
353
  wave: Tensor,
355
354
  length: Optional[int] = None,
@@ -389,7 +388,7 @@ class AudioProcessor(Model):
389
388
  except RuntimeError as e:
390
389
  if not _recall and wave.device != self.window.device:
391
390
  self.window = self.window.to(wave.device)
392
- return self.normalize_audio(wave, length, _recall=True)
391
+ return self.istft_norm(wave, length, _recall=True)
393
392
  raise e
394
393
 
395
394
  def compute_mel(
@@ -415,14 +414,6 @@ class AudioProcessor(Model):
415
414
  return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
416
415
  raise e
417
416
 
418
- def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
419
- if isinstance(n_iter, int) and n_iter != self.griffin_lm_iters:
420
- self.giffin_lim.n_iter = n_iter
421
- self.griffin_lm_iters = n_iter
422
- return self.giffin_lim.forward(
423
- self.mel_rscale(mel),
424
- )
425
-
426
417
  def load_audio(
427
418
  self,
428
419
  path: PathLike,
@@ -506,14 +497,9 @@ class AudioProcessor(Model):
506
497
  maximum,
507
498
  )
508
499
 
509
- def stft_loss(
510
- self,
511
- signal: Tensor,
512
- ground: Tensor,
513
- ):
514
- with torch.no_grad():
515
- ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
516
- return F.l1_loss(signal, ground)
500
+ def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
501
+ ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
502
+ return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
517
503
 
518
504
  def forward(
519
505
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a27
3
+ Version: 0.0.1a29
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -4,12 +4,12 @@ lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
5
  lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
6
6
  lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
7
- lt_tensor/model_base.py,sha256=DTg44N6eTXLmpIAj_ac29-M5dI_iY_sC0yA_K3E13GI,17446
7
+ lt_tensor/model_base.py,sha256=8qcFXe0_y8f1_tAwt18gwQjyyapbnVEKcjCMrKnQatw,17614
8
8
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
9
9
  lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
10
  lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
11
11
  lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
- lt_tensor/model_zoo/__init__.py,sha256=ltVTvmOlbOCfDc5Trvg0-Ta_Ujgkw0UVF9V5rqHx-RI,378
12
+ lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
13
13
  lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
14
14
  lt_tensor/model_zoo/convs.py,sha256=YQRxek75Qpsha8nfc7wLhmJS9XxPeCa4WxuftLg6IcE,3927
15
15
  lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
@@ -26,10 +26,12 @@ lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9
26
26
  lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
+ lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
30
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=yYh7HzRTUtr0RVTG7cWpcYsJZsRCz6yzg6Loq8FtyOk,20405
29
31
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
30
- lt_tensor/processors/audio.py,sha256=mZY7LOeYACnX8PLz5AeFe0zqEebPoN-Q44Bi3yrlZMQ,16881
31
- lt_tensor-0.0.1a27.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
32
- lt_tensor-0.0.1a27.dist-info/METADATA,sha256=NpXqioPXZMvXo-HzhXrS6O1qiftDnoc8ZzOfhfUMBaY,1062
33
- lt_tensor-0.0.1a27.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- lt_tensor-0.0.1a27.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
35
- lt_tensor-0.0.1a27.dist-info/RECORD,,
32
+ lt_tensor/processors/audio.py,sha256=1JuxxexfUsXkLjVjWUk-oTRU-QNnCCwvKX3eP0m7LGE,16452
33
+ lt_tensor-0.0.1a29.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
34
+ lt_tensor-0.0.1a29.dist-info/METADATA,sha256=F03dNMnEydcKjjZF3IntNaIj34FwLdoy-L0pBB_yz0E,1062
35
+ lt_tensor-0.0.1a29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ lt_tensor-0.0.1a29.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
+ lt_tensor-0.0.1a29.dist-info/RECORD,,