lt-tensor 0.0.1a27__py3-none-any.whl → 0.0.1a28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ __all__ = [
9
9
  "audio_models",
10
10
  "hifigan",
11
11
  "istft",
12
+ "losses",
12
13
  ]
13
14
  from .audio_models import hifigan, istft
14
15
  from . import (
@@ -19,4 +20,5 @@ from . import (
19
20
  pos_encoder,
20
21
  residual,
21
22
  transformer,
23
+ losses,
22
24
  )
@@ -0,0 +1,3 @@
1
+ from . import discriminators
2
+
3
+ __all__ = ["discriminators"]
@@ -0,0 +1,610 @@
1
+ from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
2
+ from lt_utils.common import *
3
+ from lt_tensor.torch_commons import *
4
+ from lt_tensor.model_base import Model
5
+ from lt_tensor.model_zoo.convs import ConvNets
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+
9
+ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
10
+ List[Tensor],
11
+ List[Tensor],
12
+ List[List[Tensor]],
13
+ List[List[Tensor]],
14
+ ]
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class MultiDiscriminatorWrapper(ConvNets):
22
+ """Base for all multi-steps type of discriminators"""
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.leaky_relu = nn.LeakyReLU(kwargs.get("negative_slope", 0.1))
26
+
27
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
28
+ pass
29
+
30
+ # for type hinting
31
+ def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
32
+ return super().__call__(*args, **kwds)
33
+
34
+ def gen_step(self, y: Tensor, y_hat: Tensor) -> tuple[Tensor, Tensor, List[float]]:
35
+ """For generator loss step [feature loss, generator loss, list of generator losses (float)]"""
36
+ _, y_hat_gen, feat_map_real, feat_map_gen = self.train_step(y, y_hat)
37
+ loss_feat = self.feature_loss(feat_map_real, feat_map_gen)
38
+ loss_generator, losses_gen_s = self.generator_loss(y_hat_gen)
39
+ return loss_feat, loss_generator, losses_gen_s
40
+
41
+ def disc_step(
42
+ self, y: Tensor, y_hat: Tensor
43
+ ) -> tuple[Tensor, tuple[List[float], List[float]]]:
44
+ """For discriminator loss step [discriminator loss, (disc losses real, disc losses generated)]"""
45
+ y_hat_real, y_hat_gen, _, _ = self.train_step(y, y_hat)
46
+
47
+ loss_disc, losses_disc_real, losses_disc_generated = self.discriminator_loss(
48
+ y_hat_real, y_hat_gen
49
+ )
50
+ return loss_disc, (losses_disc_real, losses_disc_generated)
51
+
52
+ @staticmethod
53
+ def discriminator_loss(
54
+ disc_real_outputs, disc_generated_outputs
55
+ ) -> Tuple[Tensor, List[float], List[float]]:
56
+ loss = 0
57
+ r_losses = []
58
+ g_losses = []
59
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
60
+ r_loss = torch.mean((1 - dr) ** 2)
61
+ g_loss = torch.mean(dg**2)
62
+ loss += r_loss + g_loss
63
+ r_losses.append(r_loss.item())
64
+ g_losses.append(g_loss.item())
65
+
66
+ return loss, r_losses, g_losses
67
+
68
+ @staticmethod
69
+ def feature_loss(fmap_r, fmap_g) -> Tensor:
70
+ loss = 0
71
+ for dr, dg in zip(fmap_r, fmap_g):
72
+ for rl, gl in zip(dr, dg):
73
+ loss += torch.mean(torch.abs(rl - gl))
74
+
75
+ return loss * 2
76
+
77
+ @staticmethod
78
+ def generator_loss(disc_outputs) -> Tuple[Tensor, List[float]]:
79
+ loss = 0
80
+ gen_losses = []
81
+ for dg in disc_outputs:
82
+ l = torch.mean((1 - dg) ** 2)
83
+ gen_losses.append(l.item())
84
+ loss += l
85
+
86
+ return loss, gen_losses
87
+
88
+
89
+ class DiscriminatorP(ConvNets):
90
+ def __init__(
91
+ self,
92
+ period: List[int],
93
+ discriminator_channel_mult: Number = 1,
94
+ kernel_size: int = 5,
95
+ stride: int = 3,
96
+ use_spectral_norm: bool = False,
97
+ ):
98
+ super().__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
101
+ dsc = lambda x: int(x * discriminator_channel_mult)
102
+ self.convs = nn.ModuleList(
103
+ [
104
+ norm_f(
105
+ nn.Conv2d(
106
+ 1,
107
+ dsc(32),
108
+ (kernel_size, 1),
109
+ (stride, 1),
110
+ padding=(get_padding(5, 1), 0),
111
+ )
112
+ ),
113
+ norm_f(
114
+ nn.Conv2d(
115
+ dsc(32),
116
+ dsc(128),
117
+ (kernel_size, 1),
118
+ (stride, 1),
119
+ padding=(get_padding(5, 1), 0),
120
+ )
121
+ ),
122
+ norm_f(
123
+ nn.Conv2d(
124
+ dsc(128),
125
+ dsc(512),
126
+ (kernel_size, 1),
127
+ (stride, 1),
128
+ padding=(get_padding(5, 1), 0),
129
+ )
130
+ ),
131
+ norm_f(
132
+ nn.Conv2d(
133
+ dsc(512),
134
+ dsc(1024),
135
+ (kernel_size, 1),
136
+ (stride, 1),
137
+ padding=(get_padding(5, 1), 0),
138
+ )
139
+ ),
140
+ norm_f(
141
+ nn.Conv2d(
142
+ dsc(1024),
143
+ dsc(1024),
144
+ (kernel_size, 1),
145
+ 1,
146
+ padding=(2, 0),
147
+ )
148
+ ),
149
+ ]
150
+ )
151
+ self.conv_post = norm_f(nn.Conv2d(dsc(1024), 1, (3, 1), 1, padding=(1, 0)))
152
+
153
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
154
+ fmap = []
155
+
156
+ # 1d to 2d
157
+ b, c, t = x.shape
158
+ if t % self.period != 0: # pad first
159
+ n_pad = self.period - (t % self.period)
160
+ x = F.pad(x, (0, n_pad), "reflect")
161
+ t = t + n_pad
162
+ x = x.view(b, c, t // self.period, self.period)
163
+
164
+ for l in self.convs:
165
+ x = l(x)
166
+ x = F.leaky_relu(x, 0.1)
167
+ fmap.append(x)
168
+ x = self.conv_post(x)
169
+ fmap.append(x)
170
+ return x.flatten(1, -1), fmap
171
+
172
+
173
+ class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
174
+ def __init__(
175
+ self,
176
+ discriminator_channel_mult: Number = 1,
177
+ mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
178
+ use_spectral_norm: bool = False,
179
+ ):
180
+ super().__init__()
181
+ self.mpd_reshapes = mpd_reshapes
182
+ print(f"mpd_reshapes: {self.mpd_reshapes}")
183
+ self.discriminators = nn.ModuleList(
184
+ [
185
+ DiscriminatorP(
186
+ rs,
187
+ use_spectral_norm=use_spectral_norm,
188
+ discriminator_channel_mult=discriminator_channel_mult,
189
+ )
190
+ for rs in self.mpd_reshapes
191
+ ]
192
+ )
193
+
194
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> MULTI_DISC_OUT_TYPE:
195
+ y_d_rs = []
196
+ y_d_gs = []
197
+ fmap_rs = []
198
+ fmap_gs = []
199
+ for i, d in enumerate(self.discriminators):
200
+ y_d_r, fmap_r = d(y)
201
+ y_d_g, fmap_g = d(y_hat)
202
+ y_d_rs.append(y_d_r)
203
+ fmap_rs.append(fmap_r)
204
+ y_d_gs.append(y_d_g)
205
+ fmap_gs.append(fmap_g)
206
+
207
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
208
+
209
+
210
+ class EnvelopeExtractor(nn.Module):
211
+ """Extracts the amplitude envelope of the audio signal."""
212
+
213
+ def __init__(self, kernel_size=101):
214
+ super().__init__()
215
+ # Lowpass filter for smoothing envelope (moving average)
216
+ self.kernel_size = kernel_size
217
+ self.register_buffer("kernel", torch.ones(1, 1, kernel_size) / kernel_size)
218
+
219
+ def forward(self, x):
220
+ # x: (B, 1, T) -> abs(x)
221
+ envelope = torch.abs(x)
222
+ # Apply low-pass smoothing (via conv1d)
223
+ envelope = F.pad(
224
+ envelope, (self.kernel_size // 2, self.kernel_size // 2), mode="reflect"
225
+ )
226
+ envelope = F.conv1d(envelope, self.kernel)
227
+ return envelope
228
+
229
+
230
+ class DiscriminatorEnvelope(ConvNets):
231
+ def __init__(self, use_spectral_norm=False):
232
+ super().__init__()
233
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
234
+ self.extractor = EnvelopeExtractor(kernel_size=101)
235
+ self.convs = nn.ModuleList(
236
+ [
237
+ norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
238
+ norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
239
+ norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
240
+ norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
241
+ norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
242
+ norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
243
+ ]
244
+ )
245
+ self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
246
+ self.activation = nn.LeakyReLU(0.1)
247
+
248
+ def forward(self, x):
249
+ # Input: raw audio (B, 1, T)
250
+ x = self.extractor(x)
251
+ fmap = []
252
+ for layer in self.convs:
253
+ x = self.activation(layer(x))
254
+ fmap.append(x)
255
+ x = self.conv_post(x)
256
+ fmap.append(x)
257
+ return x.flatten(1), fmap
258
+
259
+
260
+ class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
261
+ def __init__(self, use_spectral_norm: bool = False):
262
+ super().__init__()
263
+ self.discriminators = nn.ModuleList(
264
+ [
265
+ DiscriminatorEnvelope(use_spectral_norm), # raw envelope
266
+ DiscriminatorEnvelope(use_spectral_norm), # downsampled once
267
+ DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
268
+ ]
269
+ )
270
+ self.meanpools = nn.ModuleList(
271
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
272
+ )
273
+
274
+ def forward(self, y, y_hat):
275
+ y_d_rs, y_d_gs = [], []
276
+ fmap_rs, fmap_gs = [], []
277
+
278
+ for i, d in enumerate(self.discriminators):
279
+ if i != 0:
280
+ y = self.meanpools[i - 1](y)
281
+ y_hat = self.meanpools[i - 1](y_hat)
282
+ y_d_r, fmap_r = d(y)
283
+ y_d_g, fmap_g = d(y_hat)
284
+ y_d_rs.append(y_d_r)
285
+ y_d_gs.append(y_d_g)
286
+ fmap_rs.append(fmap_r)
287
+ fmap_gs.append(fmap_g)
288
+
289
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
290
+
291
+
292
+ class DiscriminatorB(ConvNets):
293
+ """
294
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
295
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ window_length: int,
301
+ channels: int = 32,
302
+ hop_factor: float = 0.25,
303
+ bands: Tuple[Tuple[float, float], ...] = (
304
+ (0.0, 0.1),
305
+ (0.1, 0.25),
306
+ (0.25, 0.5),
307
+ (0.5, 0.75),
308
+ (0.75, 1.0),
309
+ ),
310
+ ):
311
+ super().__init__()
312
+ self.window_length = window_length
313
+ self.hop_factor = hop_factor
314
+ self.spec_fn = T.Spectrogram(
315
+ n_fft=window_length,
316
+ hop_length=int(window_length * hop_factor),
317
+ win_length=window_length,
318
+ power=None,
319
+ )
320
+ n_fft = window_length // 2 + 1
321
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
322
+ self.bands = bands
323
+ convs = lambda: nn.ModuleList(
324
+ [
325
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
326
+ weight_norm(
327
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
328
+ ),
329
+ weight_norm(
330
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
331
+ ),
332
+ weight_norm(
333
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
334
+ ),
335
+ weight_norm(
336
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
337
+ ),
338
+ ]
339
+ )
340
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
341
+
342
+ self.conv_post = weight_norm(
343
+ nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
344
+ )
345
+
346
+ def spectrogram(self, x: Tensor) -> List[Tensor]:
347
+ # Remove DC offset
348
+ x = x - x.mean(dim=-1, keepdims=True)
349
+ # Peak normalize the volume of input audio
350
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
351
+ x = self.spec_fn(x)
352
+ x = torch.view_as_real(x)
353
+ x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
354
+ # Split into bands
355
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
356
+ return x_bands
357
+
358
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
359
+ x_bands = self.spectrogram(x.squeeze(1))
360
+ fmap = []
361
+ x = []
362
+
363
+ for band, stack in zip(x_bands, self.band_convs):
364
+ for i, layer in enumerate(stack):
365
+ band = layer(band)
366
+ band = torch.nn.functional.leaky_relu(band, 0.1)
367
+ if i > 0:
368
+ fmap.append(band)
369
+ x.append(band)
370
+
371
+ x = torch.cat(x, dim=-1)
372
+ x = self.conv_post(x)
373
+ fmap.append(x)
374
+
375
+ return x, fmap
376
+
377
+
378
+ class MultiBandDiscriminator(MultiDiscriminatorWrapper):
379
+ """
380
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
381
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ mbd_fft_sizes: list[int] = [2048, 1024, 512],
387
+ ):
388
+ super().__init__()
389
+ self.fft_sizes = mbd_fft_sizes
390
+ self.discriminators = nn.ModuleList(
391
+ [DiscriminatorB(window_length=w) for w in self.fft_sizes]
392
+ )
393
+
394
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
395
+
396
+ y_d_rs = []
397
+ y_d_gs = []
398
+ fmap_rs = []
399
+ fmap_gs = []
400
+
401
+ for d in self.discriminators:
402
+
403
+ y_d_r, fmap_r = d(x=y)
404
+ y_d_g, fmap_g = d(x=y_hat)
405
+ y_d_rs.append(y_d_r)
406
+ fmap_rs.append(fmap_r)
407
+ y_d_gs.append(y_d_g)
408
+ fmap_gs.append(fmap_g)
409
+
410
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
411
+
412
+
413
+ class DiscriminatorR(ConvNets):
414
+ def __init__(
415
+ self,
416
+ resolution: List[int],
417
+ use_spectral_norm: bool = False,
418
+ discriminator_channel_mult: int = 1,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.resolution = resolution
423
+ assert (
424
+ len(self.resolution) == 3
425
+ ), f"MRD layer requires list with len=3, got {self.resolution}"
426
+ self.lrelu_slope = 0.1
427
+
428
+ self.register_buffer("window", torch.hann_window(self.resolution[-1]))
429
+
430
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
431
+
432
+ self.convs = nn.ModuleList(
433
+ [
434
+ norm_f(
435
+ nn.Conv2d(
436
+ 1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
437
+ )
438
+ ),
439
+ norm_f(
440
+ nn.Conv2d(
441
+ int(32 * discriminator_channel_mult),
442
+ int(32 * discriminator_channel_mult),
443
+ (3, 9),
444
+ stride=(1, 2),
445
+ padding=(1, 4),
446
+ )
447
+ ),
448
+ norm_f(
449
+ nn.Conv2d(
450
+ int(32 * discriminator_channel_mult),
451
+ int(32 * discriminator_channel_mult),
452
+ (3, 9),
453
+ stride=(1, 2),
454
+ padding=(1, 4),
455
+ )
456
+ ),
457
+ norm_f(
458
+ nn.Conv2d(
459
+ int(32 * discriminator_channel_mult),
460
+ int(32 * discriminator_channel_mult),
461
+ (3, 9),
462
+ stride=(1, 2),
463
+ padding=(1, 4),
464
+ )
465
+ ),
466
+ norm_f(
467
+ nn.Conv2d(
468
+ int(32 * discriminator_channel_mult),
469
+ int(32 * discriminator_channel_mult),
470
+ (3, 3),
471
+ padding=(1, 1),
472
+ )
473
+ ),
474
+ ]
475
+ )
476
+ self.conv_post = norm_f(
477
+ nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
478
+ )
479
+
480
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
481
+ fmap = []
482
+ x = self.spectrogram(x)
483
+ x = x.unsqueeze(1)
484
+ for l in self.convs:
485
+ x = l(x)
486
+ x = F.leaky_relu(x, self.lrelu_slope)
487
+ fmap.append(x)
488
+ x = self.conv_post(x)
489
+ fmap.append(x)
490
+ x = torch.flatten(x, 1, -1)
491
+
492
+ return x, fmap
493
+
494
+ def spectrogram(self, x: Tensor) -> Tensor:
495
+ n_fft, hop_length, win_length = self.resolution
496
+ x = F.pad(
497
+ x,
498
+ (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
499
+ mode="reflect",
500
+ )
501
+ x = x.squeeze(1)
502
+ x = torch.stft(
503
+ x,
504
+ n_fft=n_fft,
505
+ hop_length=hop_length,
506
+ win_length=win_length,
507
+ center=False,
508
+ return_complex=True,
509
+ window=self.window,
510
+ )
511
+ x = torch.view_as_real(x) # [B, F, TT, 2]
512
+ mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
513
+
514
+ return mag
515
+
516
+
517
+ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
518
+ def __init__(
519
+ self,
520
+ use_spectral_norm: bool = False,
521
+ discriminator_channel_mult: int = 1,
522
+ resolutions: List[List[int]] = [
523
+ [1024, 120, 600],
524
+ [2048, 240, 1200],
525
+ [512, 50, 240],
526
+ ],
527
+ ):
528
+ super().__init__()
529
+ self.resolutions = resolutions
530
+ assert (
531
+ len(self.resolutions) == 3
532
+ ), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}, type: {type(self.resolutions)}"
533
+ self.discriminators = nn.ModuleList(
534
+ [
535
+ DiscriminatorR(
536
+ resolution, use_spectral_norm, discriminator_channel_mult
537
+ )
538
+ for resolution in self.resolutions
539
+ ]
540
+ )
541
+
542
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
543
+ y_d_rs = []
544
+ y_d_gs = []
545
+ fmap_rs = []
546
+ fmap_gs = []
547
+ for disc in self.discriminators:
548
+ y_d_r, fmap_r = disc(x=y)
549
+ y_d_g, fmap_g = disc(x=y_hat)
550
+ y_d_rs.append(y_d_r)
551
+ fmap_rs.append(fmap_r)
552
+ y_d_gs.append(y_d_g)
553
+ fmap_gs.append(fmap_g)
554
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
555
+
556
+
557
+ class MultiDiscriminatorStep(Model):
558
+ def __init__(self, list_discriminator: List[MultiDiscriminatorWrapper]):
559
+ super().__init__()
560
+ self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
561
+ list_discriminator
562
+ )
563
+ self.total = len(self.disc)
564
+
565
+ def forward(
566
+ self,
567
+ y: Tensor,
568
+ y_hat: Tensor,
569
+ step_type: Literal["discriminator", "generator"],
570
+ ) -> Union[
571
+ Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
572
+ ]:
573
+ """
574
+ It returns the content based on the choice of "step_type", being it a
575
+ 'discriminator' or 'generator'
576
+
577
+ For generator it returns:
578
+ Tuple[Tensor, Tensor, List[float]]
579
+ "gen_loss, feat_loss, all_g_losses"
580
+
581
+ For 'discriminator' it returns:
582
+ Tuple[Tensor, List[float], List[float]]
583
+ "disc_loss, disc_real_losses, disc_gen_losses"
584
+ """
585
+ if step_type == "generator":
586
+ all_g_losses: List[float] = []
587
+ feat_loss: Tensor = 0
588
+ gen_loss: Tensor = 0
589
+ else:
590
+ disc_loss: Tensor = 0
591
+ disc_real_losses: List[float] = []
592
+ disc_gen_losses: List[float] = []
593
+
594
+ for disc in self.disc:
595
+ if step_type == "generator":
596
+ # feature loss, generator loss, list of generator losses (float)]
597
+ f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
598
+ gen_loss += g_loss
599
+ feat_loss += f_loss
600
+ all_g_losses.extend(g_losses)
601
+ else:
602
+ # [discriminator loss, (disc losses real, disc losses generated)]
603
+ d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
604
+ disc_loss += d_loss
605
+ disc_real_losses.extend(d_real_losses)
606
+ disc_gen_losses.extend(d_gen_losses)
607
+
608
+ if step_type == "generator":
609
+ return gen_loss, feat_loss, all_g_losses
610
+ return disc_loss, disc_real_losses, disc_gen_losses
@@ -105,7 +105,6 @@ class AudioProcessor(Model):
105
105
  onesided=self.cfg.onesided,
106
106
  normalized=self.cfg.normalized,
107
107
  )
108
- self.griffin_lm_iters = 32
109
108
  self.mel_rscale = torchaudio.transforms.InverseMelScale(
110
109
  n_stft=self.cfg.n_stft,
111
110
  n_mels=self.cfg.n_mels,
@@ -114,21 +113,19 @@ class AudioProcessor(Model):
114
113
  f_max=self.cfg.f_max,
115
114
  mel_scale=self.cfg.mel_scale,
116
115
  )
117
- self.giffin_lim = torchaudio.transforms.GriffinLim(
118
- n_fft=self.cfg.n_fft,
119
- win_length=self.cfg.win_length,
120
- hop_length=self.cfg.hop_length,
121
- )
116
+
122
117
  self.register_buffer(
123
118
  "window",
124
119
  (torch.hann_window(self.cfg.win_length) if window is None else window),
125
120
  )
126
121
 
127
122
  def _apply_device(self):
128
- print(f"Audio Processor Device: {self.device.type}")
129
- self.giffin_lim.to(device=self.device)
130
123
  self._mel_spec.to(device=self.device)
131
124
  self.mel_rscale.to(device=self.device)
125
+ try:
126
+ self.window.to(device=self.device)
127
+ except:
128
+ pass
132
129
 
133
130
  def from_numpy(
134
131
  self,
@@ -173,7 +170,9 @@ class AudioProcessor(Model):
173
170
  )
174
171
 
175
172
  if audio is None and mel is not None:
176
- return self.from_numpy(librosa.feature.rms(S=mel, **rms_kwargs)[0])
173
+ return self.from_numpy(
174
+ librosa.feature.rms(S=mel, **rms_kwargs)[0]
175
+ ).squeeze()
177
176
  default_dtype = audio.dtype
178
177
  default_device = audio.device
179
178
  if audio.ndim > 1:
@@ -192,8 +191,12 @@ class AudioProcessor(Model):
192
191
  audio = self.to_numpy_safe(audio)
193
192
  if B == 1:
194
193
  if mel is None:
195
- return self.from_numpy(librosa.feature.rms(y=audio, **rms_kwargs)[0])
196
- return self.from_numpy(librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0])
194
+ return self.from_numpy(
195
+ librosa.feature.rms(y=audio, **rms_kwargs)[0]
196
+ ).squeeze()
197
+ return self.from_numpy(
198
+ librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0]
199
+ ).squeeze()
197
200
  else:
198
201
  rms_ = []
199
202
  for i in range(B):
@@ -201,7 +204,7 @@ class AudioProcessor(Model):
201
204
  0
202
205
  ]
203
206
  rms_.append(_r)
204
- return self.from_numpy_batch(rms_, default_device, default_dtype)
207
+ return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
205
208
 
206
209
  def compute_pitch(
207
210
  self,
@@ -273,7 +276,7 @@ class AudioProcessor(Model):
273
276
  win_length=win_length,
274
277
  freq_low=fmin,
275
278
  freq_high=fmax,
276
- )
279
+ ).squeeze()
277
280
 
278
281
  def interpolate(
279
282
  self,
@@ -312,7 +315,7 @@ class AudioProcessor(Model):
312
315
  antialias=antialias,
313
316
  )
314
317
 
315
- def inverse_transform(
318
+ def istft(
316
319
  self,
317
320
  spec: Tensor,
318
321
  phase: Tensor,
@@ -320,6 +323,10 @@ class AudioProcessor(Model):
320
323
  hop_length: Optional[int] = None,
321
324
  win_length: Optional[int] = None,
322
325
  length: Optional[int] = None,
326
+ center: Optional[bool] = None,
327
+ normalized: Optional[bool] = None,
328
+ onesided: Optional[bool] = None,
329
+ return_complex: bool = False,
323
330
  *,
324
331
  _recall: bool = False,
325
332
  ):
@@ -331,25 +338,25 @@ class AudioProcessor(Model):
331
338
  try:
332
339
  return torch.istft(
333
340
  spec * torch.exp(phase * 1j),
334
- n_fft=n_fft or self.cfg.n_fft,
335
- hop_length=hop_length or self.cfg.hop_length,
336
- win_length=win_length or self.cfg.win_length,
341
+ n_fft=default(n_fft, self.cfg.n_fft),
342
+ hop_length=default(hop_length, self.cfg.hop_length),
343
+ win_length=default(win_length, self.cfg.win_length),
337
344
  window=window,
338
- center=self.cfg.center,
339
- normalized=self.cfg.normalized,
340
- onesided=self.cfg.onesided,
345
+ center=default(center, self.cfg.center),
346
+ normalized=default(normalized, self.cfg.normalized),
347
+ onesided=default(onesided, self.cfg.onesided),
341
348
  length=length,
342
- return_complex=False,
349
+ return_complex=return_complex,
343
350
  )
344
351
  except RuntimeError as e:
345
352
  if not _recall and spec.device != self.window.device:
346
353
  self.window = self.window.to(spec.device)
347
- return self.inverse_transform(
354
+ return self.istft(
348
355
  spec, phase, n_fft, hop_length, win_length, length, _recall=True
349
356
  )
350
357
  raise e
351
358
 
352
- def normalize_audio(
359
+ def istft_norm(
353
360
  self,
354
361
  wave: Tensor,
355
362
  length: Optional[int] = None,
@@ -389,7 +396,7 @@ class AudioProcessor(Model):
389
396
  except RuntimeError as e:
390
397
  if not _recall and wave.device != self.window.device:
391
398
  self.window = self.window.to(wave.device)
392
- return self.normalize_audio(wave, length, _recall=True)
399
+ return self.istft_norm(wave, length, _recall=True)
393
400
  raise e
394
401
 
395
402
  def compute_mel(
@@ -415,14 +422,6 @@ class AudioProcessor(Model):
415
422
  return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
416
423
  raise e
417
424
 
418
- def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
419
- if isinstance(n_iter, int) and n_iter != self.griffin_lm_iters:
420
- self.giffin_lim.n_iter = n_iter
421
- self.griffin_lm_iters = n_iter
422
- return self.giffin_lim.forward(
423
- self.mel_rscale(mel),
424
- )
425
-
426
425
  def load_audio(
427
426
  self,
428
427
  path: PathLike,
@@ -506,14 +505,9 @@ class AudioProcessor(Model):
506
505
  maximum,
507
506
  )
508
507
 
509
- def stft_loss(
510
- self,
511
- signal: Tensor,
512
- ground: Tensor,
513
- ):
514
- with torch.no_grad():
515
- ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
516
- return F.l1_loss(signal, ground)
508
+ def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
509
+ ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
510
+ return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
517
511
 
518
512
  def forward(
519
513
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a27
3
+ Version: 0.0.1a28
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -9,7 +9,7 @@ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,
9
9
  lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
10
  lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
11
11
  lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
- lt_tensor/model_zoo/__init__.py,sha256=ltVTvmOlbOCfDc5Trvg0-Ta_Ujgkw0UVF9V5rqHx-RI,378
12
+ lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
13
13
  lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
14
14
  lt_tensor/model_zoo/convs.py,sha256=YQRxek75Qpsha8nfc7wLhmJS9XxPeCa4WxuftLg6IcE,3927
15
15
  lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
@@ -26,10 +26,12 @@ lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9
26
26
  lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
+ lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
30
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=yYh7HzRTUtr0RVTG7cWpcYsJZsRCz6yzg6Loq8FtyOk,20405
29
31
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
30
- lt_tensor/processors/audio.py,sha256=mZY7LOeYACnX8PLz5AeFe0zqEebPoN-Q44Bi3yrlZMQ,16881
31
- lt_tensor-0.0.1a27.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
32
- lt_tensor-0.0.1a27.dist-info/METADATA,sha256=NpXqioPXZMvXo-HzhXrS6O1qiftDnoc8ZzOfhfUMBaY,1062
33
- lt_tensor-0.0.1a27.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- lt_tensor-0.0.1a27.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
35
- lt_tensor-0.0.1a27.dist-info/RECORD,,
32
+ lt_tensor/processors/audio.py,sha256=rsnnNi8MtxPq9vAYoiRQ7lGjorfJIpRvrKEe3zA8YJk,16668
33
+ lt_tensor-0.0.1a28.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
34
+ lt_tensor-0.0.1a28.dist-info/METADATA,sha256=2LLguzaCAM2bcAdy_D66j4PS9Oh5PU3ZnA9qy7xcx0w,1062
35
+ lt_tensor-0.0.1a28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ lt_tensor-0.0.1a28.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
+ lt_tensor-0.0.1a28.dist-info/RECORD,,