lt-tensor 0.0.1a27__py3-none-any.whl → 0.0.1a28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/model_zoo/__init__.py +2 -0
- lt_tensor/model_zoo/losses/__init__.py +3 -0
- lt_tensor/model_zoo/losses/discriminators.py +610 -0
- lt_tensor/processors/audio.py +34 -40
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a28.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a28.dist-info}/RECORD +9 -7
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a28.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a28.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a28.dist-info}/top_level.txt +0 -0
lt_tensor/model_zoo/__init__.py
CHANGED
@@ -0,0 +1,610 @@
|
|
1
|
+
from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
|
2
|
+
from lt_utils.common import *
|
3
|
+
from lt_tensor.torch_commons import *
|
4
|
+
from lt_tensor.model_base import Model
|
5
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
6
|
+
from torch.nn import functional as F
|
7
|
+
from torchaudio import transforms as T
|
8
|
+
|
9
|
+
MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
10
|
+
List[Tensor],
|
11
|
+
List[Tensor],
|
12
|
+
List[List[Tensor]],
|
13
|
+
List[List[Tensor]],
|
14
|
+
]
|
15
|
+
|
16
|
+
|
17
|
+
def get_padding(kernel_size, dilation=1):
|
18
|
+
return int((kernel_size * dilation - dilation) / 2)
|
19
|
+
|
20
|
+
|
21
|
+
class MultiDiscriminatorWrapper(ConvNets):
|
22
|
+
"""Base for all multi-steps type of discriminators"""
|
23
|
+
def __init__(self, *args, **kwargs):
|
24
|
+
super().__init__(*args, **kwargs)
|
25
|
+
self.leaky_relu = nn.LeakyReLU(kwargs.get("negative_slope", 0.1))
|
26
|
+
|
27
|
+
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
28
|
+
pass
|
29
|
+
|
30
|
+
# for type hinting
|
31
|
+
def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
|
32
|
+
return super().__call__(*args, **kwds)
|
33
|
+
|
34
|
+
def gen_step(self, y: Tensor, y_hat: Tensor) -> tuple[Tensor, Tensor, List[float]]:
|
35
|
+
"""For generator loss step [feature loss, generator loss, list of generator losses (float)]"""
|
36
|
+
_, y_hat_gen, feat_map_real, feat_map_gen = self.train_step(y, y_hat)
|
37
|
+
loss_feat = self.feature_loss(feat_map_real, feat_map_gen)
|
38
|
+
loss_generator, losses_gen_s = self.generator_loss(y_hat_gen)
|
39
|
+
return loss_feat, loss_generator, losses_gen_s
|
40
|
+
|
41
|
+
def disc_step(
|
42
|
+
self, y: Tensor, y_hat: Tensor
|
43
|
+
) -> tuple[Tensor, tuple[List[float], List[float]]]:
|
44
|
+
"""For discriminator loss step [discriminator loss, (disc losses real, disc losses generated)]"""
|
45
|
+
y_hat_real, y_hat_gen, _, _ = self.train_step(y, y_hat)
|
46
|
+
|
47
|
+
loss_disc, losses_disc_real, losses_disc_generated = self.discriminator_loss(
|
48
|
+
y_hat_real, y_hat_gen
|
49
|
+
)
|
50
|
+
return loss_disc, (losses_disc_real, losses_disc_generated)
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def discriminator_loss(
|
54
|
+
disc_real_outputs, disc_generated_outputs
|
55
|
+
) -> Tuple[Tensor, List[float], List[float]]:
|
56
|
+
loss = 0
|
57
|
+
r_losses = []
|
58
|
+
g_losses = []
|
59
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
60
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
61
|
+
g_loss = torch.mean(dg**2)
|
62
|
+
loss += r_loss + g_loss
|
63
|
+
r_losses.append(r_loss.item())
|
64
|
+
g_losses.append(g_loss.item())
|
65
|
+
|
66
|
+
return loss, r_losses, g_losses
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def feature_loss(fmap_r, fmap_g) -> Tensor:
|
70
|
+
loss = 0
|
71
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
72
|
+
for rl, gl in zip(dr, dg):
|
73
|
+
loss += torch.mean(torch.abs(rl - gl))
|
74
|
+
|
75
|
+
return loss * 2
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def generator_loss(disc_outputs) -> Tuple[Tensor, List[float]]:
|
79
|
+
loss = 0
|
80
|
+
gen_losses = []
|
81
|
+
for dg in disc_outputs:
|
82
|
+
l = torch.mean((1 - dg) ** 2)
|
83
|
+
gen_losses.append(l.item())
|
84
|
+
loss += l
|
85
|
+
|
86
|
+
return loss, gen_losses
|
87
|
+
|
88
|
+
|
89
|
+
class DiscriminatorP(ConvNets):
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
period: List[int],
|
93
|
+
discriminator_channel_mult: Number = 1,
|
94
|
+
kernel_size: int = 5,
|
95
|
+
stride: int = 3,
|
96
|
+
use_spectral_norm: bool = False,
|
97
|
+
):
|
98
|
+
super().__init__()
|
99
|
+
self.period = period
|
100
|
+
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
101
|
+
dsc = lambda x: int(x * discriminator_channel_mult)
|
102
|
+
self.convs = nn.ModuleList(
|
103
|
+
[
|
104
|
+
norm_f(
|
105
|
+
nn.Conv2d(
|
106
|
+
1,
|
107
|
+
dsc(32),
|
108
|
+
(kernel_size, 1),
|
109
|
+
(stride, 1),
|
110
|
+
padding=(get_padding(5, 1), 0),
|
111
|
+
)
|
112
|
+
),
|
113
|
+
norm_f(
|
114
|
+
nn.Conv2d(
|
115
|
+
dsc(32),
|
116
|
+
dsc(128),
|
117
|
+
(kernel_size, 1),
|
118
|
+
(stride, 1),
|
119
|
+
padding=(get_padding(5, 1), 0),
|
120
|
+
)
|
121
|
+
),
|
122
|
+
norm_f(
|
123
|
+
nn.Conv2d(
|
124
|
+
dsc(128),
|
125
|
+
dsc(512),
|
126
|
+
(kernel_size, 1),
|
127
|
+
(stride, 1),
|
128
|
+
padding=(get_padding(5, 1), 0),
|
129
|
+
)
|
130
|
+
),
|
131
|
+
norm_f(
|
132
|
+
nn.Conv2d(
|
133
|
+
dsc(512),
|
134
|
+
dsc(1024),
|
135
|
+
(kernel_size, 1),
|
136
|
+
(stride, 1),
|
137
|
+
padding=(get_padding(5, 1), 0),
|
138
|
+
)
|
139
|
+
),
|
140
|
+
norm_f(
|
141
|
+
nn.Conv2d(
|
142
|
+
dsc(1024),
|
143
|
+
dsc(1024),
|
144
|
+
(kernel_size, 1),
|
145
|
+
1,
|
146
|
+
padding=(2, 0),
|
147
|
+
)
|
148
|
+
),
|
149
|
+
]
|
150
|
+
)
|
151
|
+
self.conv_post = norm_f(nn.Conv2d(dsc(1024), 1, (3, 1), 1, padding=(1, 0)))
|
152
|
+
|
153
|
+
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
154
|
+
fmap = []
|
155
|
+
|
156
|
+
# 1d to 2d
|
157
|
+
b, c, t = x.shape
|
158
|
+
if t % self.period != 0: # pad first
|
159
|
+
n_pad = self.period - (t % self.period)
|
160
|
+
x = F.pad(x, (0, n_pad), "reflect")
|
161
|
+
t = t + n_pad
|
162
|
+
x = x.view(b, c, t // self.period, self.period)
|
163
|
+
|
164
|
+
for l in self.convs:
|
165
|
+
x = l(x)
|
166
|
+
x = F.leaky_relu(x, 0.1)
|
167
|
+
fmap.append(x)
|
168
|
+
x = self.conv_post(x)
|
169
|
+
fmap.append(x)
|
170
|
+
return x.flatten(1, -1), fmap
|
171
|
+
|
172
|
+
|
173
|
+
class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
discriminator_channel_mult: Number = 1,
|
177
|
+
mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
|
178
|
+
use_spectral_norm: bool = False,
|
179
|
+
):
|
180
|
+
super().__init__()
|
181
|
+
self.mpd_reshapes = mpd_reshapes
|
182
|
+
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
183
|
+
self.discriminators = nn.ModuleList(
|
184
|
+
[
|
185
|
+
DiscriminatorP(
|
186
|
+
rs,
|
187
|
+
use_spectral_norm=use_spectral_norm,
|
188
|
+
discriminator_channel_mult=discriminator_channel_mult,
|
189
|
+
)
|
190
|
+
for rs in self.mpd_reshapes
|
191
|
+
]
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> MULTI_DISC_OUT_TYPE:
|
195
|
+
y_d_rs = []
|
196
|
+
y_d_gs = []
|
197
|
+
fmap_rs = []
|
198
|
+
fmap_gs = []
|
199
|
+
for i, d in enumerate(self.discriminators):
|
200
|
+
y_d_r, fmap_r = d(y)
|
201
|
+
y_d_g, fmap_g = d(y_hat)
|
202
|
+
y_d_rs.append(y_d_r)
|
203
|
+
fmap_rs.append(fmap_r)
|
204
|
+
y_d_gs.append(y_d_g)
|
205
|
+
fmap_gs.append(fmap_g)
|
206
|
+
|
207
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
208
|
+
|
209
|
+
|
210
|
+
class EnvelopeExtractor(nn.Module):
|
211
|
+
"""Extracts the amplitude envelope of the audio signal."""
|
212
|
+
|
213
|
+
def __init__(self, kernel_size=101):
|
214
|
+
super().__init__()
|
215
|
+
# Lowpass filter for smoothing envelope (moving average)
|
216
|
+
self.kernel_size = kernel_size
|
217
|
+
self.register_buffer("kernel", torch.ones(1, 1, kernel_size) / kernel_size)
|
218
|
+
|
219
|
+
def forward(self, x):
|
220
|
+
# x: (B, 1, T) -> abs(x)
|
221
|
+
envelope = torch.abs(x)
|
222
|
+
# Apply low-pass smoothing (via conv1d)
|
223
|
+
envelope = F.pad(
|
224
|
+
envelope, (self.kernel_size // 2, self.kernel_size // 2), mode="reflect"
|
225
|
+
)
|
226
|
+
envelope = F.conv1d(envelope, self.kernel)
|
227
|
+
return envelope
|
228
|
+
|
229
|
+
|
230
|
+
class DiscriminatorEnvelope(ConvNets):
|
231
|
+
def __init__(self, use_spectral_norm=False):
|
232
|
+
super().__init__()
|
233
|
+
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
234
|
+
self.extractor = EnvelopeExtractor(kernel_size=101)
|
235
|
+
self.convs = nn.ModuleList(
|
236
|
+
[
|
237
|
+
norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
|
238
|
+
norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
|
239
|
+
norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
|
240
|
+
norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
|
241
|
+
norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
|
242
|
+
norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
|
243
|
+
]
|
244
|
+
)
|
245
|
+
self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
|
246
|
+
self.activation = nn.LeakyReLU(0.1)
|
247
|
+
|
248
|
+
def forward(self, x):
|
249
|
+
# Input: raw audio (B, 1, T)
|
250
|
+
x = self.extractor(x)
|
251
|
+
fmap = []
|
252
|
+
for layer in self.convs:
|
253
|
+
x = self.activation(layer(x))
|
254
|
+
fmap.append(x)
|
255
|
+
x = self.conv_post(x)
|
256
|
+
fmap.append(x)
|
257
|
+
return x.flatten(1), fmap
|
258
|
+
|
259
|
+
|
260
|
+
class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
|
261
|
+
def __init__(self, use_spectral_norm: bool = False):
|
262
|
+
super().__init__()
|
263
|
+
self.discriminators = nn.ModuleList(
|
264
|
+
[
|
265
|
+
DiscriminatorEnvelope(use_spectral_norm), # raw envelope
|
266
|
+
DiscriminatorEnvelope(use_spectral_norm), # downsampled once
|
267
|
+
DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
|
268
|
+
]
|
269
|
+
)
|
270
|
+
self.meanpools = nn.ModuleList(
|
271
|
+
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
272
|
+
)
|
273
|
+
|
274
|
+
def forward(self, y, y_hat):
|
275
|
+
y_d_rs, y_d_gs = [], []
|
276
|
+
fmap_rs, fmap_gs = [], []
|
277
|
+
|
278
|
+
for i, d in enumerate(self.discriminators):
|
279
|
+
if i != 0:
|
280
|
+
y = self.meanpools[i - 1](y)
|
281
|
+
y_hat = self.meanpools[i - 1](y_hat)
|
282
|
+
y_d_r, fmap_r = d(y)
|
283
|
+
y_d_g, fmap_g = d(y_hat)
|
284
|
+
y_d_rs.append(y_d_r)
|
285
|
+
y_d_gs.append(y_d_g)
|
286
|
+
fmap_rs.append(fmap_r)
|
287
|
+
fmap_gs.append(fmap_g)
|
288
|
+
|
289
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
290
|
+
|
291
|
+
|
292
|
+
class DiscriminatorB(ConvNets):
|
293
|
+
"""
|
294
|
+
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
295
|
+
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
296
|
+
"""
|
297
|
+
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
window_length: int,
|
301
|
+
channels: int = 32,
|
302
|
+
hop_factor: float = 0.25,
|
303
|
+
bands: Tuple[Tuple[float, float], ...] = (
|
304
|
+
(0.0, 0.1),
|
305
|
+
(0.1, 0.25),
|
306
|
+
(0.25, 0.5),
|
307
|
+
(0.5, 0.75),
|
308
|
+
(0.75, 1.0),
|
309
|
+
),
|
310
|
+
):
|
311
|
+
super().__init__()
|
312
|
+
self.window_length = window_length
|
313
|
+
self.hop_factor = hop_factor
|
314
|
+
self.spec_fn = T.Spectrogram(
|
315
|
+
n_fft=window_length,
|
316
|
+
hop_length=int(window_length * hop_factor),
|
317
|
+
win_length=window_length,
|
318
|
+
power=None,
|
319
|
+
)
|
320
|
+
n_fft = window_length // 2 + 1
|
321
|
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
322
|
+
self.bands = bands
|
323
|
+
convs = lambda: nn.ModuleList(
|
324
|
+
[
|
325
|
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
326
|
+
weight_norm(
|
327
|
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
328
|
+
),
|
329
|
+
weight_norm(
|
330
|
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
331
|
+
),
|
332
|
+
weight_norm(
|
333
|
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
334
|
+
),
|
335
|
+
weight_norm(
|
336
|
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
337
|
+
),
|
338
|
+
]
|
339
|
+
)
|
340
|
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
341
|
+
|
342
|
+
self.conv_post = weight_norm(
|
343
|
+
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
344
|
+
)
|
345
|
+
|
346
|
+
def spectrogram(self, x: Tensor) -> List[Tensor]:
|
347
|
+
# Remove DC offset
|
348
|
+
x = x - x.mean(dim=-1, keepdims=True)
|
349
|
+
# Peak normalize the volume of input audio
|
350
|
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
351
|
+
x = self.spec_fn(x)
|
352
|
+
x = torch.view_as_real(x)
|
353
|
+
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
354
|
+
# Split into bands
|
355
|
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
356
|
+
return x_bands
|
357
|
+
|
358
|
+
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
359
|
+
x_bands = self.spectrogram(x.squeeze(1))
|
360
|
+
fmap = []
|
361
|
+
x = []
|
362
|
+
|
363
|
+
for band, stack in zip(x_bands, self.band_convs):
|
364
|
+
for i, layer in enumerate(stack):
|
365
|
+
band = layer(band)
|
366
|
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
367
|
+
if i > 0:
|
368
|
+
fmap.append(band)
|
369
|
+
x.append(band)
|
370
|
+
|
371
|
+
x = torch.cat(x, dim=-1)
|
372
|
+
x = self.conv_post(x)
|
373
|
+
fmap.append(x)
|
374
|
+
|
375
|
+
return x, fmap
|
376
|
+
|
377
|
+
|
378
|
+
class MultiBandDiscriminator(MultiDiscriminatorWrapper):
|
379
|
+
"""
|
380
|
+
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
381
|
+
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
382
|
+
"""
|
383
|
+
|
384
|
+
def __init__(
|
385
|
+
self,
|
386
|
+
mbd_fft_sizes: list[int] = [2048, 1024, 512],
|
387
|
+
):
|
388
|
+
super().__init__()
|
389
|
+
self.fft_sizes = mbd_fft_sizes
|
390
|
+
self.discriminators = nn.ModuleList(
|
391
|
+
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
392
|
+
)
|
393
|
+
|
394
|
+
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
395
|
+
|
396
|
+
y_d_rs = []
|
397
|
+
y_d_gs = []
|
398
|
+
fmap_rs = []
|
399
|
+
fmap_gs = []
|
400
|
+
|
401
|
+
for d in self.discriminators:
|
402
|
+
|
403
|
+
y_d_r, fmap_r = d(x=y)
|
404
|
+
y_d_g, fmap_g = d(x=y_hat)
|
405
|
+
y_d_rs.append(y_d_r)
|
406
|
+
fmap_rs.append(fmap_r)
|
407
|
+
y_d_gs.append(y_d_g)
|
408
|
+
fmap_gs.append(fmap_g)
|
409
|
+
|
410
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
411
|
+
|
412
|
+
|
413
|
+
class DiscriminatorR(ConvNets):
|
414
|
+
def __init__(
|
415
|
+
self,
|
416
|
+
resolution: List[int],
|
417
|
+
use_spectral_norm: bool = False,
|
418
|
+
discriminator_channel_mult: int = 1,
|
419
|
+
):
|
420
|
+
super().__init__()
|
421
|
+
|
422
|
+
self.resolution = resolution
|
423
|
+
assert (
|
424
|
+
len(self.resolution) == 3
|
425
|
+
), f"MRD layer requires list with len=3, got {self.resolution}"
|
426
|
+
self.lrelu_slope = 0.1
|
427
|
+
|
428
|
+
self.register_buffer("window", torch.hann_window(self.resolution[-1]))
|
429
|
+
|
430
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
431
|
+
|
432
|
+
self.convs = nn.ModuleList(
|
433
|
+
[
|
434
|
+
norm_f(
|
435
|
+
nn.Conv2d(
|
436
|
+
1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
|
437
|
+
)
|
438
|
+
),
|
439
|
+
norm_f(
|
440
|
+
nn.Conv2d(
|
441
|
+
int(32 * discriminator_channel_mult),
|
442
|
+
int(32 * discriminator_channel_mult),
|
443
|
+
(3, 9),
|
444
|
+
stride=(1, 2),
|
445
|
+
padding=(1, 4),
|
446
|
+
)
|
447
|
+
),
|
448
|
+
norm_f(
|
449
|
+
nn.Conv2d(
|
450
|
+
int(32 * discriminator_channel_mult),
|
451
|
+
int(32 * discriminator_channel_mult),
|
452
|
+
(3, 9),
|
453
|
+
stride=(1, 2),
|
454
|
+
padding=(1, 4),
|
455
|
+
)
|
456
|
+
),
|
457
|
+
norm_f(
|
458
|
+
nn.Conv2d(
|
459
|
+
int(32 * discriminator_channel_mult),
|
460
|
+
int(32 * discriminator_channel_mult),
|
461
|
+
(3, 9),
|
462
|
+
stride=(1, 2),
|
463
|
+
padding=(1, 4),
|
464
|
+
)
|
465
|
+
),
|
466
|
+
norm_f(
|
467
|
+
nn.Conv2d(
|
468
|
+
int(32 * discriminator_channel_mult),
|
469
|
+
int(32 * discriminator_channel_mult),
|
470
|
+
(3, 3),
|
471
|
+
padding=(1, 1),
|
472
|
+
)
|
473
|
+
),
|
474
|
+
]
|
475
|
+
)
|
476
|
+
self.conv_post = norm_f(
|
477
|
+
nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
|
478
|
+
)
|
479
|
+
|
480
|
+
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
481
|
+
fmap = []
|
482
|
+
x = self.spectrogram(x)
|
483
|
+
x = x.unsqueeze(1)
|
484
|
+
for l in self.convs:
|
485
|
+
x = l(x)
|
486
|
+
x = F.leaky_relu(x, self.lrelu_slope)
|
487
|
+
fmap.append(x)
|
488
|
+
x = self.conv_post(x)
|
489
|
+
fmap.append(x)
|
490
|
+
x = torch.flatten(x, 1, -1)
|
491
|
+
|
492
|
+
return x, fmap
|
493
|
+
|
494
|
+
def spectrogram(self, x: Tensor) -> Tensor:
|
495
|
+
n_fft, hop_length, win_length = self.resolution
|
496
|
+
x = F.pad(
|
497
|
+
x,
|
498
|
+
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
499
|
+
mode="reflect",
|
500
|
+
)
|
501
|
+
x = x.squeeze(1)
|
502
|
+
x = torch.stft(
|
503
|
+
x,
|
504
|
+
n_fft=n_fft,
|
505
|
+
hop_length=hop_length,
|
506
|
+
win_length=win_length,
|
507
|
+
center=False,
|
508
|
+
return_complex=True,
|
509
|
+
window=self.window,
|
510
|
+
)
|
511
|
+
x = torch.view_as_real(x) # [B, F, TT, 2]
|
512
|
+
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
513
|
+
|
514
|
+
return mag
|
515
|
+
|
516
|
+
|
517
|
+
class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
|
518
|
+
def __init__(
|
519
|
+
self,
|
520
|
+
use_spectral_norm: bool = False,
|
521
|
+
discriminator_channel_mult: int = 1,
|
522
|
+
resolutions: List[List[int]] = [
|
523
|
+
[1024, 120, 600],
|
524
|
+
[2048, 240, 1200],
|
525
|
+
[512, 50, 240],
|
526
|
+
],
|
527
|
+
):
|
528
|
+
super().__init__()
|
529
|
+
self.resolutions = resolutions
|
530
|
+
assert (
|
531
|
+
len(self.resolutions) == 3
|
532
|
+
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}, type: {type(self.resolutions)}"
|
533
|
+
self.discriminators = nn.ModuleList(
|
534
|
+
[
|
535
|
+
DiscriminatorR(
|
536
|
+
resolution, use_spectral_norm, discriminator_channel_mult
|
537
|
+
)
|
538
|
+
for resolution in self.resolutions
|
539
|
+
]
|
540
|
+
)
|
541
|
+
|
542
|
+
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
543
|
+
y_d_rs = []
|
544
|
+
y_d_gs = []
|
545
|
+
fmap_rs = []
|
546
|
+
fmap_gs = []
|
547
|
+
for disc in self.discriminators:
|
548
|
+
y_d_r, fmap_r = disc(x=y)
|
549
|
+
y_d_g, fmap_g = disc(x=y_hat)
|
550
|
+
y_d_rs.append(y_d_r)
|
551
|
+
fmap_rs.append(fmap_r)
|
552
|
+
y_d_gs.append(y_d_g)
|
553
|
+
fmap_gs.append(fmap_g)
|
554
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
555
|
+
|
556
|
+
|
557
|
+
class MultiDiscriminatorStep(Model):
|
558
|
+
def __init__(self, list_discriminator: List[MultiDiscriminatorWrapper]):
|
559
|
+
super().__init__()
|
560
|
+
self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
|
561
|
+
list_discriminator
|
562
|
+
)
|
563
|
+
self.total = len(self.disc)
|
564
|
+
|
565
|
+
def forward(
|
566
|
+
self,
|
567
|
+
y: Tensor,
|
568
|
+
y_hat: Tensor,
|
569
|
+
step_type: Literal["discriminator", "generator"],
|
570
|
+
) -> Union[
|
571
|
+
Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
|
572
|
+
]:
|
573
|
+
"""
|
574
|
+
It returns the content based on the choice of "step_type", being it a
|
575
|
+
'discriminator' or 'generator'
|
576
|
+
|
577
|
+
For generator it returns:
|
578
|
+
Tuple[Tensor, Tensor, List[float]]
|
579
|
+
"gen_loss, feat_loss, all_g_losses"
|
580
|
+
|
581
|
+
For 'discriminator' it returns:
|
582
|
+
Tuple[Tensor, List[float], List[float]]
|
583
|
+
"disc_loss, disc_real_losses, disc_gen_losses"
|
584
|
+
"""
|
585
|
+
if step_type == "generator":
|
586
|
+
all_g_losses: List[float] = []
|
587
|
+
feat_loss: Tensor = 0
|
588
|
+
gen_loss: Tensor = 0
|
589
|
+
else:
|
590
|
+
disc_loss: Tensor = 0
|
591
|
+
disc_real_losses: List[float] = []
|
592
|
+
disc_gen_losses: List[float] = []
|
593
|
+
|
594
|
+
for disc in self.disc:
|
595
|
+
if step_type == "generator":
|
596
|
+
# feature loss, generator loss, list of generator losses (float)]
|
597
|
+
f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
|
598
|
+
gen_loss += g_loss
|
599
|
+
feat_loss += f_loss
|
600
|
+
all_g_losses.extend(g_losses)
|
601
|
+
else:
|
602
|
+
# [discriminator loss, (disc losses real, disc losses generated)]
|
603
|
+
d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
|
604
|
+
disc_loss += d_loss
|
605
|
+
disc_real_losses.extend(d_real_losses)
|
606
|
+
disc_gen_losses.extend(d_gen_losses)
|
607
|
+
|
608
|
+
if step_type == "generator":
|
609
|
+
return gen_loss, feat_loss, all_g_losses
|
610
|
+
return disc_loss, disc_real_losses, disc_gen_losses
|
lt_tensor/processors/audio.py
CHANGED
@@ -105,7 +105,6 @@ class AudioProcessor(Model):
|
|
105
105
|
onesided=self.cfg.onesided,
|
106
106
|
normalized=self.cfg.normalized,
|
107
107
|
)
|
108
|
-
self.griffin_lm_iters = 32
|
109
108
|
self.mel_rscale = torchaudio.transforms.InverseMelScale(
|
110
109
|
n_stft=self.cfg.n_stft,
|
111
110
|
n_mels=self.cfg.n_mels,
|
@@ -114,21 +113,19 @@ class AudioProcessor(Model):
|
|
114
113
|
f_max=self.cfg.f_max,
|
115
114
|
mel_scale=self.cfg.mel_scale,
|
116
115
|
)
|
117
|
-
|
118
|
-
n_fft=self.cfg.n_fft,
|
119
|
-
win_length=self.cfg.win_length,
|
120
|
-
hop_length=self.cfg.hop_length,
|
121
|
-
)
|
116
|
+
|
122
117
|
self.register_buffer(
|
123
118
|
"window",
|
124
119
|
(torch.hann_window(self.cfg.win_length) if window is None else window),
|
125
120
|
)
|
126
121
|
|
127
122
|
def _apply_device(self):
|
128
|
-
print(f"Audio Processor Device: {self.device.type}")
|
129
|
-
self.giffin_lim.to(device=self.device)
|
130
123
|
self._mel_spec.to(device=self.device)
|
131
124
|
self.mel_rscale.to(device=self.device)
|
125
|
+
try:
|
126
|
+
self.window.to(device=self.device)
|
127
|
+
except:
|
128
|
+
pass
|
132
129
|
|
133
130
|
def from_numpy(
|
134
131
|
self,
|
@@ -173,7 +170,9 @@ class AudioProcessor(Model):
|
|
173
170
|
)
|
174
171
|
|
175
172
|
if audio is None and mel is not None:
|
176
|
-
return self.from_numpy(
|
173
|
+
return self.from_numpy(
|
174
|
+
librosa.feature.rms(S=mel, **rms_kwargs)[0]
|
175
|
+
).squeeze()
|
177
176
|
default_dtype = audio.dtype
|
178
177
|
default_device = audio.device
|
179
178
|
if audio.ndim > 1:
|
@@ -192,8 +191,12 @@ class AudioProcessor(Model):
|
|
192
191
|
audio = self.to_numpy_safe(audio)
|
193
192
|
if B == 1:
|
194
193
|
if mel is None:
|
195
|
-
return self.from_numpy(
|
196
|
-
|
194
|
+
return self.from_numpy(
|
195
|
+
librosa.feature.rms(y=audio, **rms_kwargs)[0]
|
196
|
+
).squeeze()
|
197
|
+
return self.from_numpy(
|
198
|
+
librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0]
|
199
|
+
).squeeze()
|
197
200
|
else:
|
198
201
|
rms_ = []
|
199
202
|
for i in range(B):
|
@@ -201,7 +204,7 @@ class AudioProcessor(Model):
|
|
201
204
|
0
|
202
205
|
]
|
203
206
|
rms_.append(_r)
|
204
|
-
return self.from_numpy_batch(rms_, default_device, default_dtype)
|
207
|
+
return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
|
205
208
|
|
206
209
|
def compute_pitch(
|
207
210
|
self,
|
@@ -273,7 +276,7 @@ class AudioProcessor(Model):
|
|
273
276
|
win_length=win_length,
|
274
277
|
freq_low=fmin,
|
275
278
|
freq_high=fmax,
|
276
|
-
)
|
279
|
+
).squeeze()
|
277
280
|
|
278
281
|
def interpolate(
|
279
282
|
self,
|
@@ -312,7 +315,7 @@ class AudioProcessor(Model):
|
|
312
315
|
antialias=antialias,
|
313
316
|
)
|
314
317
|
|
315
|
-
def
|
318
|
+
def istft(
|
316
319
|
self,
|
317
320
|
spec: Tensor,
|
318
321
|
phase: Tensor,
|
@@ -320,6 +323,10 @@ class AudioProcessor(Model):
|
|
320
323
|
hop_length: Optional[int] = None,
|
321
324
|
win_length: Optional[int] = None,
|
322
325
|
length: Optional[int] = None,
|
326
|
+
center: Optional[bool] = None,
|
327
|
+
normalized: Optional[bool] = None,
|
328
|
+
onesided: Optional[bool] = None,
|
329
|
+
return_complex: bool = False,
|
323
330
|
*,
|
324
331
|
_recall: bool = False,
|
325
332
|
):
|
@@ -331,25 +338,25 @@ class AudioProcessor(Model):
|
|
331
338
|
try:
|
332
339
|
return torch.istft(
|
333
340
|
spec * torch.exp(phase * 1j),
|
334
|
-
n_fft=n_fft
|
335
|
-
hop_length=hop_length
|
336
|
-
win_length=win_length
|
341
|
+
n_fft=default(n_fft, self.cfg.n_fft),
|
342
|
+
hop_length=default(hop_length, self.cfg.hop_length),
|
343
|
+
win_length=default(win_length, self.cfg.win_length),
|
337
344
|
window=window,
|
338
|
-
center=self.cfg.center,
|
339
|
-
normalized=self.cfg.normalized,
|
340
|
-
onesided=self.cfg.onesided,
|
345
|
+
center=default(center, self.cfg.center),
|
346
|
+
normalized=default(normalized, self.cfg.normalized),
|
347
|
+
onesided=default(onesided, self.cfg.onesided),
|
341
348
|
length=length,
|
342
|
-
return_complex=
|
349
|
+
return_complex=return_complex,
|
343
350
|
)
|
344
351
|
except RuntimeError as e:
|
345
352
|
if not _recall and spec.device != self.window.device:
|
346
353
|
self.window = self.window.to(spec.device)
|
347
|
-
return self.
|
354
|
+
return self.istft(
|
348
355
|
spec, phase, n_fft, hop_length, win_length, length, _recall=True
|
349
356
|
)
|
350
357
|
raise e
|
351
358
|
|
352
|
-
def
|
359
|
+
def istft_norm(
|
353
360
|
self,
|
354
361
|
wave: Tensor,
|
355
362
|
length: Optional[int] = None,
|
@@ -389,7 +396,7 @@ class AudioProcessor(Model):
|
|
389
396
|
except RuntimeError as e:
|
390
397
|
if not _recall and wave.device != self.window.device:
|
391
398
|
self.window = self.window.to(wave.device)
|
392
|
-
return self.
|
399
|
+
return self.istft_norm(wave, length, _recall=True)
|
393
400
|
raise e
|
394
401
|
|
395
402
|
def compute_mel(
|
@@ -415,14 +422,6 @@ class AudioProcessor(Model):
|
|
415
422
|
return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
|
416
423
|
raise e
|
417
424
|
|
418
|
-
def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
|
419
|
-
if isinstance(n_iter, int) and n_iter != self.griffin_lm_iters:
|
420
|
-
self.giffin_lim.n_iter = n_iter
|
421
|
-
self.griffin_lm_iters = n_iter
|
422
|
-
return self.giffin_lim.forward(
|
423
|
-
self.mel_rscale(mel),
|
424
|
-
)
|
425
|
-
|
426
425
|
def load_audio(
|
427
426
|
self,
|
428
427
|
path: PathLike,
|
@@ -506,14 +505,9 @@ class AudioProcessor(Model):
|
|
506
505
|
maximum,
|
507
506
|
)
|
508
507
|
|
509
|
-
def stft_loss(
|
510
|
-
|
511
|
-
signal
|
512
|
-
ground: Tensor,
|
513
|
-
):
|
514
|
-
with torch.no_grad():
|
515
|
-
ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
|
516
|
-
return F.l1_loss(signal, ground)
|
508
|
+
def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
|
509
|
+
ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
|
510
|
+
return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
|
517
511
|
|
518
512
|
def forward(
|
519
513
|
self,
|
@@ -9,7 +9,7 @@ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,
|
|
9
9
|
lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
|
10
10
|
lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
|
11
11
|
lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
|
12
|
-
lt_tensor/model_zoo/__init__.py,sha256=
|
12
|
+
lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
|
13
13
|
lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
|
14
14
|
lt_tensor/model_zoo/convs.py,sha256=YQRxek75Qpsha8nfc7wLhmJS9XxPeCa4WxuftLg6IcE,3927
|
15
15
|
lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
|
@@ -26,10 +26,12 @@ lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9
|
|
26
26
|
lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
|
27
27
|
lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
|
28
28
|
lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
|
29
|
+
lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
|
30
|
+
lt_tensor/model_zoo/losses/discriminators.py,sha256=yYh7HzRTUtr0RVTG7cWpcYsJZsRCz6yzg6Loq8FtyOk,20405
|
29
31
|
lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
|
30
|
-
lt_tensor/processors/audio.py,sha256=
|
31
|
-
lt_tensor-0.0.
|
32
|
-
lt_tensor-0.0.
|
33
|
-
lt_tensor-0.0.
|
34
|
-
lt_tensor-0.0.
|
35
|
-
lt_tensor-0.0.
|
32
|
+
lt_tensor/processors/audio.py,sha256=rsnnNi8MtxPq9vAYoiRQ7lGjorfJIpRvrKEe3zA8YJk,16668
|
33
|
+
lt_tensor-0.0.1a28.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
|
34
|
+
lt_tensor-0.0.1a28.dist-info/METADATA,sha256=2LLguzaCAM2bcAdy_D66j4PS9Oh5PU3ZnA9qy7xcx0w,1062
|
35
|
+
lt_tensor-0.0.1a28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
36
|
+
lt_tensor-0.0.1a28.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
37
|
+
lt_tensor-0.0.1a28.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|