lt-tensor 0.0.1a34__py3-none-any.whl → 0.0.1a36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/losses.py +11 -7
- lt_tensor/lr_schedulers.py +147 -21
- lt_tensor/misc_utils.py +35 -42
- lt_tensor/model_zoo/activations/__init__.py +3 -0
- lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
- lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
- lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
- lt_tensor/model_zoo/audio_models/__init__.py +2 -2
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +22 -357
- lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
- lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
- lt_tensor/model_zoo/convs.py +21 -32
- lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
- lt_tensor/model_zoo/losses/CQT/transforms.py +336 -0
- lt_tensor/model_zoo/losses/CQT/utils.py +519 -0
- lt_tensor/model_zoo/losses/discriminators.py +375 -37
- lt_tensor/processors/audio.py +67 -57
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/METADATA +1 -1
- lt_tensor-0.0.1a36.dist-info/RECORD +43 -0
- lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
- lt_tensor-0.0.1a34.dist-info/RECORD +0 -37
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,6 @@ from lt_tensor.model_base import Model
|
|
7
7
|
from lt_tensor.model_zoo.convs import ConvNets
|
8
8
|
from torch.nn import functional as F
|
9
9
|
from torchaudio import transforms as T
|
10
|
-
from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
|
11
|
-
|
12
10
|
|
13
11
|
MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
14
12
|
List[Tensor],
|
@@ -19,9 +17,11 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
|
19
17
|
|
20
18
|
|
21
19
|
class MultiDiscriminatorWrapper(Model):
|
22
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self, list_discriminator: Union[List["_MultiDiscriminatorT"], nn.ModuleList]
|
22
|
+
):
|
23
23
|
"""Setup example:
|
24
|
-
model_d =
|
24
|
+
model_d = MultiDiscriminatorWrapper(
|
25
25
|
[
|
26
26
|
MultiEnvelopeDiscriminator(),
|
27
27
|
MultiBandDiscriminator(),
|
@@ -31,7 +31,12 @@ class MultiDiscriminatorWrapper(Model):
|
|
31
31
|
)
|
32
32
|
"""
|
33
33
|
super().__init__()
|
34
|
-
|
34
|
+
|
35
|
+
self.disc: Sequence[_MultiDiscriminatorT] = (
|
36
|
+
nn.ModuleList(list_discriminator)
|
37
|
+
if isinstance(list_discriminator, (list, tuple, set))
|
38
|
+
else list_discriminator
|
39
|
+
)
|
35
40
|
self.total = len(self.disc)
|
36
41
|
|
37
42
|
def forward(
|
@@ -96,7 +101,6 @@ class _MultiDiscriminatorT(ConvNets):
|
|
96
101
|
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
97
102
|
pass
|
98
103
|
|
99
|
-
# for type hinting
|
100
104
|
def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
|
101
105
|
return super().__call__(*args, **kwds)
|
102
106
|
|
@@ -159,7 +163,7 @@ class DiscriminatorP(ConvNets):
|
|
159
163
|
def __init__(
|
160
164
|
self,
|
161
165
|
period: List[int],
|
162
|
-
|
166
|
+
discriminator_channel_multi: Number = 1,
|
163
167
|
kernel_size: int = 5,
|
164
168
|
stride: int = 3,
|
165
169
|
use_spectral_norm: bool = False,
|
@@ -167,7 +171,7 @@ class DiscriminatorP(ConvNets):
|
|
167
171
|
super().__init__()
|
168
172
|
self.period = period
|
169
173
|
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
170
|
-
dsc = lambda x: int(x *
|
174
|
+
dsc = lambda x: int(x * discriminator_channel_multi)
|
171
175
|
self.convs = nn.ModuleList(
|
172
176
|
[
|
173
177
|
norm_f(
|
@@ -242,19 +246,18 @@ class DiscriminatorP(ConvNets):
|
|
242
246
|
class MultiPeriodDiscriminator(_MultiDiscriminatorT):
|
243
247
|
def __init__(
|
244
248
|
self,
|
245
|
-
|
249
|
+
discriminator_channel_multi: Number = 1,
|
246
250
|
mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
|
247
251
|
use_spectral_norm: bool = False,
|
248
252
|
):
|
249
253
|
super().__init__()
|
250
254
|
self.mpd_reshapes = mpd_reshapes
|
251
|
-
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
252
255
|
self.discriminators = nn.ModuleList(
|
253
256
|
[
|
254
257
|
DiscriminatorP(
|
255
258
|
rs,
|
256
259
|
use_spectral_norm=use_spectral_norm,
|
257
|
-
|
260
|
+
discriminator_channel_multi=discriminator_channel_multi,
|
258
261
|
)
|
259
262
|
for rs in self.mpd_reshapes
|
260
263
|
]
|
@@ -276,6 +279,79 @@ class MultiPeriodDiscriminator(_MultiDiscriminatorT):
|
|
276
279
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
277
280
|
|
278
281
|
|
282
|
+
class DiscriminatorS(ConvNets):
|
283
|
+
def __init__(
|
284
|
+
self,
|
285
|
+
use_spectral_norm=False,
|
286
|
+
discriminator_channel_multi: Number = 1,
|
287
|
+
):
|
288
|
+
super().__init__()
|
289
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
290
|
+
dsc = lambda x: int(x * discriminator_channel_multi)
|
291
|
+
self.convs = nn.ModuleList(
|
292
|
+
[
|
293
|
+
norm_f(nn.Conv1d(1, dsc(128), 15, 1, padding=7)),
|
294
|
+
norm_f(nn.Conv1d(dsc(128), dsc(128), 41, 2, groups=4, padding=20)),
|
295
|
+
norm_f(nn.Conv1d(dsc(128), dsc(256), 41, 2, groups=16, padding=20)),
|
296
|
+
norm_f(nn.Conv1d(dsc(256), dsc(512), 41, 4, groups=16, padding=20)),
|
297
|
+
norm_f(nn.Conv1d(dsc(512), dsc(1024), 41, 4, groups=16, padding=20)),
|
298
|
+
norm_f(nn.Conv1d(dsc(1024), dsc(1024), 41, 1, groups=16, padding=20)),
|
299
|
+
norm_f(nn.Conv1d(dsc(1024), dsc(1024), 5, 1, padding=2)),
|
300
|
+
]
|
301
|
+
)
|
302
|
+
self.conv_post = norm_f(nn.Conv1d(dsc(1024), 1, 3, 1, padding=1))
|
303
|
+
self.activation = nn.LeakyReLU(0.1)
|
304
|
+
|
305
|
+
def forward(self, x):
|
306
|
+
fmap = []
|
307
|
+
for l in self.convs:
|
308
|
+
x = l(x)
|
309
|
+
x = self.activation(x)
|
310
|
+
fmap.append(x)
|
311
|
+
x = self.conv_post(x)
|
312
|
+
fmap.append(x)
|
313
|
+
return x.flatten(1, -1), fmap
|
314
|
+
|
315
|
+
|
316
|
+
class MultiScaleDiscriminator(ConvNets):
|
317
|
+
def __init__(
|
318
|
+
self,
|
319
|
+
discriminator_channel_multi: Number = 1,
|
320
|
+
):
|
321
|
+
super().__init__()
|
322
|
+
self.discriminators = nn.ModuleList(
|
323
|
+
[
|
324
|
+
DiscriminatorS(
|
325
|
+
use_spectral_norm=True,
|
326
|
+
discriminator_channel_multi=discriminator_channel_multi,
|
327
|
+
),
|
328
|
+
DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
|
329
|
+
DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
|
330
|
+
]
|
331
|
+
)
|
332
|
+
self.meanpools = nn.ModuleList(
|
333
|
+
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
334
|
+
)
|
335
|
+
|
336
|
+
def forward(self, y, y_hat):
|
337
|
+
y_d_rs = []
|
338
|
+
y_d_gs = []
|
339
|
+
fmap_rs = []
|
340
|
+
fmap_gs = []
|
341
|
+
for i, d in enumerate(self.discriminators):
|
342
|
+
if i > 0:
|
343
|
+
y = self.meanpools[i - 1](y)
|
344
|
+
y_hat = self.meanpools[i - 1](y_hat)
|
345
|
+
y_d_r, fmap_r = d(y)
|
346
|
+
y_d_g, fmap_g = d(y_hat)
|
347
|
+
y_d_rs.append(y_d_r)
|
348
|
+
fmap_rs.append(fmap_r)
|
349
|
+
y_d_gs.append(y_d_g)
|
350
|
+
fmap_gs.append(fmap_g)
|
351
|
+
|
352
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
353
|
+
|
354
|
+
|
279
355
|
class EnvelopeExtractor(Model):
|
280
356
|
"""Extracts the amplitude envelope of the audio signal."""
|
281
357
|
|
@@ -297,21 +373,35 @@ class EnvelopeExtractor(Model):
|
|
297
373
|
|
298
374
|
|
299
375
|
class DiscriminatorEnvelope(ConvNets):
|
300
|
-
def __init__(
|
376
|
+
def __init__(
|
377
|
+
self,
|
378
|
+
use_spectral_norm=False,
|
379
|
+
discriminator_channel_multi: Number = 1,
|
380
|
+
kernel_size: int = 101,
|
381
|
+
):
|
301
382
|
super().__init__()
|
302
383
|
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
303
|
-
self.extractor = EnvelopeExtractor(kernel_size=
|
384
|
+
self.extractor = EnvelopeExtractor(kernel_size=kernel_size)
|
385
|
+
dsc = lambda x: int(x * discriminator_channel_multi)
|
304
386
|
self.convs = nn.ModuleList(
|
305
387
|
[
|
306
|
-
norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
|
307
|
-
norm_f(
|
308
|
-
|
309
|
-
|
310
|
-
norm_f(
|
311
|
-
|
388
|
+
norm_f(nn.Conv1d(1, dsc(64), 15, stride=1, padding=7)),
|
389
|
+
norm_f(
|
390
|
+
nn.Conv1d(dsc(64), dsc(128), 41, stride=2, groups=4, padding=20)
|
391
|
+
),
|
392
|
+
norm_f(
|
393
|
+
nn.Conv1d(dsc(128), dsc(256), 41, stride=2, groups=16, padding=20)
|
394
|
+
),
|
395
|
+
norm_f(
|
396
|
+
nn.Conv1d(dsc(256), dsc(512), 41, stride=4, groups=16, padding=20)
|
397
|
+
),
|
398
|
+
norm_f(
|
399
|
+
nn.Conv1d(dsc(512), dsc(512), 41, stride=4, groups=16, padding=20)
|
400
|
+
),
|
401
|
+
norm_f(nn.Conv1d(dsc(512), dsc(512), 5, stride=1, padding=2)),
|
312
402
|
]
|
313
403
|
)
|
314
|
-
self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
|
404
|
+
self.conv_post = norm_f(nn.Conv1d(dsc(512), 1, 3, stride=1, padding=1))
|
315
405
|
self.activation = nn.LeakyReLU(0.1)
|
316
406
|
|
317
407
|
def forward(self, x):
|
@@ -327,11 +417,17 @@ class DiscriminatorEnvelope(ConvNets):
|
|
327
417
|
|
328
418
|
|
329
419
|
class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
|
330
|
-
def __init__(
|
420
|
+
def __init__(
|
421
|
+
self,
|
422
|
+
use_spectral_norm: bool = False,
|
423
|
+
discriminator_channel_multi: Number = 1,
|
424
|
+
):
|
331
425
|
super().__init__()
|
332
426
|
self.discriminators = nn.ModuleList(
|
333
427
|
[
|
334
|
-
DiscriminatorEnvelope(
|
428
|
+
DiscriminatorEnvelope(
|
429
|
+
use_spectral_norm, discriminator_channel_multi
|
430
|
+
), # raw envelope
|
335
431
|
DiscriminatorEnvelope(use_spectral_norm), # downsampled once
|
336
432
|
DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
|
337
433
|
]
|
@@ -431,7 +527,7 @@ class DiscriminatorB(ConvNets):
|
|
431
527
|
for band, stack in zip(x_bands, self.band_convs):
|
432
528
|
for i, layer in enumerate(stack):
|
433
529
|
band = layer(band)
|
434
|
-
band =
|
530
|
+
band = F.leaky_relu(band, 0.1)
|
435
531
|
if i > 0:
|
436
532
|
fmap.append(band)
|
437
533
|
x.append(band)
|
@@ -452,11 +548,21 @@ class MultiBandDiscriminator(_MultiDiscriminatorT):
|
|
452
548
|
def __init__(
|
453
549
|
self,
|
454
550
|
mbd_fft_sizes: list[int] = [2048, 1024, 512],
|
551
|
+
channels: int = 32,
|
552
|
+
hop_factor: float = 0.25,
|
553
|
+
bands: Tuple[Tuple[float, float], ...] = (
|
554
|
+
(0.0, 0.1),
|
555
|
+
(0.1, 0.25),
|
556
|
+
(0.25, 0.5),
|
557
|
+
(0.5, 0.75),
|
558
|
+
(0.75, 1.0),
|
559
|
+
),
|
455
560
|
):
|
456
561
|
super().__init__()
|
457
562
|
self.fft_sizes = mbd_fft_sizes
|
563
|
+
kwargs_disc = dict(channels=channels, hop_factor=hop_factor, bands=bands)
|
458
564
|
self.discriminators = nn.ModuleList(
|
459
|
-
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
565
|
+
[DiscriminatorB(window_length=w, **kwargs_disc) for w in self.fft_sizes]
|
460
566
|
)
|
461
567
|
|
462
568
|
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
@@ -483,7 +589,7 @@ class DiscriminatorR(ConvNets):
|
|
483
589
|
self,
|
484
590
|
resolution: List[int],
|
485
591
|
use_spectral_norm: bool = False,
|
486
|
-
|
592
|
+
discriminator_channel_multi: Number = 1,
|
487
593
|
):
|
488
594
|
super().__init__()
|
489
595
|
|
@@ -501,13 +607,13 @@ class DiscriminatorR(ConvNets):
|
|
501
607
|
[
|
502
608
|
norm_f(
|
503
609
|
nn.Conv2d(
|
504
|
-
1, int(32 *
|
610
|
+
1, int(32 * discriminator_channel_multi), (3, 9), padding=(1, 4)
|
505
611
|
)
|
506
612
|
),
|
507
613
|
norm_f(
|
508
614
|
nn.Conv2d(
|
509
|
-
int(32 *
|
510
|
-
int(32 *
|
615
|
+
int(32 * discriminator_channel_multi),
|
616
|
+
int(32 * discriminator_channel_multi),
|
511
617
|
(3, 9),
|
512
618
|
stride=(1, 2),
|
513
619
|
padding=(1, 4),
|
@@ -515,8 +621,8 @@ class DiscriminatorR(ConvNets):
|
|
515
621
|
),
|
516
622
|
norm_f(
|
517
623
|
nn.Conv2d(
|
518
|
-
int(32 *
|
519
|
-
int(32 *
|
624
|
+
int(32 * discriminator_channel_multi),
|
625
|
+
int(32 * discriminator_channel_multi),
|
520
626
|
(3, 9),
|
521
627
|
stride=(1, 2),
|
522
628
|
padding=(1, 4),
|
@@ -524,8 +630,8 @@ class DiscriminatorR(ConvNets):
|
|
524
630
|
),
|
525
631
|
norm_f(
|
526
632
|
nn.Conv2d(
|
527
|
-
int(32 *
|
528
|
-
int(32 *
|
633
|
+
int(32 * discriminator_channel_multi),
|
634
|
+
int(32 * discriminator_channel_multi),
|
529
635
|
(3, 9),
|
530
636
|
stride=(1, 2),
|
531
637
|
padding=(1, 4),
|
@@ -533,8 +639,8 @@ class DiscriminatorR(ConvNets):
|
|
533
639
|
),
|
534
640
|
norm_f(
|
535
641
|
nn.Conv2d(
|
536
|
-
int(32 *
|
537
|
-
int(32 *
|
642
|
+
int(32 * discriminator_channel_multi),
|
643
|
+
int(32 * discriminator_channel_multi),
|
538
644
|
(3, 3),
|
539
645
|
padding=(1, 1),
|
540
646
|
)
|
@@ -542,7 +648,7 @@ class DiscriminatorR(ConvNets):
|
|
542
648
|
]
|
543
649
|
)
|
544
650
|
self.conv_post = norm_f(
|
545
|
-
nn.Conv2d(int(32 *
|
651
|
+
nn.Conv2d(int(32 * discriminator_channel_multi), 1, (3, 3), padding=(1, 1))
|
546
652
|
)
|
547
653
|
|
548
654
|
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
@@ -586,7 +692,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
|
586
692
|
def __init__(
|
587
693
|
self,
|
588
694
|
use_spectral_norm: bool = False,
|
589
|
-
|
695
|
+
discriminator_channel_multi: Number = 1,
|
590
696
|
resolutions: List[List[int]] = [
|
591
697
|
[1024, 120, 600],
|
592
698
|
[2048, 240, 1200],
|
@@ -601,7 +707,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
|
601
707
|
self.discriminators = nn.ModuleList(
|
602
708
|
[
|
603
709
|
DiscriminatorR(
|
604
|
-
resolution, use_spectral_norm,
|
710
|
+
resolution, use_spectral_norm, discriminator_channel_multi
|
605
711
|
)
|
606
712
|
for resolution in self.resolutions
|
607
713
|
]
|
@@ -620,3 +726,235 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
|
620
726
|
y_d_gs.append(y_d_g)
|
621
727
|
fmap_gs.append(fmap_g)
|
622
728
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
729
|
+
|
730
|
+
|
731
|
+
class DiscriminatorCQT(ConvNets):
|
732
|
+
"""Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license."""
|
733
|
+
|
734
|
+
def __init__(
|
735
|
+
self,
|
736
|
+
hop_length: int,
|
737
|
+
n_octaves: int,
|
738
|
+
bins_per_octave: int,
|
739
|
+
sampling_rate: int,
|
740
|
+
cqtd_filters: int = 128,
|
741
|
+
cqtd_max_filters: int = 1024,
|
742
|
+
cqtd_filters_scale: int = 1,
|
743
|
+
cqtd_dilations: list = [1, 2, 4],
|
744
|
+
cqtd_in_channels: int = 1,
|
745
|
+
cqtd_out_channels: int = 1,
|
746
|
+
cqtd_normalize_volume: bool = False,
|
747
|
+
):
|
748
|
+
super().__init__()
|
749
|
+
self.filters = cqtd_filters
|
750
|
+
self.max_filters = cqtd_max_filters
|
751
|
+
self.filters_scale = cqtd_filters_scale
|
752
|
+
self.kernel_size = (3, 9)
|
753
|
+
self.dilations = cqtd_dilations
|
754
|
+
self.stride = (1, 2)
|
755
|
+
|
756
|
+
self.fs = sampling_rate
|
757
|
+
self.in_channels = cqtd_in_channels
|
758
|
+
self.out_channels = cqtd_out_channels
|
759
|
+
self.hop_length = hop_length
|
760
|
+
self.n_octaves = n_octaves
|
761
|
+
self.bins_per_octave = bins_per_octave
|
762
|
+
|
763
|
+
# Lazy-load
|
764
|
+
from lt_tensor.model_zoo.losses.CQT.transforms import CQT2010v2
|
765
|
+
|
766
|
+
self.cqt_transform = CQT2010v2(
|
767
|
+
sr=self.fs * 2,
|
768
|
+
hop_length=self.hop_length,
|
769
|
+
n_bins=self.bins_per_octave * self.n_octaves,
|
770
|
+
bins_per_octave=self.bins_per_octave,
|
771
|
+
output_format="Complex",
|
772
|
+
pad_mode="constant",
|
773
|
+
)
|
774
|
+
|
775
|
+
self.conv_pres = nn.ModuleList()
|
776
|
+
for _ in range(self.n_octaves):
|
777
|
+
self.conv_pres.append(
|
778
|
+
nn.Conv2d(
|
779
|
+
self.in_channels * 2,
|
780
|
+
self.in_channels * 2,
|
781
|
+
kernel_size=self.kernel_size,
|
782
|
+
padding=self.get_2d_padding(self.kernel_size),
|
783
|
+
)
|
784
|
+
)
|
785
|
+
|
786
|
+
self.convs = nn.ModuleList()
|
787
|
+
|
788
|
+
self.convs.append(
|
789
|
+
nn.Conv2d(
|
790
|
+
self.in_channels * 2,
|
791
|
+
self.filters,
|
792
|
+
kernel_size=self.kernel_size,
|
793
|
+
padding=self.get_2d_padding(self.kernel_size),
|
794
|
+
)
|
795
|
+
)
|
796
|
+
|
797
|
+
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
798
|
+
for i, dilation in enumerate(self.dilations):
|
799
|
+
out_chs = min(
|
800
|
+
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
801
|
+
)
|
802
|
+
self.convs.append(
|
803
|
+
weight_norm(
|
804
|
+
nn.Conv2d(
|
805
|
+
in_chs,
|
806
|
+
out_chs,
|
807
|
+
kernel_size=self.kernel_size,
|
808
|
+
stride=self.stride,
|
809
|
+
dilation=(dilation, 1),
|
810
|
+
padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
|
811
|
+
)
|
812
|
+
)
|
813
|
+
)
|
814
|
+
in_chs = out_chs
|
815
|
+
out_chs = min(
|
816
|
+
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
|
817
|
+
self.max_filters,
|
818
|
+
)
|
819
|
+
self.convs.append(
|
820
|
+
weight_norm(
|
821
|
+
nn.Conv2d(
|
822
|
+
in_chs,
|
823
|
+
out_chs,
|
824
|
+
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
825
|
+
padding=self.get_2d_padding(
|
826
|
+
(self.kernel_size[0], self.kernel_size[0])
|
827
|
+
),
|
828
|
+
)
|
829
|
+
)
|
830
|
+
)
|
831
|
+
|
832
|
+
self.conv_post = weight_norm(
|
833
|
+
nn.Conv2d(
|
834
|
+
out_chs,
|
835
|
+
self.out_channels,
|
836
|
+
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
837
|
+
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
838
|
+
)
|
839
|
+
)
|
840
|
+
|
841
|
+
self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
|
842
|
+
self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
|
843
|
+
|
844
|
+
self.cqtd_normalize_volume = cqtd_normalize_volume
|
845
|
+
if self.cqtd_normalize_volume:
|
846
|
+
print(
|
847
|
+
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
848
|
+
)
|
849
|
+
|
850
|
+
def get_2d_padding(
|
851
|
+
self,
|
852
|
+
kernel_size: Tuple[int, int],
|
853
|
+
dilation: Tuple[int, int] = (1, 1),
|
854
|
+
):
|
855
|
+
return (
|
856
|
+
((kernel_size[0] - 1) * dilation[0]) // 2,
|
857
|
+
((kernel_size[1] - 1) * dilation[1]) // 2,
|
858
|
+
)
|
859
|
+
|
860
|
+
def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
861
|
+
fmap = []
|
862
|
+
|
863
|
+
if self.cqtd_normalize_volume:
|
864
|
+
# Remove DC offset
|
865
|
+
x = x - x.mean(dim=-1, keepdims=True)
|
866
|
+
# Peak normalize the volume of input audio
|
867
|
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
868
|
+
|
869
|
+
x = self.resample(x)
|
870
|
+
|
871
|
+
z = self.cqt_transform(x)
|
872
|
+
|
873
|
+
z_amplitude = z[:, :, :, 0].unsqueeze(1)
|
874
|
+
z_phase = z[:, :, :, 1].unsqueeze(1)
|
875
|
+
|
876
|
+
z = torch.cat([z_amplitude, z_phase], dim=1)
|
877
|
+
z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
|
878
|
+
|
879
|
+
latent_z = []
|
880
|
+
for i in range(self.n_octaves):
|
881
|
+
latent_z.append(
|
882
|
+
self.conv_pres[i](
|
883
|
+
z[
|
884
|
+
:,
|
885
|
+
:,
|
886
|
+
:,
|
887
|
+
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
|
888
|
+
]
|
889
|
+
)
|
890
|
+
)
|
891
|
+
latent_z = torch.cat(latent_z, dim=-1)
|
892
|
+
|
893
|
+
for i, l in enumerate(self.convs):
|
894
|
+
latent_z = l(latent_z)
|
895
|
+
|
896
|
+
latent_z = self.activation(latent_z)
|
897
|
+
fmap.append(latent_z)
|
898
|
+
|
899
|
+
latent_z = self.conv_post(latent_z)
|
900
|
+
|
901
|
+
return latent_z, fmap
|
902
|
+
|
903
|
+
|
904
|
+
class MultiScaleSubbandCQTDiscriminator(_MultiDiscriminatorT):
|
905
|
+
def __init__(
|
906
|
+
self,
|
907
|
+
sampling_rate: int,
|
908
|
+
cqtd_filters: int = 128,
|
909
|
+
cqtd_max_filters: int = 1024,
|
910
|
+
cqtd_filters_scale: Number = 1,
|
911
|
+
cqtd_dilations: list = [1, 2, 4],
|
912
|
+
cqtd_hop_lengths: list = [512, 256, 256],
|
913
|
+
cqtd_n_octaves: list = [9, 9, 9],
|
914
|
+
cqtd_bins_per_octaves: list = [24, 36, 48],
|
915
|
+
cqtd_in_channels: int = 1,
|
916
|
+
cqtd_out_channels: int = 1,
|
917
|
+
cqtd_normalize_volume: bool = False,
|
918
|
+
):
|
919
|
+
super().__init__()
|
920
|
+
|
921
|
+
self.discriminators = nn.ModuleList(
|
922
|
+
[
|
923
|
+
DiscriminatorCQT(
|
924
|
+
hop_length=cqtd_hop_lengths[i],
|
925
|
+
n_octaves=cqtd_n_octaves[i],
|
926
|
+
bins_per_octave=cqtd_bins_per_octaves[i],
|
927
|
+
sampling_rate=sampling_rate,
|
928
|
+
cqtd_filters=cqtd_filters,
|
929
|
+
cqtd_max_filters=cqtd_max_filters,
|
930
|
+
cqtd_filters_scale=cqtd_filters_scale,
|
931
|
+
cqtd_dilations=cqtd_dilations,
|
932
|
+
cqtd_in_channels=cqtd_in_channels,
|
933
|
+
cqtd_out_channels=cqtd_out_channels,
|
934
|
+
cqtd_normalize_volume=cqtd_normalize_volume,
|
935
|
+
)
|
936
|
+
for i in range(len(cqtd_hop_lengths))
|
937
|
+
]
|
938
|
+
)
|
939
|
+
|
940
|
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
941
|
+
List[torch.Tensor],
|
942
|
+
List[torch.Tensor],
|
943
|
+
List[List[torch.Tensor]],
|
944
|
+
List[List[torch.Tensor]],
|
945
|
+
]:
|
946
|
+
|
947
|
+
y_d_rs = []
|
948
|
+
y_d_gs = []
|
949
|
+
fmap_rs = []
|
950
|
+
fmap_gs = []
|
951
|
+
|
952
|
+
for disc in self.discriminators:
|
953
|
+
y_d_r, fmap_r = disc(y)
|
954
|
+
y_d_g, fmap_g = disc(y_hat)
|
955
|
+
y_d_rs.append(y_d_r)
|
956
|
+
fmap_rs.append(fmap_r)
|
957
|
+
y_d_gs.append(y_d_g)
|
958
|
+
fmap_gs.append(fmap_g)
|
959
|
+
|
960
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|