lt-tensor 0.0.1a34__py3-none-any.whl → 0.0.1a36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. lt_tensor/__init__.py +1 -1
  2. lt_tensor/losses.py +11 -7
  3. lt_tensor/lr_schedulers.py +147 -21
  4. lt_tensor/misc_utils.py +35 -42
  5. lt_tensor/model_zoo/activations/__init__.py +3 -0
  6. lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
  7. lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
  8. lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
  9. lt_tensor/model_zoo/audio_models/__init__.py +2 -2
  10. lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
  11. lt_tensor/model_zoo/audio_models/hifigan/__init__.py +22 -357
  12. lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
  13. lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
  14. lt_tensor/model_zoo/convs.py +21 -32
  15. lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
  16. lt_tensor/model_zoo/losses/CQT/transforms.py +336 -0
  17. lt_tensor/model_zoo/losses/CQT/utils.py +519 -0
  18. lt_tensor/model_zoo/losses/discriminators.py +375 -37
  19. lt_tensor/processors/audio.py +67 -57
  20. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/METADATA +1 -1
  21. lt_tensor-0.0.1a36.dist-info/RECORD +43 -0
  22. lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
  23. lt_tensor-0.0.1a34.dist-info/RECORD +0 -37
  24. /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
  25. /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
  26. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/WHEEL +0 -0
  27. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/licenses/LICENSE +0 -0
  28. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,6 @@ from lt_tensor.model_base import Model
7
7
  from lt_tensor.model_zoo.convs import ConvNets
8
8
  from torch.nn import functional as F
9
9
  from torchaudio import transforms as T
10
- from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
11
-
12
10
 
13
11
  MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
14
12
  List[Tensor],
@@ -19,9 +17,11 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
19
17
 
20
18
 
21
19
  class MultiDiscriminatorWrapper(Model):
22
- def __init__(self, list_discriminator: List["_MultiDiscriminatorT"]):
20
+ def __init__(
21
+ self, list_discriminator: Union[List["_MultiDiscriminatorT"], nn.ModuleList]
22
+ ):
23
23
  """Setup example:
24
- model_d = MultiDiscriminatorStep(
24
+ model_d = MultiDiscriminatorWrapper(
25
25
  [
26
26
  MultiEnvelopeDiscriminator(),
27
27
  MultiBandDiscriminator(),
@@ -31,7 +31,12 @@ class MultiDiscriminatorWrapper(Model):
31
31
  )
32
32
  """
33
33
  super().__init__()
34
- self.disc: Sequence[_MultiDiscriminatorT] = nn.ModuleList(list_discriminator)
34
+
35
+ self.disc: Sequence[_MultiDiscriminatorT] = (
36
+ nn.ModuleList(list_discriminator)
37
+ if isinstance(list_discriminator, (list, tuple, set))
38
+ else list_discriminator
39
+ )
35
40
  self.total = len(self.disc)
36
41
 
37
42
  def forward(
@@ -96,7 +101,6 @@ class _MultiDiscriminatorT(ConvNets):
96
101
  def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
97
102
  pass
98
103
 
99
- # for type hinting
100
104
  def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
101
105
  return super().__call__(*args, **kwds)
102
106
 
@@ -159,7 +163,7 @@ class DiscriminatorP(ConvNets):
159
163
  def __init__(
160
164
  self,
161
165
  period: List[int],
162
- discriminator_channel_mult: Number = 1,
166
+ discriminator_channel_multi: Number = 1,
163
167
  kernel_size: int = 5,
164
168
  stride: int = 3,
165
169
  use_spectral_norm: bool = False,
@@ -167,7 +171,7 @@ class DiscriminatorP(ConvNets):
167
171
  super().__init__()
168
172
  self.period = period
169
173
  norm_f = weight_norm if not use_spectral_norm else spectral_norm
170
- dsc = lambda x: int(x * discriminator_channel_mult)
174
+ dsc = lambda x: int(x * discriminator_channel_multi)
171
175
  self.convs = nn.ModuleList(
172
176
  [
173
177
  norm_f(
@@ -242,19 +246,18 @@ class DiscriminatorP(ConvNets):
242
246
  class MultiPeriodDiscriminator(_MultiDiscriminatorT):
243
247
  def __init__(
244
248
  self,
245
- discriminator_channel_mult: Number = 1,
249
+ discriminator_channel_multi: Number = 1,
246
250
  mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
247
251
  use_spectral_norm: bool = False,
248
252
  ):
249
253
  super().__init__()
250
254
  self.mpd_reshapes = mpd_reshapes
251
- print(f"mpd_reshapes: {self.mpd_reshapes}")
252
255
  self.discriminators = nn.ModuleList(
253
256
  [
254
257
  DiscriminatorP(
255
258
  rs,
256
259
  use_spectral_norm=use_spectral_norm,
257
- discriminator_channel_mult=discriminator_channel_mult,
260
+ discriminator_channel_multi=discriminator_channel_multi,
258
261
  )
259
262
  for rs in self.mpd_reshapes
260
263
  ]
@@ -276,6 +279,79 @@ class MultiPeriodDiscriminator(_MultiDiscriminatorT):
276
279
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
277
280
 
278
281
 
282
+ class DiscriminatorS(ConvNets):
283
+ def __init__(
284
+ self,
285
+ use_spectral_norm=False,
286
+ discriminator_channel_multi: Number = 1,
287
+ ):
288
+ super().__init__()
289
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
290
+ dsc = lambda x: int(x * discriminator_channel_multi)
291
+ self.convs = nn.ModuleList(
292
+ [
293
+ norm_f(nn.Conv1d(1, dsc(128), 15, 1, padding=7)),
294
+ norm_f(nn.Conv1d(dsc(128), dsc(128), 41, 2, groups=4, padding=20)),
295
+ norm_f(nn.Conv1d(dsc(128), dsc(256), 41, 2, groups=16, padding=20)),
296
+ norm_f(nn.Conv1d(dsc(256), dsc(512), 41, 4, groups=16, padding=20)),
297
+ norm_f(nn.Conv1d(dsc(512), dsc(1024), 41, 4, groups=16, padding=20)),
298
+ norm_f(nn.Conv1d(dsc(1024), dsc(1024), 41, 1, groups=16, padding=20)),
299
+ norm_f(nn.Conv1d(dsc(1024), dsc(1024), 5, 1, padding=2)),
300
+ ]
301
+ )
302
+ self.conv_post = norm_f(nn.Conv1d(dsc(1024), 1, 3, 1, padding=1))
303
+ self.activation = nn.LeakyReLU(0.1)
304
+
305
+ def forward(self, x):
306
+ fmap = []
307
+ for l in self.convs:
308
+ x = l(x)
309
+ x = self.activation(x)
310
+ fmap.append(x)
311
+ x = self.conv_post(x)
312
+ fmap.append(x)
313
+ return x.flatten(1, -1), fmap
314
+
315
+
316
+ class MultiScaleDiscriminator(ConvNets):
317
+ def __init__(
318
+ self,
319
+ discriminator_channel_multi: Number = 1,
320
+ ):
321
+ super().__init__()
322
+ self.discriminators = nn.ModuleList(
323
+ [
324
+ DiscriminatorS(
325
+ use_spectral_norm=True,
326
+ discriminator_channel_multi=discriminator_channel_multi,
327
+ ),
328
+ DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
329
+ DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
330
+ ]
331
+ )
332
+ self.meanpools = nn.ModuleList(
333
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
334
+ )
335
+
336
+ def forward(self, y, y_hat):
337
+ y_d_rs = []
338
+ y_d_gs = []
339
+ fmap_rs = []
340
+ fmap_gs = []
341
+ for i, d in enumerate(self.discriminators):
342
+ if i > 0:
343
+ y = self.meanpools[i - 1](y)
344
+ y_hat = self.meanpools[i - 1](y_hat)
345
+ y_d_r, fmap_r = d(y)
346
+ y_d_g, fmap_g = d(y_hat)
347
+ y_d_rs.append(y_d_r)
348
+ fmap_rs.append(fmap_r)
349
+ y_d_gs.append(y_d_g)
350
+ fmap_gs.append(fmap_g)
351
+
352
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
353
+
354
+
279
355
  class EnvelopeExtractor(Model):
280
356
  """Extracts the amplitude envelope of the audio signal."""
281
357
 
@@ -297,21 +373,35 @@ class EnvelopeExtractor(Model):
297
373
 
298
374
 
299
375
  class DiscriminatorEnvelope(ConvNets):
300
- def __init__(self, use_spectral_norm=False):
376
+ def __init__(
377
+ self,
378
+ use_spectral_norm=False,
379
+ discriminator_channel_multi: Number = 1,
380
+ kernel_size: int = 101,
381
+ ):
301
382
  super().__init__()
302
383
  norm_f = weight_norm if not use_spectral_norm else spectral_norm
303
- self.extractor = EnvelopeExtractor(kernel_size=101)
384
+ self.extractor = EnvelopeExtractor(kernel_size=kernel_size)
385
+ dsc = lambda x: int(x * discriminator_channel_multi)
304
386
  self.convs = nn.ModuleList(
305
387
  [
306
- norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
307
- norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
308
- norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
309
- norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
310
- norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
311
- norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
388
+ norm_f(nn.Conv1d(1, dsc(64), 15, stride=1, padding=7)),
389
+ norm_f(
390
+ nn.Conv1d(dsc(64), dsc(128), 41, stride=2, groups=4, padding=20)
391
+ ),
392
+ norm_f(
393
+ nn.Conv1d(dsc(128), dsc(256), 41, stride=2, groups=16, padding=20)
394
+ ),
395
+ norm_f(
396
+ nn.Conv1d(dsc(256), dsc(512), 41, stride=4, groups=16, padding=20)
397
+ ),
398
+ norm_f(
399
+ nn.Conv1d(dsc(512), dsc(512), 41, stride=4, groups=16, padding=20)
400
+ ),
401
+ norm_f(nn.Conv1d(dsc(512), dsc(512), 5, stride=1, padding=2)),
312
402
  ]
313
403
  )
314
- self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
404
+ self.conv_post = norm_f(nn.Conv1d(dsc(512), 1, 3, stride=1, padding=1))
315
405
  self.activation = nn.LeakyReLU(0.1)
316
406
 
317
407
  def forward(self, x):
@@ -327,11 +417,17 @@ class DiscriminatorEnvelope(ConvNets):
327
417
 
328
418
 
329
419
  class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
330
- def __init__(self, use_spectral_norm: bool = False):
420
+ def __init__(
421
+ self,
422
+ use_spectral_norm: bool = False,
423
+ discriminator_channel_multi: Number = 1,
424
+ ):
331
425
  super().__init__()
332
426
  self.discriminators = nn.ModuleList(
333
427
  [
334
- DiscriminatorEnvelope(use_spectral_norm), # raw envelope
428
+ DiscriminatorEnvelope(
429
+ use_spectral_norm, discriminator_channel_multi
430
+ ), # raw envelope
335
431
  DiscriminatorEnvelope(use_spectral_norm), # downsampled once
336
432
  DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
337
433
  ]
@@ -431,7 +527,7 @@ class DiscriminatorB(ConvNets):
431
527
  for band, stack in zip(x_bands, self.band_convs):
432
528
  for i, layer in enumerate(stack):
433
529
  band = layer(band)
434
- band = torch.nn.functional.leaky_relu(band, 0.1)
530
+ band = F.leaky_relu(band, 0.1)
435
531
  if i > 0:
436
532
  fmap.append(band)
437
533
  x.append(band)
@@ -452,11 +548,21 @@ class MultiBandDiscriminator(_MultiDiscriminatorT):
452
548
  def __init__(
453
549
  self,
454
550
  mbd_fft_sizes: list[int] = [2048, 1024, 512],
551
+ channels: int = 32,
552
+ hop_factor: float = 0.25,
553
+ bands: Tuple[Tuple[float, float], ...] = (
554
+ (0.0, 0.1),
555
+ (0.1, 0.25),
556
+ (0.25, 0.5),
557
+ (0.5, 0.75),
558
+ (0.75, 1.0),
559
+ ),
455
560
  ):
456
561
  super().__init__()
457
562
  self.fft_sizes = mbd_fft_sizes
563
+ kwargs_disc = dict(channels=channels, hop_factor=hop_factor, bands=bands)
458
564
  self.discriminators = nn.ModuleList(
459
- [DiscriminatorB(window_length=w) for w in self.fft_sizes]
565
+ [DiscriminatorB(window_length=w, **kwargs_disc) for w in self.fft_sizes]
460
566
  )
461
567
 
462
568
  def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
@@ -483,7 +589,7 @@ class DiscriminatorR(ConvNets):
483
589
  self,
484
590
  resolution: List[int],
485
591
  use_spectral_norm: bool = False,
486
- discriminator_channel_mult: int = 1,
592
+ discriminator_channel_multi: Number = 1,
487
593
  ):
488
594
  super().__init__()
489
595
 
@@ -501,13 +607,13 @@ class DiscriminatorR(ConvNets):
501
607
  [
502
608
  norm_f(
503
609
  nn.Conv2d(
504
- 1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
610
+ 1, int(32 * discriminator_channel_multi), (3, 9), padding=(1, 4)
505
611
  )
506
612
  ),
507
613
  norm_f(
508
614
  nn.Conv2d(
509
- int(32 * discriminator_channel_mult),
510
- int(32 * discriminator_channel_mult),
615
+ int(32 * discriminator_channel_multi),
616
+ int(32 * discriminator_channel_multi),
511
617
  (3, 9),
512
618
  stride=(1, 2),
513
619
  padding=(1, 4),
@@ -515,8 +621,8 @@ class DiscriminatorR(ConvNets):
515
621
  ),
516
622
  norm_f(
517
623
  nn.Conv2d(
518
- int(32 * discriminator_channel_mult),
519
- int(32 * discriminator_channel_mult),
624
+ int(32 * discriminator_channel_multi),
625
+ int(32 * discriminator_channel_multi),
520
626
  (3, 9),
521
627
  stride=(1, 2),
522
628
  padding=(1, 4),
@@ -524,8 +630,8 @@ class DiscriminatorR(ConvNets):
524
630
  ),
525
631
  norm_f(
526
632
  nn.Conv2d(
527
- int(32 * discriminator_channel_mult),
528
- int(32 * discriminator_channel_mult),
633
+ int(32 * discriminator_channel_multi),
634
+ int(32 * discriminator_channel_multi),
529
635
  (3, 9),
530
636
  stride=(1, 2),
531
637
  padding=(1, 4),
@@ -533,8 +639,8 @@ class DiscriminatorR(ConvNets):
533
639
  ),
534
640
  norm_f(
535
641
  nn.Conv2d(
536
- int(32 * discriminator_channel_mult),
537
- int(32 * discriminator_channel_mult),
642
+ int(32 * discriminator_channel_multi),
643
+ int(32 * discriminator_channel_multi),
538
644
  (3, 3),
539
645
  padding=(1, 1),
540
646
  )
@@ -542,7 +648,7 @@ class DiscriminatorR(ConvNets):
542
648
  ]
543
649
  )
544
650
  self.conv_post = norm_f(
545
- nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
651
+ nn.Conv2d(int(32 * discriminator_channel_multi), 1, (3, 3), padding=(1, 1))
546
652
  )
547
653
 
548
654
  def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
@@ -586,7 +692,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
586
692
  def __init__(
587
693
  self,
588
694
  use_spectral_norm: bool = False,
589
- discriminator_channel_mult: int = 1,
695
+ discriminator_channel_multi: Number = 1,
590
696
  resolutions: List[List[int]] = [
591
697
  [1024, 120, 600],
592
698
  [2048, 240, 1200],
@@ -601,7 +707,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
601
707
  self.discriminators = nn.ModuleList(
602
708
  [
603
709
  DiscriminatorR(
604
- resolution, use_spectral_norm, discriminator_channel_mult
710
+ resolution, use_spectral_norm, discriminator_channel_multi
605
711
  )
606
712
  for resolution in self.resolutions
607
713
  ]
@@ -620,3 +726,235 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
620
726
  y_d_gs.append(y_d_g)
621
727
  fmap_gs.append(fmap_g)
622
728
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
729
+
730
+
731
+ class DiscriminatorCQT(ConvNets):
732
+ """Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license."""
733
+
734
+ def __init__(
735
+ self,
736
+ hop_length: int,
737
+ n_octaves: int,
738
+ bins_per_octave: int,
739
+ sampling_rate: int,
740
+ cqtd_filters: int = 128,
741
+ cqtd_max_filters: int = 1024,
742
+ cqtd_filters_scale: int = 1,
743
+ cqtd_dilations: list = [1, 2, 4],
744
+ cqtd_in_channels: int = 1,
745
+ cqtd_out_channels: int = 1,
746
+ cqtd_normalize_volume: bool = False,
747
+ ):
748
+ super().__init__()
749
+ self.filters = cqtd_filters
750
+ self.max_filters = cqtd_max_filters
751
+ self.filters_scale = cqtd_filters_scale
752
+ self.kernel_size = (3, 9)
753
+ self.dilations = cqtd_dilations
754
+ self.stride = (1, 2)
755
+
756
+ self.fs = sampling_rate
757
+ self.in_channels = cqtd_in_channels
758
+ self.out_channels = cqtd_out_channels
759
+ self.hop_length = hop_length
760
+ self.n_octaves = n_octaves
761
+ self.bins_per_octave = bins_per_octave
762
+
763
+ # Lazy-load
764
+ from lt_tensor.model_zoo.losses.CQT.transforms import CQT2010v2
765
+
766
+ self.cqt_transform = CQT2010v2(
767
+ sr=self.fs * 2,
768
+ hop_length=self.hop_length,
769
+ n_bins=self.bins_per_octave * self.n_octaves,
770
+ bins_per_octave=self.bins_per_octave,
771
+ output_format="Complex",
772
+ pad_mode="constant",
773
+ )
774
+
775
+ self.conv_pres = nn.ModuleList()
776
+ for _ in range(self.n_octaves):
777
+ self.conv_pres.append(
778
+ nn.Conv2d(
779
+ self.in_channels * 2,
780
+ self.in_channels * 2,
781
+ kernel_size=self.kernel_size,
782
+ padding=self.get_2d_padding(self.kernel_size),
783
+ )
784
+ )
785
+
786
+ self.convs = nn.ModuleList()
787
+
788
+ self.convs.append(
789
+ nn.Conv2d(
790
+ self.in_channels * 2,
791
+ self.filters,
792
+ kernel_size=self.kernel_size,
793
+ padding=self.get_2d_padding(self.kernel_size),
794
+ )
795
+ )
796
+
797
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
798
+ for i, dilation in enumerate(self.dilations):
799
+ out_chs = min(
800
+ (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
801
+ )
802
+ self.convs.append(
803
+ weight_norm(
804
+ nn.Conv2d(
805
+ in_chs,
806
+ out_chs,
807
+ kernel_size=self.kernel_size,
808
+ stride=self.stride,
809
+ dilation=(dilation, 1),
810
+ padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
811
+ )
812
+ )
813
+ )
814
+ in_chs = out_chs
815
+ out_chs = min(
816
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
817
+ self.max_filters,
818
+ )
819
+ self.convs.append(
820
+ weight_norm(
821
+ nn.Conv2d(
822
+ in_chs,
823
+ out_chs,
824
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
825
+ padding=self.get_2d_padding(
826
+ (self.kernel_size[0], self.kernel_size[0])
827
+ ),
828
+ )
829
+ )
830
+ )
831
+
832
+ self.conv_post = weight_norm(
833
+ nn.Conv2d(
834
+ out_chs,
835
+ self.out_channels,
836
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
837
+ padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
838
+ )
839
+ )
840
+
841
+ self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
842
+ self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
843
+
844
+ self.cqtd_normalize_volume = cqtd_normalize_volume
845
+ if self.cqtd_normalize_volume:
846
+ print(
847
+ f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
848
+ )
849
+
850
+ def get_2d_padding(
851
+ self,
852
+ kernel_size: Tuple[int, int],
853
+ dilation: Tuple[int, int] = (1, 1),
854
+ ):
855
+ return (
856
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
857
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
858
+ )
859
+
860
+ def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
861
+ fmap = []
862
+
863
+ if self.cqtd_normalize_volume:
864
+ # Remove DC offset
865
+ x = x - x.mean(dim=-1, keepdims=True)
866
+ # Peak normalize the volume of input audio
867
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
868
+
869
+ x = self.resample(x)
870
+
871
+ z = self.cqt_transform(x)
872
+
873
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
874
+ z_phase = z[:, :, :, 1].unsqueeze(1)
875
+
876
+ z = torch.cat([z_amplitude, z_phase], dim=1)
877
+ z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
878
+
879
+ latent_z = []
880
+ for i in range(self.n_octaves):
881
+ latent_z.append(
882
+ self.conv_pres[i](
883
+ z[
884
+ :,
885
+ :,
886
+ :,
887
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
888
+ ]
889
+ )
890
+ )
891
+ latent_z = torch.cat(latent_z, dim=-1)
892
+
893
+ for i, l in enumerate(self.convs):
894
+ latent_z = l(latent_z)
895
+
896
+ latent_z = self.activation(latent_z)
897
+ fmap.append(latent_z)
898
+
899
+ latent_z = self.conv_post(latent_z)
900
+
901
+ return latent_z, fmap
902
+
903
+
904
+ class MultiScaleSubbandCQTDiscriminator(_MultiDiscriminatorT):
905
+ def __init__(
906
+ self,
907
+ sampling_rate: int,
908
+ cqtd_filters: int = 128,
909
+ cqtd_max_filters: int = 1024,
910
+ cqtd_filters_scale: Number = 1,
911
+ cqtd_dilations: list = [1, 2, 4],
912
+ cqtd_hop_lengths: list = [512, 256, 256],
913
+ cqtd_n_octaves: list = [9, 9, 9],
914
+ cqtd_bins_per_octaves: list = [24, 36, 48],
915
+ cqtd_in_channels: int = 1,
916
+ cqtd_out_channels: int = 1,
917
+ cqtd_normalize_volume: bool = False,
918
+ ):
919
+ super().__init__()
920
+
921
+ self.discriminators = nn.ModuleList(
922
+ [
923
+ DiscriminatorCQT(
924
+ hop_length=cqtd_hop_lengths[i],
925
+ n_octaves=cqtd_n_octaves[i],
926
+ bins_per_octave=cqtd_bins_per_octaves[i],
927
+ sampling_rate=sampling_rate,
928
+ cqtd_filters=cqtd_filters,
929
+ cqtd_max_filters=cqtd_max_filters,
930
+ cqtd_filters_scale=cqtd_filters_scale,
931
+ cqtd_dilations=cqtd_dilations,
932
+ cqtd_in_channels=cqtd_in_channels,
933
+ cqtd_out_channels=cqtd_out_channels,
934
+ cqtd_normalize_volume=cqtd_normalize_volume,
935
+ )
936
+ for i in range(len(cqtd_hop_lengths))
937
+ ]
938
+ )
939
+
940
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
941
+ List[torch.Tensor],
942
+ List[torch.Tensor],
943
+ List[List[torch.Tensor]],
944
+ List[List[torch.Tensor]],
945
+ ]:
946
+
947
+ y_d_rs = []
948
+ y_d_gs = []
949
+ fmap_rs = []
950
+ fmap_gs = []
951
+
952
+ for disc in self.discriminators:
953
+ y_d_r, fmap_r = disc(y)
954
+ y_d_g, fmap_g = disc(y_hat)
955
+ y_d_rs.append(y_d_r)
956
+ fmap_rs.append(fmap_r)
957
+ y_d_gs.append(y_d_g)
958
+ fmap_gs.append(fmap_g)
959
+
960
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs