lt-tensor 0.0.1a27__py3-none-any.whl → 0.0.1a29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/model_base.py +18 -10
- lt_tensor/model_zoo/__init__.py +2 -0
- lt_tensor/model_zoo/losses/__init__.py +3 -0
- lt_tensor/model_zoo/losses/discriminators.py +610 -0
- lt_tensor/processors/audio.py +30 -44
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a29.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a29.dist-info}/RECORD +10 -8
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a29.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a29.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a27.dist-info → lt_tensor-0.0.1a29.dist-info}/top_level.txt +0 -0
lt_tensor/model_base.py
CHANGED
@@ -137,16 +137,24 @@ class _Devices_Base(nn.Module):
|
|
137
137
|
)
|
138
138
|
|
139
139
|
def _apply_device(self):
|
140
|
-
"""
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
140
|
+
"""This may be seem as overkill, but its necessary"""
|
141
|
+
for modules in self.modules():
|
142
|
+
try:
|
143
|
+
modules.to(self.device)
|
144
|
+
except:
|
145
|
+
pass
|
146
|
+
|
147
|
+
for buffer in self.buffers():
|
148
|
+
try:
|
149
|
+
buffer.to(self.device)
|
150
|
+
except:
|
151
|
+
pass
|
152
|
+
|
153
|
+
for tensor in self.parameters():
|
154
|
+
try:
|
155
|
+
tensor.to(self.device)
|
156
|
+
except:
|
157
|
+
pass
|
150
158
|
|
151
159
|
def _to_dvc(
|
152
160
|
self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
|
lt_tensor/model_zoo/__init__.py
CHANGED
@@ -0,0 +1,610 @@
|
|
1
|
+
from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
|
2
|
+
from lt_utils.common import *
|
3
|
+
from lt_tensor.torch_commons import *
|
4
|
+
from lt_tensor.model_base import Model
|
5
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
6
|
+
from torch.nn import functional as F
|
7
|
+
from torchaudio import transforms as T
|
8
|
+
|
9
|
+
MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
|
10
|
+
List[Tensor],
|
11
|
+
List[Tensor],
|
12
|
+
List[List[Tensor]],
|
13
|
+
List[List[Tensor]],
|
14
|
+
]
|
15
|
+
|
16
|
+
|
17
|
+
def get_padding(kernel_size, dilation=1):
|
18
|
+
return int((kernel_size * dilation - dilation) / 2)
|
19
|
+
|
20
|
+
|
21
|
+
class MultiDiscriminatorWrapper(ConvNets):
|
22
|
+
"""Base for all multi-steps type of discriminators"""
|
23
|
+
def __init__(self, *args, **kwargs):
|
24
|
+
super().__init__(*args, **kwargs)
|
25
|
+
self.leaky_relu = nn.LeakyReLU(kwargs.get("negative_slope", 0.1))
|
26
|
+
|
27
|
+
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
28
|
+
pass
|
29
|
+
|
30
|
+
# for type hinting
|
31
|
+
def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
|
32
|
+
return super().__call__(*args, **kwds)
|
33
|
+
|
34
|
+
def gen_step(self, y: Tensor, y_hat: Tensor) -> tuple[Tensor, Tensor, List[float]]:
|
35
|
+
"""For generator loss step [feature loss, generator loss, list of generator losses (float)]"""
|
36
|
+
_, y_hat_gen, feat_map_real, feat_map_gen = self.train_step(y, y_hat)
|
37
|
+
loss_feat = self.feature_loss(feat_map_real, feat_map_gen)
|
38
|
+
loss_generator, losses_gen_s = self.generator_loss(y_hat_gen)
|
39
|
+
return loss_feat, loss_generator, losses_gen_s
|
40
|
+
|
41
|
+
def disc_step(
|
42
|
+
self, y: Tensor, y_hat: Tensor
|
43
|
+
) -> tuple[Tensor, tuple[List[float], List[float]]]:
|
44
|
+
"""For discriminator loss step [discriminator loss, (disc losses real, disc losses generated)]"""
|
45
|
+
y_hat_real, y_hat_gen, _, _ = self.train_step(y, y_hat)
|
46
|
+
|
47
|
+
loss_disc, losses_disc_real, losses_disc_generated = self.discriminator_loss(
|
48
|
+
y_hat_real, y_hat_gen
|
49
|
+
)
|
50
|
+
return loss_disc, (losses_disc_real, losses_disc_generated)
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def discriminator_loss(
|
54
|
+
disc_real_outputs, disc_generated_outputs
|
55
|
+
) -> Tuple[Tensor, List[float], List[float]]:
|
56
|
+
loss = 0
|
57
|
+
r_losses = []
|
58
|
+
g_losses = []
|
59
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
60
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
61
|
+
g_loss = torch.mean(dg**2)
|
62
|
+
loss += r_loss + g_loss
|
63
|
+
r_losses.append(r_loss.item())
|
64
|
+
g_losses.append(g_loss.item())
|
65
|
+
|
66
|
+
return loss, r_losses, g_losses
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def feature_loss(fmap_r, fmap_g) -> Tensor:
|
70
|
+
loss = 0
|
71
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
72
|
+
for rl, gl in zip(dr, dg):
|
73
|
+
loss += torch.mean(torch.abs(rl - gl))
|
74
|
+
|
75
|
+
return loss * 2
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def generator_loss(disc_outputs) -> Tuple[Tensor, List[float]]:
|
79
|
+
loss = 0
|
80
|
+
gen_losses = []
|
81
|
+
for dg in disc_outputs:
|
82
|
+
l = torch.mean((1 - dg) ** 2)
|
83
|
+
gen_losses.append(l.item())
|
84
|
+
loss += l
|
85
|
+
|
86
|
+
return loss, gen_losses
|
87
|
+
|
88
|
+
|
89
|
+
class DiscriminatorP(ConvNets):
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
period: List[int],
|
93
|
+
discriminator_channel_mult: Number = 1,
|
94
|
+
kernel_size: int = 5,
|
95
|
+
stride: int = 3,
|
96
|
+
use_spectral_norm: bool = False,
|
97
|
+
):
|
98
|
+
super().__init__()
|
99
|
+
self.period = period
|
100
|
+
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
101
|
+
dsc = lambda x: int(x * discriminator_channel_mult)
|
102
|
+
self.convs = nn.ModuleList(
|
103
|
+
[
|
104
|
+
norm_f(
|
105
|
+
nn.Conv2d(
|
106
|
+
1,
|
107
|
+
dsc(32),
|
108
|
+
(kernel_size, 1),
|
109
|
+
(stride, 1),
|
110
|
+
padding=(get_padding(5, 1), 0),
|
111
|
+
)
|
112
|
+
),
|
113
|
+
norm_f(
|
114
|
+
nn.Conv2d(
|
115
|
+
dsc(32),
|
116
|
+
dsc(128),
|
117
|
+
(kernel_size, 1),
|
118
|
+
(stride, 1),
|
119
|
+
padding=(get_padding(5, 1), 0),
|
120
|
+
)
|
121
|
+
),
|
122
|
+
norm_f(
|
123
|
+
nn.Conv2d(
|
124
|
+
dsc(128),
|
125
|
+
dsc(512),
|
126
|
+
(kernel_size, 1),
|
127
|
+
(stride, 1),
|
128
|
+
padding=(get_padding(5, 1), 0),
|
129
|
+
)
|
130
|
+
),
|
131
|
+
norm_f(
|
132
|
+
nn.Conv2d(
|
133
|
+
dsc(512),
|
134
|
+
dsc(1024),
|
135
|
+
(kernel_size, 1),
|
136
|
+
(stride, 1),
|
137
|
+
padding=(get_padding(5, 1), 0),
|
138
|
+
)
|
139
|
+
),
|
140
|
+
norm_f(
|
141
|
+
nn.Conv2d(
|
142
|
+
dsc(1024),
|
143
|
+
dsc(1024),
|
144
|
+
(kernel_size, 1),
|
145
|
+
1,
|
146
|
+
padding=(2, 0),
|
147
|
+
)
|
148
|
+
),
|
149
|
+
]
|
150
|
+
)
|
151
|
+
self.conv_post = norm_f(nn.Conv2d(dsc(1024), 1, (3, 1), 1, padding=(1, 0)))
|
152
|
+
|
153
|
+
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
154
|
+
fmap = []
|
155
|
+
|
156
|
+
# 1d to 2d
|
157
|
+
b, c, t = x.shape
|
158
|
+
if t % self.period != 0: # pad first
|
159
|
+
n_pad = self.period - (t % self.period)
|
160
|
+
x = F.pad(x, (0, n_pad), "reflect")
|
161
|
+
t = t + n_pad
|
162
|
+
x = x.view(b, c, t // self.period, self.period)
|
163
|
+
|
164
|
+
for l in self.convs:
|
165
|
+
x = l(x)
|
166
|
+
x = F.leaky_relu(x, 0.1)
|
167
|
+
fmap.append(x)
|
168
|
+
x = self.conv_post(x)
|
169
|
+
fmap.append(x)
|
170
|
+
return x.flatten(1, -1), fmap
|
171
|
+
|
172
|
+
|
173
|
+
class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
discriminator_channel_mult: Number = 1,
|
177
|
+
mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
|
178
|
+
use_spectral_norm: bool = False,
|
179
|
+
):
|
180
|
+
super().__init__()
|
181
|
+
self.mpd_reshapes = mpd_reshapes
|
182
|
+
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
183
|
+
self.discriminators = nn.ModuleList(
|
184
|
+
[
|
185
|
+
DiscriminatorP(
|
186
|
+
rs,
|
187
|
+
use_spectral_norm=use_spectral_norm,
|
188
|
+
discriminator_channel_mult=discriminator_channel_mult,
|
189
|
+
)
|
190
|
+
for rs in self.mpd_reshapes
|
191
|
+
]
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> MULTI_DISC_OUT_TYPE:
|
195
|
+
y_d_rs = []
|
196
|
+
y_d_gs = []
|
197
|
+
fmap_rs = []
|
198
|
+
fmap_gs = []
|
199
|
+
for i, d in enumerate(self.discriminators):
|
200
|
+
y_d_r, fmap_r = d(y)
|
201
|
+
y_d_g, fmap_g = d(y_hat)
|
202
|
+
y_d_rs.append(y_d_r)
|
203
|
+
fmap_rs.append(fmap_r)
|
204
|
+
y_d_gs.append(y_d_g)
|
205
|
+
fmap_gs.append(fmap_g)
|
206
|
+
|
207
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
208
|
+
|
209
|
+
|
210
|
+
class EnvelopeExtractor(nn.Module):
|
211
|
+
"""Extracts the amplitude envelope of the audio signal."""
|
212
|
+
|
213
|
+
def __init__(self, kernel_size=101):
|
214
|
+
super().__init__()
|
215
|
+
# Lowpass filter for smoothing envelope (moving average)
|
216
|
+
self.kernel_size = kernel_size
|
217
|
+
self.register_buffer("kernel", torch.ones(1, 1, kernel_size) / kernel_size)
|
218
|
+
|
219
|
+
def forward(self, x):
|
220
|
+
# x: (B, 1, T) -> abs(x)
|
221
|
+
envelope = torch.abs(x)
|
222
|
+
# Apply low-pass smoothing (via conv1d)
|
223
|
+
envelope = F.pad(
|
224
|
+
envelope, (self.kernel_size // 2, self.kernel_size // 2), mode="reflect"
|
225
|
+
)
|
226
|
+
envelope = F.conv1d(envelope, self.kernel)
|
227
|
+
return envelope
|
228
|
+
|
229
|
+
|
230
|
+
class DiscriminatorEnvelope(ConvNets):
|
231
|
+
def __init__(self, use_spectral_norm=False):
|
232
|
+
super().__init__()
|
233
|
+
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
234
|
+
self.extractor = EnvelopeExtractor(kernel_size=101)
|
235
|
+
self.convs = nn.ModuleList(
|
236
|
+
[
|
237
|
+
norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
|
238
|
+
norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
|
239
|
+
norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
|
240
|
+
norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
|
241
|
+
norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
|
242
|
+
norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
|
243
|
+
]
|
244
|
+
)
|
245
|
+
self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
|
246
|
+
self.activation = nn.LeakyReLU(0.1)
|
247
|
+
|
248
|
+
def forward(self, x):
|
249
|
+
# Input: raw audio (B, 1, T)
|
250
|
+
x = self.extractor(x)
|
251
|
+
fmap = []
|
252
|
+
for layer in self.convs:
|
253
|
+
x = self.activation(layer(x))
|
254
|
+
fmap.append(x)
|
255
|
+
x = self.conv_post(x)
|
256
|
+
fmap.append(x)
|
257
|
+
return x.flatten(1), fmap
|
258
|
+
|
259
|
+
|
260
|
+
class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
|
261
|
+
def __init__(self, use_spectral_norm: bool = False):
|
262
|
+
super().__init__()
|
263
|
+
self.discriminators = nn.ModuleList(
|
264
|
+
[
|
265
|
+
DiscriminatorEnvelope(use_spectral_norm), # raw envelope
|
266
|
+
DiscriminatorEnvelope(use_spectral_norm), # downsampled once
|
267
|
+
DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
|
268
|
+
]
|
269
|
+
)
|
270
|
+
self.meanpools = nn.ModuleList(
|
271
|
+
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
272
|
+
)
|
273
|
+
|
274
|
+
def forward(self, y, y_hat):
|
275
|
+
y_d_rs, y_d_gs = [], []
|
276
|
+
fmap_rs, fmap_gs = [], []
|
277
|
+
|
278
|
+
for i, d in enumerate(self.discriminators):
|
279
|
+
if i != 0:
|
280
|
+
y = self.meanpools[i - 1](y)
|
281
|
+
y_hat = self.meanpools[i - 1](y_hat)
|
282
|
+
y_d_r, fmap_r = d(y)
|
283
|
+
y_d_g, fmap_g = d(y_hat)
|
284
|
+
y_d_rs.append(y_d_r)
|
285
|
+
y_d_gs.append(y_d_g)
|
286
|
+
fmap_rs.append(fmap_r)
|
287
|
+
fmap_gs.append(fmap_g)
|
288
|
+
|
289
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
290
|
+
|
291
|
+
|
292
|
+
class DiscriminatorB(ConvNets):
|
293
|
+
"""
|
294
|
+
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
295
|
+
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
296
|
+
"""
|
297
|
+
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
window_length: int,
|
301
|
+
channels: int = 32,
|
302
|
+
hop_factor: float = 0.25,
|
303
|
+
bands: Tuple[Tuple[float, float], ...] = (
|
304
|
+
(0.0, 0.1),
|
305
|
+
(0.1, 0.25),
|
306
|
+
(0.25, 0.5),
|
307
|
+
(0.5, 0.75),
|
308
|
+
(0.75, 1.0),
|
309
|
+
),
|
310
|
+
):
|
311
|
+
super().__init__()
|
312
|
+
self.window_length = window_length
|
313
|
+
self.hop_factor = hop_factor
|
314
|
+
self.spec_fn = T.Spectrogram(
|
315
|
+
n_fft=window_length,
|
316
|
+
hop_length=int(window_length * hop_factor),
|
317
|
+
win_length=window_length,
|
318
|
+
power=None,
|
319
|
+
)
|
320
|
+
n_fft = window_length // 2 + 1
|
321
|
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
322
|
+
self.bands = bands
|
323
|
+
convs = lambda: nn.ModuleList(
|
324
|
+
[
|
325
|
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
326
|
+
weight_norm(
|
327
|
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
328
|
+
),
|
329
|
+
weight_norm(
|
330
|
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
331
|
+
),
|
332
|
+
weight_norm(
|
333
|
+
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
334
|
+
),
|
335
|
+
weight_norm(
|
336
|
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
337
|
+
),
|
338
|
+
]
|
339
|
+
)
|
340
|
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
341
|
+
|
342
|
+
self.conv_post = weight_norm(
|
343
|
+
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
344
|
+
)
|
345
|
+
|
346
|
+
def spectrogram(self, x: Tensor) -> List[Tensor]:
|
347
|
+
# Remove DC offset
|
348
|
+
x = x - x.mean(dim=-1, keepdims=True)
|
349
|
+
# Peak normalize the volume of input audio
|
350
|
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
351
|
+
x = self.spec_fn(x)
|
352
|
+
x = torch.view_as_real(x)
|
353
|
+
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
354
|
+
# Split into bands
|
355
|
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
356
|
+
return x_bands
|
357
|
+
|
358
|
+
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
359
|
+
x_bands = self.spectrogram(x.squeeze(1))
|
360
|
+
fmap = []
|
361
|
+
x = []
|
362
|
+
|
363
|
+
for band, stack in zip(x_bands, self.band_convs):
|
364
|
+
for i, layer in enumerate(stack):
|
365
|
+
band = layer(band)
|
366
|
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
367
|
+
if i > 0:
|
368
|
+
fmap.append(band)
|
369
|
+
x.append(band)
|
370
|
+
|
371
|
+
x = torch.cat(x, dim=-1)
|
372
|
+
x = self.conv_post(x)
|
373
|
+
fmap.append(x)
|
374
|
+
|
375
|
+
return x, fmap
|
376
|
+
|
377
|
+
|
378
|
+
class MultiBandDiscriminator(MultiDiscriminatorWrapper):
|
379
|
+
"""
|
380
|
+
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
381
|
+
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
382
|
+
"""
|
383
|
+
|
384
|
+
def __init__(
|
385
|
+
self,
|
386
|
+
mbd_fft_sizes: list[int] = [2048, 1024, 512],
|
387
|
+
):
|
388
|
+
super().__init__()
|
389
|
+
self.fft_sizes = mbd_fft_sizes
|
390
|
+
self.discriminators = nn.ModuleList(
|
391
|
+
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
392
|
+
)
|
393
|
+
|
394
|
+
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
395
|
+
|
396
|
+
y_d_rs = []
|
397
|
+
y_d_gs = []
|
398
|
+
fmap_rs = []
|
399
|
+
fmap_gs = []
|
400
|
+
|
401
|
+
for d in self.discriminators:
|
402
|
+
|
403
|
+
y_d_r, fmap_r = d(x=y)
|
404
|
+
y_d_g, fmap_g = d(x=y_hat)
|
405
|
+
y_d_rs.append(y_d_r)
|
406
|
+
fmap_rs.append(fmap_r)
|
407
|
+
y_d_gs.append(y_d_g)
|
408
|
+
fmap_gs.append(fmap_g)
|
409
|
+
|
410
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
411
|
+
|
412
|
+
|
413
|
+
class DiscriminatorR(ConvNets):
|
414
|
+
def __init__(
|
415
|
+
self,
|
416
|
+
resolution: List[int],
|
417
|
+
use_spectral_norm: bool = False,
|
418
|
+
discriminator_channel_mult: int = 1,
|
419
|
+
):
|
420
|
+
super().__init__()
|
421
|
+
|
422
|
+
self.resolution = resolution
|
423
|
+
assert (
|
424
|
+
len(self.resolution) == 3
|
425
|
+
), f"MRD layer requires list with len=3, got {self.resolution}"
|
426
|
+
self.lrelu_slope = 0.1
|
427
|
+
|
428
|
+
self.register_buffer("window", torch.hann_window(self.resolution[-1]))
|
429
|
+
|
430
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
431
|
+
|
432
|
+
self.convs = nn.ModuleList(
|
433
|
+
[
|
434
|
+
norm_f(
|
435
|
+
nn.Conv2d(
|
436
|
+
1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
|
437
|
+
)
|
438
|
+
),
|
439
|
+
norm_f(
|
440
|
+
nn.Conv2d(
|
441
|
+
int(32 * discriminator_channel_mult),
|
442
|
+
int(32 * discriminator_channel_mult),
|
443
|
+
(3, 9),
|
444
|
+
stride=(1, 2),
|
445
|
+
padding=(1, 4),
|
446
|
+
)
|
447
|
+
),
|
448
|
+
norm_f(
|
449
|
+
nn.Conv2d(
|
450
|
+
int(32 * discriminator_channel_mult),
|
451
|
+
int(32 * discriminator_channel_mult),
|
452
|
+
(3, 9),
|
453
|
+
stride=(1, 2),
|
454
|
+
padding=(1, 4),
|
455
|
+
)
|
456
|
+
),
|
457
|
+
norm_f(
|
458
|
+
nn.Conv2d(
|
459
|
+
int(32 * discriminator_channel_mult),
|
460
|
+
int(32 * discriminator_channel_mult),
|
461
|
+
(3, 9),
|
462
|
+
stride=(1, 2),
|
463
|
+
padding=(1, 4),
|
464
|
+
)
|
465
|
+
),
|
466
|
+
norm_f(
|
467
|
+
nn.Conv2d(
|
468
|
+
int(32 * discriminator_channel_mult),
|
469
|
+
int(32 * discriminator_channel_mult),
|
470
|
+
(3, 3),
|
471
|
+
padding=(1, 1),
|
472
|
+
)
|
473
|
+
),
|
474
|
+
]
|
475
|
+
)
|
476
|
+
self.conv_post = norm_f(
|
477
|
+
nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
|
478
|
+
)
|
479
|
+
|
480
|
+
def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
|
481
|
+
fmap = []
|
482
|
+
x = self.spectrogram(x)
|
483
|
+
x = x.unsqueeze(1)
|
484
|
+
for l in self.convs:
|
485
|
+
x = l(x)
|
486
|
+
x = F.leaky_relu(x, self.lrelu_slope)
|
487
|
+
fmap.append(x)
|
488
|
+
x = self.conv_post(x)
|
489
|
+
fmap.append(x)
|
490
|
+
x = torch.flatten(x, 1, -1)
|
491
|
+
|
492
|
+
return x, fmap
|
493
|
+
|
494
|
+
def spectrogram(self, x: Tensor) -> Tensor:
|
495
|
+
n_fft, hop_length, win_length = self.resolution
|
496
|
+
x = F.pad(
|
497
|
+
x,
|
498
|
+
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
499
|
+
mode="reflect",
|
500
|
+
)
|
501
|
+
x = x.squeeze(1)
|
502
|
+
x = torch.stft(
|
503
|
+
x,
|
504
|
+
n_fft=n_fft,
|
505
|
+
hop_length=hop_length,
|
506
|
+
win_length=win_length,
|
507
|
+
center=False,
|
508
|
+
return_complex=True,
|
509
|
+
window=self.window,
|
510
|
+
)
|
511
|
+
x = torch.view_as_real(x) # [B, F, TT, 2]
|
512
|
+
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
513
|
+
|
514
|
+
return mag
|
515
|
+
|
516
|
+
|
517
|
+
class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
|
518
|
+
def __init__(
|
519
|
+
self,
|
520
|
+
use_spectral_norm: bool = False,
|
521
|
+
discriminator_channel_mult: int = 1,
|
522
|
+
resolutions: List[List[int]] = [
|
523
|
+
[1024, 120, 600],
|
524
|
+
[2048, 240, 1200],
|
525
|
+
[512, 50, 240],
|
526
|
+
],
|
527
|
+
):
|
528
|
+
super().__init__()
|
529
|
+
self.resolutions = resolutions
|
530
|
+
assert (
|
531
|
+
len(self.resolutions) == 3
|
532
|
+
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}, type: {type(self.resolutions)}"
|
533
|
+
self.discriminators = nn.ModuleList(
|
534
|
+
[
|
535
|
+
DiscriminatorR(
|
536
|
+
resolution, use_spectral_norm, discriminator_channel_mult
|
537
|
+
)
|
538
|
+
for resolution in self.resolutions
|
539
|
+
]
|
540
|
+
)
|
541
|
+
|
542
|
+
def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
|
543
|
+
y_d_rs = []
|
544
|
+
y_d_gs = []
|
545
|
+
fmap_rs = []
|
546
|
+
fmap_gs = []
|
547
|
+
for disc in self.discriminators:
|
548
|
+
y_d_r, fmap_r = disc(x=y)
|
549
|
+
y_d_g, fmap_g = disc(x=y_hat)
|
550
|
+
y_d_rs.append(y_d_r)
|
551
|
+
fmap_rs.append(fmap_r)
|
552
|
+
y_d_gs.append(y_d_g)
|
553
|
+
fmap_gs.append(fmap_g)
|
554
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
555
|
+
|
556
|
+
|
557
|
+
class MultiDiscriminatorStep(Model):
|
558
|
+
def __init__(self, list_discriminator: List[MultiDiscriminatorWrapper]):
|
559
|
+
super().__init__()
|
560
|
+
self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
|
561
|
+
list_discriminator
|
562
|
+
)
|
563
|
+
self.total = len(self.disc)
|
564
|
+
|
565
|
+
def forward(
|
566
|
+
self,
|
567
|
+
y: Tensor,
|
568
|
+
y_hat: Tensor,
|
569
|
+
step_type: Literal["discriminator", "generator"],
|
570
|
+
) -> Union[
|
571
|
+
Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
|
572
|
+
]:
|
573
|
+
"""
|
574
|
+
It returns the content based on the choice of "step_type", being it a
|
575
|
+
'discriminator' or 'generator'
|
576
|
+
|
577
|
+
For generator it returns:
|
578
|
+
Tuple[Tensor, Tensor, List[float]]
|
579
|
+
"gen_loss, feat_loss, all_g_losses"
|
580
|
+
|
581
|
+
For 'discriminator' it returns:
|
582
|
+
Tuple[Tensor, List[float], List[float]]
|
583
|
+
"disc_loss, disc_real_losses, disc_gen_losses"
|
584
|
+
"""
|
585
|
+
if step_type == "generator":
|
586
|
+
all_g_losses: List[float] = []
|
587
|
+
feat_loss: Tensor = 0
|
588
|
+
gen_loss: Tensor = 0
|
589
|
+
else:
|
590
|
+
disc_loss: Tensor = 0
|
591
|
+
disc_real_losses: List[float] = []
|
592
|
+
disc_gen_losses: List[float] = []
|
593
|
+
|
594
|
+
for disc in self.disc:
|
595
|
+
if step_type == "generator":
|
596
|
+
# feature loss, generator loss, list of generator losses (float)]
|
597
|
+
f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
|
598
|
+
gen_loss += g_loss
|
599
|
+
feat_loss += f_loss
|
600
|
+
all_g_losses.extend(g_losses)
|
601
|
+
else:
|
602
|
+
# [discriminator loss, (disc losses real, disc losses generated)]
|
603
|
+
d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
|
604
|
+
disc_loss += d_loss
|
605
|
+
disc_real_losses.extend(d_real_losses)
|
606
|
+
disc_gen_losses.extend(d_gen_losses)
|
607
|
+
|
608
|
+
if step_type == "generator":
|
609
|
+
return gen_loss, feat_loss, all_g_losses
|
610
|
+
return disc_loss, disc_real_losses, disc_gen_losses
|
lt_tensor/processors/audio.py
CHANGED
@@ -105,7 +105,6 @@ class AudioProcessor(Model):
|
|
105
105
|
onesided=self.cfg.onesided,
|
106
106
|
normalized=self.cfg.normalized,
|
107
107
|
)
|
108
|
-
self.griffin_lm_iters = 32
|
109
108
|
self.mel_rscale = torchaudio.transforms.InverseMelScale(
|
110
109
|
n_stft=self.cfg.n_stft,
|
111
110
|
n_mels=self.cfg.n_mels,
|
@@ -114,22 +113,12 @@ class AudioProcessor(Model):
|
|
114
113
|
f_max=self.cfg.f_max,
|
115
114
|
mel_scale=self.cfg.mel_scale,
|
116
115
|
)
|
117
|
-
|
118
|
-
n_fft=self.cfg.n_fft,
|
119
|
-
win_length=self.cfg.win_length,
|
120
|
-
hop_length=self.cfg.hop_length,
|
121
|
-
)
|
116
|
+
|
122
117
|
self.register_buffer(
|
123
118
|
"window",
|
124
119
|
(torch.hann_window(self.cfg.win_length) if window is None else window),
|
125
120
|
)
|
126
121
|
|
127
|
-
def _apply_device(self):
|
128
|
-
print(f"Audio Processor Device: {self.device.type}")
|
129
|
-
self.giffin_lim.to(device=self.device)
|
130
|
-
self._mel_spec.to(device=self.device)
|
131
|
-
self.mel_rscale.to(device=self.device)
|
132
|
-
|
133
122
|
def from_numpy(
|
134
123
|
self,
|
135
124
|
array: np.ndarray,
|
@@ -173,7 +162,9 @@ class AudioProcessor(Model):
|
|
173
162
|
)
|
174
163
|
|
175
164
|
if audio is None and mel is not None:
|
176
|
-
return self.from_numpy(
|
165
|
+
return self.from_numpy(
|
166
|
+
librosa.feature.rms(S=mel, **rms_kwargs)[0]
|
167
|
+
).squeeze()
|
177
168
|
default_dtype = audio.dtype
|
178
169
|
default_device = audio.device
|
179
170
|
if audio.ndim > 1:
|
@@ -192,8 +183,12 @@ class AudioProcessor(Model):
|
|
192
183
|
audio = self.to_numpy_safe(audio)
|
193
184
|
if B == 1:
|
194
185
|
if mel is None:
|
195
|
-
return self.from_numpy(
|
196
|
-
|
186
|
+
return self.from_numpy(
|
187
|
+
librosa.feature.rms(y=audio, **rms_kwargs)[0]
|
188
|
+
).squeeze()
|
189
|
+
return self.from_numpy(
|
190
|
+
librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0]
|
191
|
+
).squeeze()
|
197
192
|
else:
|
198
193
|
rms_ = []
|
199
194
|
for i in range(B):
|
@@ -201,7 +196,7 @@ class AudioProcessor(Model):
|
|
201
196
|
0
|
202
197
|
]
|
203
198
|
rms_.append(_r)
|
204
|
-
return self.from_numpy_batch(rms_, default_device, default_dtype)
|
199
|
+
return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
|
205
200
|
|
206
201
|
def compute_pitch(
|
207
202
|
self,
|
@@ -273,7 +268,7 @@ class AudioProcessor(Model):
|
|
273
268
|
win_length=win_length,
|
274
269
|
freq_low=fmin,
|
275
270
|
freq_high=fmax,
|
276
|
-
)
|
271
|
+
).squeeze()
|
277
272
|
|
278
273
|
def interpolate(
|
279
274
|
self,
|
@@ -312,7 +307,7 @@ class AudioProcessor(Model):
|
|
312
307
|
antialias=antialias,
|
313
308
|
)
|
314
309
|
|
315
|
-
def
|
310
|
+
def istft(
|
316
311
|
self,
|
317
312
|
spec: Tensor,
|
318
313
|
phase: Tensor,
|
@@ -320,6 +315,10 @@ class AudioProcessor(Model):
|
|
320
315
|
hop_length: Optional[int] = None,
|
321
316
|
win_length: Optional[int] = None,
|
322
317
|
length: Optional[int] = None,
|
318
|
+
center: Optional[bool] = None,
|
319
|
+
normalized: Optional[bool] = None,
|
320
|
+
onesided: Optional[bool] = None,
|
321
|
+
return_complex: bool = False,
|
323
322
|
*,
|
324
323
|
_recall: bool = False,
|
325
324
|
):
|
@@ -331,25 +330,25 @@ class AudioProcessor(Model):
|
|
331
330
|
try:
|
332
331
|
return torch.istft(
|
333
332
|
spec * torch.exp(phase * 1j),
|
334
|
-
n_fft=n_fft
|
335
|
-
hop_length=hop_length
|
336
|
-
win_length=win_length
|
333
|
+
n_fft=default(n_fft, self.cfg.n_fft),
|
334
|
+
hop_length=default(hop_length, self.cfg.hop_length),
|
335
|
+
win_length=default(win_length, self.cfg.win_length),
|
337
336
|
window=window,
|
338
|
-
center=self.cfg.center,
|
339
|
-
normalized=self.cfg.normalized,
|
340
|
-
onesided=self.cfg.onesided,
|
337
|
+
center=default(center, self.cfg.center),
|
338
|
+
normalized=default(normalized, self.cfg.normalized),
|
339
|
+
onesided=default(onesided, self.cfg.onesided),
|
341
340
|
length=length,
|
342
|
-
return_complex=
|
341
|
+
return_complex=return_complex,
|
343
342
|
)
|
344
343
|
except RuntimeError as e:
|
345
344
|
if not _recall and spec.device != self.window.device:
|
346
345
|
self.window = self.window.to(spec.device)
|
347
|
-
return self.
|
346
|
+
return self.istft(
|
348
347
|
spec, phase, n_fft, hop_length, win_length, length, _recall=True
|
349
348
|
)
|
350
349
|
raise e
|
351
350
|
|
352
|
-
def
|
351
|
+
def istft_norm(
|
353
352
|
self,
|
354
353
|
wave: Tensor,
|
355
354
|
length: Optional[int] = None,
|
@@ -389,7 +388,7 @@ class AudioProcessor(Model):
|
|
389
388
|
except RuntimeError as e:
|
390
389
|
if not _recall and wave.device != self.window.device:
|
391
390
|
self.window = self.window.to(wave.device)
|
392
|
-
return self.
|
391
|
+
return self.istft_norm(wave, length, _recall=True)
|
393
392
|
raise e
|
394
393
|
|
395
394
|
def compute_mel(
|
@@ -415,14 +414,6 @@ class AudioProcessor(Model):
|
|
415
414
|
return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
|
416
415
|
raise e
|
417
416
|
|
418
|
-
def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
|
419
|
-
if isinstance(n_iter, int) and n_iter != self.griffin_lm_iters:
|
420
|
-
self.giffin_lim.n_iter = n_iter
|
421
|
-
self.griffin_lm_iters = n_iter
|
422
|
-
return self.giffin_lim.forward(
|
423
|
-
self.mel_rscale(mel),
|
424
|
-
)
|
425
|
-
|
426
417
|
def load_audio(
|
427
418
|
self,
|
428
419
|
path: PathLike,
|
@@ -506,14 +497,9 @@ class AudioProcessor(Model):
|
|
506
497
|
maximum,
|
507
498
|
)
|
508
499
|
|
509
|
-
def stft_loss(
|
510
|
-
|
511
|
-
signal
|
512
|
-
ground: Tensor,
|
513
|
-
):
|
514
|
-
with torch.no_grad():
|
515
|
-
ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
|
516
|
-
return F.l1_loss(signal, ground)
|
500
|
+
def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
|
501
|
+
ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
|
502
|
+
return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
|
517
503
|
|
518
504
|
def forward(
|
519
505
|
self,
|
@@ -4,12 +4,12 @@ lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
|
|
4
4
|
lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
|
5
5
|
lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
|
6
6
|
lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
|
7
|
-
lt_tensor/model_base.py,sha256=
|
7
|
+
lt_tensor/model_base.py,sha256=8qcFXe0_y8f1_tAwt18gwQjyyapbnVEKcjCMrKnQatw,17614
|
8
8
|
lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
|
9
9
|
lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
|
10
10
|
lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
|
11
11
|
lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
|
12
|
-
lt_tensor/model_zoo/__init__.py,sha256=
|
12
|
+
lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
|
13
13
|
lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
|
14
14
|
lt_tensor/model_zoo/convs.py,sha256=YQRxek75Qpsha8nfc7wLhmJS9XxPeCa4WxuftLg6IcE,3927
|
15
15
|
lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
|
@@ -26,10 +26,12 @@ lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9
|
|
26
26
|
lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
|
27
27
|
lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
|
28
28
|
lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
|
29
|
+
lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
|
30
|
+
lt_tensor/model_zoo/losses/discriminators.py,sha256=yYh7HzRTUtr0RVTG7cWpcYsJZsRCz6yzg6Loq8FtyOk,20405
|
29
31
|
lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
|
30
|
-
lt_tensor/processors/audio.py,sha256=
|
31
|
-
lt_tensor-0.0.
|
32
|
-
lt_tensor-0.0.
|
33
|
-
lt_tensor-0.0.
|
34
|
-
lt_tensor-0.0.
|
35
|
-
lt_tensor-0.0.
|
32
|
+
lt_tensor/processors/audio.py,sha256=1JuxxexfUsXkLjVjWUk-oTRU-QNnCCwvKX3eP0m7LGE,16452
|
33
|
+
lt_tensor-0.0.1a29.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
|
34
|
+
lt_tensor-0.0.1a29.dist-info/METADATA,sha256=F03dNMnEydcKjjZF3IntNaIj34FwLdoy-L0pBB_yz0E,1062
|
35
|
+
lt_tensor-0.0.1a29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
36
|
+
lt_tensor-0.0.1a29.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
37
|
+
lt_tensor-0.0.1a29.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|