lt-tensor 0.0.1a33__py3-none-any.whl → 0.0.1a35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/losses.py +169 -47
- lt_tensor/lr_schedulers.py +147 -21
- lt_tensor/misc_utils.py +35 -42
- lt_tensor/model_zoo/activations/__init__.py +3 -0
- lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
- lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
- lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
- lt_tensor/model_zoo/audio_models/__init__.py +2 -2
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +16 -347
- lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
- lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
- lt_tensor/model_zoo/convs.py +21 -32
- lt_tensor/model_zoo/losses/discriminators.py +143 -230
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/METADATA +1 -1
- lt_tensor-0.0.1a35.dist-info/RECORD +40 -0
- lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
- lt_tensor-0.0.1a33.dist-info/RECORD +0 -37
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/top_level.txt +0 -0
lt_tensor/model_zoo/convs.py
CHANGED
@@ -40,42 +40,22 @@ class ConvNets(Model):
|
|
40
40
|
|
41
41
|
def remove_norms(self, name: str = "weight"):
|
42
42
|
for module in self.modules():
|
43
|
-
|
44
|
-
|
43
|
+
try:
|
44
|
+
if "Conv" in module.__class__.__name__:
|
45
|
+
remove_norm(module, name)
|
46
|
+
except:
|
47
|
+
pass
|
45
48
|
|
46
49
|
@staticmethod
|
47
|
-
def init_weights(
|
48
|
-
m: nn.Module,
|
49
|
-
norm: Optional[Literal["spectral", "weight"]] = None,
|
50
|
-
mean=0.0,
|
51
|
-
std=0.02,
|
52
|
-
name: str = "weight",
|
53
|
-
n_power_iterations: int = 1,
|
54
|
-
eps: float = 1e-9,
|
55
|
-
dim_sn: Optional[int] = None,
|
56
|
-
dim_wn: int = 0,
|
57
|
-
):
|
50
|
+
def init_weights(m: nn.Module, mean=0.0, std=0.02):
|
58
51
|
if "Conv" in m.__class__.__name__:
|
59
|
-
if norm is not None:
|
60
|
-
try:
|
61
|
-
if norm == "spectral":
|
62
|
-
m.apply(
|
63
|
-
lambda m: spectral_norm(
|
64
|
-
m,
|
65
|
-
n_power_iterations=n_power_iterations,
|
66
|
-
eps=eps,
|
67
|
-
name=name,
|
68
|
-
dim=dim_sn,
|
69
|
-
)
|
70
|
-
)
|
71
|
-
else:
|
72
|
-
m.apply(lambda m: weight_norm(m, name=name, dim=dim_wn))
|
73
|
-
except ValueError:
|
74
|
-
pass
|
75
52
|
m.weight.data.normal_(mean, std)
|
76
53
|
|
77
54
|
|
78
55
|
class Conv1dEXT(ConvNets):
|
56
|
+
|
57
|
+
# TODO: Use this module to replace all that are using normalizations, mostly those in `audio_models`
|
58
|
+
|
79
59
|
def __init__(
|
80
60
|
self,
|
81
61
|
in_channels: int,
|
@@ -90,7 +70,8 @@ class Conv1dEXT(ConvNets):
|
|
90
70
|
device: Optional[Any] = None,
|
91
71
|
dtype: Optional[Any] = None,
|
92
72
|
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
93
|
-
|
73
|
+
activation_in: nn.Module = nn.Identity(),
|
74
|
+
activation_out: nn.Module = nn.Identity(),
|
94
75
|
*args,
|
95
76
|
**kwargs,
|
96
77
|
):
|
@@ -112,13 +93,21 @@ class Conv1dEXT(ConvNets):
|
|
112
93
|
)
|
113
94
|
if apply_norm is None:
|
114
95
|
self.cnn = nn.Conv1d(**cnn_kwargs)
|
96
|
+
self.has_wn = False
|
115
97
|
else:
|
98
|
+
self.has_wn = True
|
116
99
|
if apply_norm == "spectral":
|
117
100
|
self.cnn = spectral_norm(nn.Conv1d(**cnn_kwargs))
|
118
101
|
else:
|
119
102
|
self.cnn = weight_norm(nn.Conv1d(**cnn_kwargs))
|
120
|
-
self.
|
103
|
+
self.actv_in = activation_in
|
104
|
+
self.actv_out = activation_out
|
121
105
|
self.cnn.apply(self.init_weights)
|
122
106
|
|
123
107
|
def forward(self, input: Tensor):
|
124
|
-
return self.cnn(self.
|
108
|
+
return self.actv_out(self.cnn(self.actv_in(input)))
|
109
|
+
|
110
|
+
def remove_norms(self, name="weight"):
|
111
|
+
if self.has_wn:
|
112
|
+
remove_norm(self.cnn, name)
|
113
|
+
self.has_wn = False
|
@@ -7,8 +7,6 @@ from lt_tensor.model_base import Model
|
|
7
7
|
from lt_tensor.model_zoo.convs import ConvNets
|
8
8
|
from torch.nn import functional as F
|
9
9
|
from torchaudio import transforms as T
|
10
|
-
from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
|
11
|
-
|
12
10
|
|
13
11
|
MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
14
12
|
List[Tensor],
|
@@ -19,9 +17,11 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
|
19
17
|
|
20
18
|
|
21
19
|
class MultiDiscriminatorWrapper(Model):
|
22
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self, list_discriminator: Union[List["_MultiDiscriminatorT"], nn.ModuleList]
|
22
|
+
):
|
23
23
|
"""Setup example:
|
24
|
-
model_d =
|
24
|
+
model_d = MultiDiscriminatorWrapper(
|
25
25
|
[
|
26
26
|
MultiEnvelopeDiscriminator(),
|
27
27
|
MultiBandDiscriminator(),
|
@@ -31,7 +31,12 @@ class MultiDiscriminatorWrapper(Model):
|
|
31
31
|
)
|
32
32
|
"""
|
33
33
|
super().__init__()
|
34
|
-
|
34
|
+
|
35
|
+
self.disc: Sequence[_MultiDiscriminatorT] = (
|
36
|
+
nn.ModuleList(list_discriminator)
|
37
|
+
if isinstance(list_discriminator, (list, tuple, set))
|
38
|
+
else list_discriminator
|
39
|
+
)
|
35
40
|
self.total = len(self.disc)
|
36
41
|
|
37
42
|
def forward(
|
@@ -82,23 +87,6 @@ class MultiDiscriminatorWrapper(Model):
|
|
82
87
|
return disc_loss, disc_real_losses, disc_gen_losses
|
83
88
|
|
84
89
|
|
85
|
-
def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-5):
|
86
|
-
norm = torch.norm(x, dim=-1, keepdim=True)
|
87
|
-
return x / (norm + eps)
|
88
|
-
|
89
|
-
|
90
|
-
def normalize_minmax(x: torch.Tensor, eps: float = 1e-5):
|
91
|
-
min_val = x.amin(dim=-1, keepdim=True)
|
92
|
-
max_val = x.amax(dim=-1, keepdim=True)
|
93
|
-
return (x - min_val) / (max_val - min_val + eps)
|
94
|
-
|
95
|
-
|
96
|
-
def normalize_zscore(x: torch.Tensor, eps: float = 1e-5):
|
97
|
-
mean = x.mean(dim=-1, keepdim=True)
|
98
|
-
std = x.std(dim=-1, keepdim=True)
|
99
|
-
return (x - mean) / (std + eps)
|
100
|
-
|
101
|
-
|
102
90
|
def get_padding(kernel_size, dilation=1):
|
103
91
|
return int((kernel_size * dilation - dilation) / 2)
|
104
92
|
|
@@ -113,7 +101,6 @@ class _MultiDiscriminatorT(ConvNets):
|
|
113
101
|
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
114
102
|
pass
|
115
103
|
|
116
|
-
# for type hinting
|
117
104
|
def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
|
118
105
|
return super().__call__(*args, **kwds)
|
119
106
|
|
@@ -176,7 +163,7 @@ class DiscriminatorP(ConvNets):
|
|
176
163
|
def __init__(
|
177
164
|
self,
|
178
165
|
period: List[int],
|
179
|
-
|
166
|
+
discriminator_channel_multi: Number = 1,
|
180
167
|
kernel_size: int = 5,
|
181
168
|
stride: int = 3,
|
182
169
|
use_spectral_norm: bool = False,
|
@@ -184,7 +171,7 @@ class DiscriminatorP(ConvNets):
|
|
184
171
|
super().__init__()
|
185
172
|
self.period = period
|
186
173
|
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
187
|
-
dsc = lambda x: int(x *
|
174
|
+
dsc = lambda x: int(x * discriminator_channel_multi)
|
188
175
|
self.convs = nn.ModuleList(
|
189
176
|
[
|
190
177
|
norm_f(
|
@@ -259,19 +246,18 @@ class DiscriminatorP(ConvNets):
|
|
259
246
|
class MultiPeriodDiscriminator(_MultiDiscriminatorT):
|
260
247
|
def __init__(
|
261
248
|
self,
|
262
|
-
|
249
|
+
discriminator_channel_multi: Number = 1,
|
263
250
|
mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
|
264
251
|
use_spectral_norm: bool = False,
|
265
252
|
):
|
266
253
|
super().__init__()
|
267
254
|
self.mpd_reshapes = mpd_reshapes
|
268
|
-
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
269
255
|
self.discriminators = nn.ModuleList(
|
270
256
|
[
|
271
257
|
DiscriminatorP(
|
272
258
|
rs,
|
273
259
|
use_spectral_norm=use_spectral_norm,
|
274
|
-
|
260
|
+
discriminator_channel_multi=discriminator_channel_multi,
|
275
261
|
)
|
276
262
|
for rs in self.mpd_reshapes
|
277
263
|
]
|
@@ -293,6 +279,79 @@ class MultiPeriodDiscriminator(_MultiDiscriminatorT):
|
|
293
279
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
294
280
|
|
295
281
|
|
282
|
+
class DiscriminatorS(ConvNets):
|
283
|
+
def __init__(
|
284
|
+
self,
|
285
|
+
use_spectral_norm=False,
|
286
|
+
discriminator_channel_multi: Number = 1,
|
287
|
+
):
|
288
|
+
super().__init__()
|
289
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
290
|
+
dsc = lambda x: int(x * discriminator_channel_multi)
|
291
|
+
self.convs = nn.ModuleList(
|
292
|
+
[
|
293
|
+
norm_f(nn.Conv1d(1, dsc(128), 15, 1, padding=7)),
|
294
|
+
norm_f(nn.Conv1d(dsc(128), dsc(128), 41, 2, groups=4, padding=20)),
|
295
|
+
norm_f(nn.Conv1d(dsc(128), dsc(256), 41, 2, groups=16, padding=20)),
|
296
|
+
norm_f(nn.Conv1d(dsc(256), dsc(512), 41, 4, groups=16, padding=20)),
|
297
|
+
norm_f(nn.Conv1d(dsc(512), dsc(1024), 41, 4, groups=16, padding=20)),
|
298
|
+
norm_f(nn.Conv1d(dsc(1024), dsc(1024), 41, 1, groups=16, padding=20)),
|
299
|
+
norm_f(nn.Conv1d(dsc(1024), dsc(1024), 5, 1, padding=2)),
|
300
|
+
]
|
301
|
+
)
|
302
|
+
self.conv_post = norm_f(nn.Conv1d(dsc(1024), 1, 3, 1, padding=1))
|
303
|
+
self.activation = nn.LeakyReLU(0.1)
|
304
|
+
|
305
|
+
def forward(self, x):
|
306
|
+
fmap = []
|
307
|
+
for l in self.convs:
|
308
|
+
x = l(x)
|
309
|
+
x = self.activation(x)
|
310
|
+
fmap.append(x)
|
311
|
+
x = self.conv_post(x)
|
312
|
+
fmap.append(x)
|
313
|
+
return x.flatten(1, -1), fmap
|
314
|
+
|
315
|
+
|
316
|
+
class MultiScaleDiscriminator(ConvNets):
|
317
|
+
def __init__(
|
318
|
+
self,
|
319
|
+
discriminator_channel_multi: Number = 1,
|
320
|
+
):
|
321
|
+
super().__init__()
|
322
|
+
self.discriminators = nn.ModuleList(
|
323
|
+
[
|
324
|
+
DiscriminatorS(
|
325
|
+
use_spectral_norm=True,
|
326
|
+
discriminator_channel_multi=discriminator_channel_multi,
|
327
|
+
),
|
328
|
+
DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
|
329
|
+
DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
|
330
|
+
]
|
331
|
+
)
|
332
|
+
self.meanpools = nn.ModuleList(
|
333
|
+
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
334
|
+
)
|
335
|
+
|
336
|
+
def forward(self, y, y_hat):
|
337
|
+
y_d_rs = []
|
338
|
+
y_d_gs = []
|
339
|
+
fmap_rs = []
|
340
|
+
fmap_gs = []
|
341
|
+
for i, d in enumerate(self.discriminators):
|
342
|
+
if i > 0:
|
343
|
+
y = self.meanpools[i - 1](y)
|
344
|
+
y_hat = self.meanpools[i - 1](y_hat)
|
345
|
+
y_d_r, fmap_r = d(y)
|
346
|
+
y_d_g, fmap_g = d(y_hat)
|
347
|
+
y_d_rs.append(y_d_r)
|
348
|
+
fmap_rs.append(fmap_r)
|
349
|
+
y_d_gs.append(y_d_g)
|
350
|
+
fmap_gs.append(fmap_g)
|
351
|
+
|
352
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
353
|
+
|
354
|
+
|
296
355
|
class EnvelopeExtractor(Model):
|
297
356
|
"""Extracts the amplitude envelope of the audio signal."""
|
298
357
|
|
@@ -314,21 +373,35 @@ class EnvelopeExtractor(Model):
|
|
314
373
|
|
315
374
|
|
316
375
|
class DiscriminatorEnvelope(ConvNets):
|
317
|
-
def __init__(
|
376
|
+
def __init__(
|
377
|
+
self,
|
378
|
+
use_spectral_norm=False,
|
379
|
+
discriminator_channel_multi: Number = 1,
|
380
|
+
kernel_size: int = 101,
|
381
|
+
):
|
318
382
|
super().__init__()
|
319
383
|
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
320
|
-
self.extractor = EnvelopeExtractor(kernel_size=
|
384
|
+
self.extractor = EnvelopeExtractor(kernel_size=kernel_size)
|
385
|
+
dsc = lambda x: int(x * discriminator_channel_multi)
|
321
386
|
self.convs = nn.ModuleList(
|
322
387
|
[
|
323
|
-
norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
|
324
|
-
norm_f(
|
325
|
-
|
326
|
-
|
327
|
-
norm_f(
|
328
|
-
|
388
|
+
norm_f(nn.Conv1d(1, dsc(64), 15, stride=1, padding=7)),
|
389
|
+
norm_f(
|
390
|
+
nn.Conv1d(dsc(64), dsc(128), 41, stride=2, groups=4, padding=20)
|
391
|
+
),
|
392
|
+
norm_f(
|
393
|
+
nn.Conv1d(dsc(128), dsc(256), 41, stride=2, groups=16, padding=20)
|
394
|
+
),
|
395
|
+
norm_f(
|
396
|
+
nn.Conv1d(dsc(256), dsc(512), 41, stride=4, groups=16, padding=20)
|
397
|
+
),
|
398
|
+
norm_f(
|
399
|
+
nn.Conv1d(dsc(512), dsc(512), 41, stride=4, groups=16, padding=20)
|
400
|
+
),
|
401
|
+
norm_f(nn.Conv1d(dsc(512), dsc(512), 5, stride=1, padding=2)),
|
329
402
|
]
|
330
403
|
)
|
331
|
-
self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
|
404
|
+
self.conv_post = norm_f(nn.Conv1d(dsc(512), 1, 3, stride=1, padding=1))
|
332
405
|
self.activation = nn.LeakyReLU(0.1)
|
333
406
|
|
334
407
|
def forward(self, x):
|
@@ -344,11 +417,17 @@ class DiscriminatorEnvelope(ConvNets):
|
|
344
417
|
|
345
418
|
|
346
419
|
class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
|
347
|
-
def __init__(
|
420
|
+
def __init__(
|
421
|
+
self,
|
422
|
+
use_spectral_norm: bool = False,
|
423
|
+
discriminator_channel_multi: Number = 1,
|
424
|
+
):
|
348
425
|
super().__init__()
|
349
426
|
self.discriminators = nn.ModuleList(
|
350
427
|
[
|
351
|
-
DiscriminatorEnvelope(
|
428
|
+
DiscriminatorEnvelope(
|
429
|
+
use_spectral_norm, discriminator_channel_multi
|
430
|
+
), # raw envelope
|
352
431
|
DiscriminatorEnvelope(use_spectral_norm), # downsampled once
|
353
432
|
DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
|
354
433
|
]
|
@@ -448,7 +527,7 @@ class DiscriminatorB(ConvNets):
|
|
448
527
|
for band, stack in zip(x_bands, self.band_convs):
|
449
528
|
for i, layer in enumerate(stack):
|
450
529
|
band = layer(band)
|
451
|
-
band =
|
530
|
+
band = F.leaky_relu(band, 0.1)
|
452
531
|
if i > 0:
|
453
532
|
fmap.append(band)
|
454
533
|
x.append(band)
|
@@ -469,11 +548,21 @@ class MultiBandDiscriminator(_MultiDiscriminatorT):
|
|
469
548
|
def __init__(
|
470
549
|
self,
|
471
550
|
mbd_fft_sizes: list[int] = [2048, 1024, 512],
|
551
|
+
channels: int = 32,
|
552
|
+
hop_factor: float = 0.25,
|
553
|
+
bands: Tuple[Tuple[float, float], ...] = (
|
554
|
+
(0.0, 0.1),
|
555
|
+
(0.1, 0.25),
|
556
|
+
(0.25, 0.5),
|
557
|
+
(0.5, 0.75),
|
558
|
+
(0.75, 1.0),
|
559
|
+
),
|
472
560
|
):
|
473
561
|
super().__init__()
|
474
562
|
self.fft_sizes = mbd_fft_sizes
|
563
|
+
kwargs_disc = dict(channels=channels, hop_factor=hop_factor, bands=bands)
|
475
564
|
self.discriminators = nn.ModuleList(
|
476
|
-
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
565
|
+
[DiscriminatorB(window_length=w, **kwargs_disc) for w in self.fft_sizes]
|
477
566
|
)
|
478
567
|
|
479
568
|
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
@@ -500,7 +589,7 @@ class DiscriminatorR(ConvNets):
|
|
500
589
|
self,
|
501
590
|
resolution: List[int],
|
502
591
|
use_spectral_norm: bool = False,
|
503
|
-
|
592
|
+
discriminator_channel_multi: Number = 1,
|
504
593
|
):
|
505
594
|
super().__init__()
|
506
595
|
|
@@ -518,13 +607,13 @@ class DiscriminatorR(ConvNets):
|
|
518
607
|
[
|
519
608
|
norm_f(
|
520
609
|
nn.Conv2d(
|
521
|
-
1, int(32 *
|
610
|
+
1, int(32 * discriminator_channel_multi), (3, 9), padding=(1, 4)
|
522
611
|
)
|
523
612
|
),
|
524
613
|
norm_f(
|
525
614
|
nn.Conv2d(
|
526
|
-
int(32 *
|
527
|
-
int(32 *
|
615
|
+
int(32 * discriminator_channel_multi),
|
616
|
+
int(32 * discriminator_channel_multi),
|
528
617
|
(3, 9),
|
529
618
|
stride=(1, 2),
|
530
619
|
padding=(1, 4),
|
@@ -532,8 +621,8 @@ class DiscriminatorR(ConvNets):
|
|
532
621
|
),
|
533
622
|
norm_f(
|
534
623
|
nn.Conv2d(
|
535
|
-
int(32 *
|
536
|
-
int(32 *
|
624
|
+
int(32 * discriminator_channel_multi),
|
625
|
+
int(32 * discriminator_channel_multi),
|
537
626
|
(3, 9),
|
538
627
|
stride=(1, 2),
|
539
628
|
padding=(1, 4),
|
@@ -541,8 +630,8 @@ class DiscriminatorR(ConvNets):
|
|
541
630
|
),
|
542
631
|
norm_f(
|
543
632
|
nn.Conv2d(
|
544
|
-
int(32 *
|
545
|
-
int(32 *
|
633
|
+
int(32 * discriminator_channel_multi),
|
634
|
+
int(32 * discriminator_channel_multi),
|
546
635
|
(3, 9),
|
547
636
|
stride=(1, 2),
|
548
637
|
padding=(1, 4),
|
@@ -550,8 +639,8 @@ class DiscriminatorR(ConvNets):
|
|
550
639
|
),
|
551
640
|
norm_f(
|
552
641
|
nn.Conv2d(
|
553
|
-
int(32 *
|
554
|
-
int(32 *
|
642
|
+
int(32 * discriminator_channel_multi),
|
643
|
+
int(32 * discriminator_channel_multi),
|
555
644
|
(3, 3),
|
556
645
|
padding=(1, 1),
|
557
646
|
)
|
@@ -559,7 +648,7 @@ class DiscriminatorR(ConvNets):
|
|
559
648
|
]
|
560
649
|
)
|
561
650
|
self.conv_post = norm_f(
|
562
|
-
nn.Conv2d(int(32 *
|
651
|
+
nn.Conv2d(int(32 * discriminator_channel_multi), 1, (3, 3), padding=(1, 1))
|
563
652
|
)
|
564
653
|
|
565
654
|
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
@@ -603,7 +692,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
|
603
692
|
def __init__(
|
604
693
|
self,
|
605
694
|
use_spectral_norm: bool = False,
|
606
|
-
|
695
|
+
discriminator_channel_multi: Number = 1,
|
607
696
|
resolutions: List[List[int]] = [
|
608
697
|
[1024, 120, 600],
|
609
698
|
[2048, 240, 1200],
|
@@ -618,7 +707,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
|
618
707
|
self.discriminators = nn.ModuleList(
|
619
708
|
[
|
620
709
|
DiscriminatorR(
|
621
|
-
resolution, use_spectral_norm,
|
710
|
+
resolution, use_spectral_norm, discriminator_channel_multi
|
622
711
|
)
|
623
712
|
for resolution in self.resolutions
|
624
713
|
]
|
@@ -637,179 +726,3 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
|
|
637
726
|
y_d_gs.append(y_d_g)
|
638
727
|
fmap_gs.append(fmap_g)
|
639
728
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
640
|
-
|
641
|
-
|
642
|
-
class MultiMelScaleLoss(Model):
|
643
|
-
# TODO: Make the normalization an argument to be chosen by the dev
|
644
|
-
def __init__(
|
645
|
-
self,
|
646
|
-
sample_rate: int,
|
647
|
-
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
648
|
-
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
649
|
-
n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
650
|
-
hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
|
651
|
-
weight: float = 1.0,
|
652
|
-
lambda_mel: float = 1.0,
|
653
|
-
f_min: float = [0, 0, 0, 0, 0, 0, 0],
|
654
|
-
f_max: Optional[float] = [None, None, None, None, None, None, None],
|
655
|
-
loss_fn: Callable = nn.L1Loss(),
|
656
|
-
center: bool = True,
|
657
|
-
power: float = 1.0,
|
658
|
-
normalized: bool = False,
|
659
|
-
pad_mode: str = "reflect",
|
660
|
-
onesided: Optional[bool] = None,
|
661
|
-
std: int = 4,
|
662
|
-
mean: int = -4,
|
663
|
-
auto_interpolate: bool = True,
|
664
|
-
use_istft_norm: bool = True,
|
665
|
-
use_pitch_loss: bool = False,
|
666
|
-
use_rms_loss: bool = False,
|
667
|
-
lambda_pitch: float = 0.5,
|
668
|
-
lambda_rms: float = 0.5,
|
669
|
-
):
|
670
|
-
super().__init__()
|
671
|
-
assert (
|
672
|
-
len(n_mels)
|
673
|
-
== len(window_lengths)
|
674
|
-
== len(n_ffts)
|
675
|
-
== len(hops)
|
676
|
-
== len(f_min)
|
677
|
-
== len(f_max)
|
678
|
-
)
|
679
|
-
self.loss_fn = loss_fn
|
680
|
-
self.lambda_mel = lambda_mel
|
681
|
-
self.weight = weight
|
682
|
-
self.use_istft_norm = use_istft_norm
|
683
|
-
self.auto_interpolate = auto_interpolate if not self.use_istft_norm else False
|
684
|
-
self.use_pitch_loss = use_pitch_loss
|
685
|
-
self.use_rms_loss = use_rms_loss
|
686
|
-
self.lambda_pitch = lambda_pitch
|
687
|
-
self.lambda_rms = lambda_rms
|
688
|
-
|
689
|
-
self._setup_mels(
|
690
|
-
sample_rate,
|
691
|
-
n_mels,
|
692
|
-
window_lengths,
|
693
|
-
n_ffts,
|
694
|
-
hops,
|
695
|
-
f_min,
|
696
|
-
f_max,
|
697
|
-
center,
|
698
|
-
power,
|
699
|
-
normalized,
|
700
|
-
pad_mode,
|
701
|
-
onesided,
|
702
|
-
std,
|
703
|
-
mean,
|
704
|
-
)
|
705
|
-
|
706
|
-
def _setup_mels(
|
707
|
-
self,
|
708
|
-
sample_rate: int,
|
709
|
-
n_mels: List[int],
|
710
|
-
window_lengths: List[int],
|
711
|
-
n_ffts: List[int],
|
712
|
-
hops: List[int],
|
713
|
-
f_min: List[float],
|
714
|
-
f_max: List[Optional[float]],
|
715
|
-
center: bool,
|
716
|
-
power: float,
|
717
|
-
normalized: bool,
|
718
|
-
pad_mode: str = "reflect",
|
719
|
-
onesided: Optional[bool] = None,
|
720
|
-
std: int = 4,
|
721
|
-
mean: int = -4,
|
722
|
-
):
|
723
|
-
assert (
|
724
|
-
len(n_mels)
|
725
|
-
== len(window_lengths)
|
726
|
-
== len(n_ffts)
|
727
|
-
== len(hops)
|
728
|
-
== len(f_min)
|
729
|
-
== len(f_max)
|
730
|
-
)
|
731
|
-
_mel_kwargs = dict(
|
732
|
-
sample_rate=sample_rate,
|
733
|
-
center=center,
|
734
|
-
onesided=onesided,
|
735
|
-
normalized=normalized,
|
736
|
-
power=power,
|
737
|
-
pad_mode=pad_mode,
|
738
|
-
std=std,
|
739
|
-
mean=mean,
|
740
|
-
)
|
741
|
-
self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
|
742
|
-
[
|
743
|
-
AudioProcessor(
|
744
|
-
AudioProcessorConfig(
|
745
|
-
**_mel_kwargs,
|
746
|
-
n_mels=mel,
|
747
|
-
n_fft=n_fft,
|
748
|
-
win_length=win,
|
749
|
-
hop_length=hop,
|
750
|
-
f_min=fmin,
|
751
|
-
f_max=fmax,
|
752
|
-
)
|
753
|
-
)
|
754
|
-
for mel, win, n_fft, hop, fmin, fmax in zip(
|
755
|
-
n_mels, window_lengths, n_ffts, hops, f_min, f_max
|
756
|
-
)
|
757
|
-
]
|
758
|
-
)
|
759
|
-
|
760
|
-
def _process_tensor(
|
761
|
-
self,
|
762
|
-
input_wave: torch.Tensor,
|
763
|
-
target_wave: torch.Tensor,
|
764
|
-
):
|
765
|
-
if input_wave.shape[-1] != target_wave.shape[-1]:
|
766
|
-
if input_wave.ndim < 3:
|
767
|
-
# To be compatible with interpolatin
|
768
|
-
if input_wave.ndim == 2:
|
769
|
-
input_wave = input_wave.unsqueeze(1)
|
770
|
-
else:
|
771
|
-
input_wave = input_wave.unsqueeze(0).unsqueeze(0)
|
772
|
-
input_wave = F.interpolate(input_wave, target_wave.shape[-1], mode="linear")
|
773
|
-
return input_wave
|
774
|
-
|
775
|
-
def forward(
|
776
|
-
self, input_wave: torch.Tensor, target_wave: torch.Tensor
|
777
|
-
) -> torch.Tensor:
|
778
|
-
assert (
|
779
|
-
self.use_istft_norm
|
780
|
-
or self.auto_interpolate
|
781
|
-
or input_wave.shape[-1] == target_wave.shape[-1]
|
782
|
-
)
|
783
|
-
if self.auto_interpolate:
|
784
|
-
input_wave = self._process_tensor(input_wave, target_wave)
|
785
|
-
|
786
|
-
losses = 0.0
|
787
|
-
for M in self.mel_spectrograms:
|
788
|
-
# Apply normalization if requested
|
789
|
-
if self.use_istft_norm:
|
790
|
-
input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
|
791
|
-
target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
|
792
|
-
else:
|
793
|
-
input_proc, target_proc = input_wave, target_wave
|
794
|
-
|
795
|
-
x_mels = M(input_proc)
|
796
|
-
y_mels = M(target_proc)
|
797
|
-
|
798
|
-
loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
|
799
|
-
losses += loss * self.lambda_mel
|
800
|
-
|
801
|
-
# pitch/f0 loss
|
802
|
-
if self.use_pitch_loss:
|
803
|
-
x_pitch = normalize_unit_norm(M.compute_pitch(input_proc))
|
804
|
-
y_pitch = normalize_unit_norm(M.compute_pitch(target_proc))
|
805
|
-
f0_loss = self.loss_fn(x_pitch, y_pitch)
|
806
|
-
losses += f0_loss * self.lambda_pitch
|
807
|
-
|
808
|
-
# energy/rms loss
|
809
|
-
if self.use_rms_loss:
|
810
|
-
x_rms = normalize_unit_norm(M.compute_rms(input_proc, x_mels))
|
811
|
-
y_rms = normalize_unit_norm(M.compute_rms(target_proc, y_mels))
|
812
|
-
rms_loss = self.loss_fn(x_rms, y_rms)
|
813
|
-
losses += rms_loss * self.lambda_rms
|
814
|
-
|
815
|
-
return losses * self.weight
|
@@ -0,0 +1,40 @@
|
|
1
|
+
lt_tensor/__init__.py,sha256=4NqhrI_O5q4YQMBpyoLtNUUbBnnbWkO92GE1hxHcrd8,441
|
2
|
+
lt_tensor/config_templates.py,sha256=F9UvL8paAjkSvio890kp8WznpYeI50pYnm9iqQroBxk,2797
|
3
|
+
lt_tensor/losses.py,sha256=Heco_WyoC1HkNkcJEircOAzS9umusATHiNAG-FKGyzc,8918
|
4
|
+
lt_tensor/lr_schedulers.py,sha256=6_vcfaPHrozfH3wvmNEdKSFYl6iTIijYoHL8vuG-45U,7651
|
5
|
+
lt_tensor/math_ops.py,sha256=ahX6Z1Mt3X-FhmwSZYZea5mB1B0S8GDuvKPfAm5e_FQ,2646
|
6
|
+
lt_tensor/misc_utils.py,sha256=stL6q3M7S2N4FBICFYbgYpdPDrJRlwmr24-iCXMRifM,28933
|
7
|
+
lt_tensor/model_base.py,sha256=5T4dbAh4MXbQmPRpihGtMYwTY8sJTQOhY6An3VboM58,18086
|
8
|
+
lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
|
9
|
+
lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
|
10
|
+
lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
|
11
|
+
lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
|
12
|
+
lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
|
13
|
+
lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
|
14
|
+
lt_tensor/model_zoo/convs.py,sha256=Tws0jrPfs9m7OLmJ30W0AfkAvZgppW7lNi4xt0e-qRU,3518
|
15
|
+
lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
|
16
|
+
lt_tensor/model_zoo/fusion.py,sha256=usC1bcjQRNivDc8xzkIS5T1glm78OLcs2V_tPqfp-eI,5422
|
17
|
+
lt_tensor/model_zoo/pos_encoder.py,sha256=3d1EYLinCU9UAy-WuEWeYMGhMqaGknCiQ5qEmhw_UYM,4487
|
18
|
+
lt_tensor/model_zoo/residual.py,sha256=tMXgif9Ggep9bk75K93yueeU5vk5S25AGCRFwOQOyB8,6452
|
19
|
+
lt_tensor/model_zoo/transformer.py,sha256=HUFoFFh7EQJErxdd9XIxhssdjvNVx2tNGDJOTUfwG2A,4301
|
20
|
+
lt_tensor/model_zoo/activations/__init__.py,sha256=f_IsuC-SaFsX6w4OtBWa5bbS4TqR90X-cvLxGUgYfjk,67
|
21
|
+
lt_tensor/model_zoo/activations/alias_free/__init__.py,sha256=dgLjatRm9nusoPVOl1pvCef5rZsaRfS3BJUs05SPYzw,64
|
22
|
+
lt_tensor/model_zoo/activations/alias_free/act.py,sha256=1wxmab2kMD88L6wsQgf3t25dBwR7_he2eM1DlV0FQak,1424
|
23
|
+
lt_tensor/model_zoo/activations/alias_free/filter.py,sha256=5TvXESv31toD5sePBe_OUJJfMXv6Ohwmx2YawjQL-pk,6004
|
24
|
+
lt_tensor/model_zoo/activations/alias_free/resample.py,sha256=3iM4fNr9fLNXXMyXvzW-MwkSjOZOrMZLfS80UHs6zk0,3386
|
25
|
+
lt_tensor/model_zoo/activations/snake/__init__.py,sha256=AtOAbJuMinxmKkppITGMzRbcbPQaALnl9mCtl1c3x0Q,4356
|
26
|
+
lt_tensor/model_zoo/audio_models/__init__.py,sha256=WwiP9MekJreMOfKPWLl24VkRJIpLk6hhL8ch0aKgOss,103
|
27
|
+
lt_tensor/model_zoo/audio_models/resblocks.py,sha256=u-foHxaFDUICjxSkpyHXljQYQG9zMxVYaOGqLR_nJ-k,7978
|
28
|
+
lt_tensor/model_zoo/audio_models/bigvgan/__init__.py,sha256=Dpt_3JXUToldxQrZx4a1gfI-awsLIVipAXqWm4lzBzM,8495
|
29
|
+
lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
|
30
|
+
lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=3HVfEreQ4NqYIC9AWEkmL4ePcIbR1kTyH0cBG8u_Jik,6387
|
31
|
+
lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=blICjLX_z_IFmR3_TCz_dJiSayLYGza9eG6fd9aKyvE,7448
|
32
|
+
lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
|
33
|
+
lt_tensor/model_zoo/losses/discriminators.py,sha256=HBO7jwCsUGsYfSz-JZPZccuYLnto6jfZs3Ve5j51JQE,24247
|
34
|
+
lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
|
35
|
+
lt_tensor/processors/audio.py,sha256=HNr1GS-6M2q0Rda4cErf5y2Jlc9f4jD58FvpX2ua9d4,18369
|
36
|
+
lt_tensor-0.0.1a35.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
|
37
|
+
lt_tensor-0.0.1a35.dist-info/METADATA,sha256=0FrtLNnbU49bKOlyshasXPZOZ90Sok03XkXbtxP4VMI,1062
|
38
|
+
lt_tensor-0.0.1a35.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
39
|
+
lt_tensor-0.0.1a35.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
40
|
+
lt_tensor-0.0.1a35.dist-info/RECORD,,
|
@@ -1 +0,0 @@
|
|
1
|
-
from . import *
|