lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +2 -0
- lt_tensor/config_templates.py +97 -0
- lt_tensor/datasets/audio.py +21 -7
- lt_tensor/losses.py +1 -1
- lt_tensor/math_ops.py +1 -1
- lt_tensor/misc_utils.py +71 -2
- lt_tensor/model_base.py +157 -203
- lt_tensor/model_zoo/__init__.py +2 -2
- lt_tensor/model_zoo/bsc.py +6 -6
- lt_tensor/model_zoo/disc.py +1 -1
- lt_tensor/model_zoo/fsn.py +2 -2
- lt_tensor/model_zoo/gns.py +4 -4
- lt_tensor/model_zoo/istft/__init__.py +5 -0
- lt_tensor/model_zoo/istft/generator.py +150 -0
- lt_tensor/model_zoo/istft/trainer.py +450 -0
- lt_tensor/model_zoo/istft.py +508 -25
- lt_tensor/model_zoo/pos.py +2 -2
- lt_tensor/model_zoo/rsd.py +16 -146
- lt_tensor/model_zoo/tfrms.py +4 -4
- lt_tensor/noise_tools.py +2 -2
- lt_tensor/processors/audio.py +87 -16
- lt_tensor/transform.py +30 -37
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/METADATA +3 -2
- lt_tensor-0.0.1a12.dist-info/RECORD +32 -0
- lt_tensor-0.0.1a11.dist-info/RECORD +0 -28
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/top_level.txt +0 -0
lt_tensor/model_zoo/istft.py
CHANGED
@@ -1,14 +1,495 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
1
|
+
__all__ = ["WaveSettings", "WaveDecoder", "iSTFTGenerator"]
|
2
|
+
import gc
|
3
|
+
import math
|
4
|
+
import itertools
|
5
|
+
from lt_utils.common import *
|
6
|
+
from lt_tensor.torch_commons import *
|
7
|
+
from lt_tensor.model_base import Model
|
8
|
+
from lt_tensor.misc_utils import log_tensor
|
9
|
+
from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
|
10
|
+
from lt_utils.misc_utils import log_traceback
|
11
|
+
from lt_tensor.processors import AudioProcessor
|
12
|
+
from lt_utils.type_utils import is_dir, is_pathlike
|
13
|
+
from lt_tensor.misc_utils import set_seed, clear_cache
|
14
|
+
from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
6
15
|
import torch.nn.functional as F
|
16
|
+
from lt_tensor.config_templates import updateDict, ModelConfig
|
17
|
+
|
18
|
+
|
19
|
+
def feature_loss(real_feats, fake_feats):
|
20
|
+
loss = 0.0
|
21
|
+
for r, f in zip(real_feats, fake_feats):
|
22
|
+
for ri, fi in zip(r, f):
|
23
|
+
loss += F.l1_loss(ri, fi)
|
24
|
+
return loss
|
25
|
+
|
26
|
+
|
27
|
+
def generator_adv_loss(fake_preds):
|
28
|
+
loss = 0.0
|
29
|
+
for f in fake_preds:
|
30
|
+
loss += torch.mean((f - 1.0) ** 2)
|
31
|
+
return loss
|
32
|
+
|
33
|
+
|
34
|
+
def discriminator_loss(real_preds, fake_preds):
|
35
|
+
loss = 0.0
|
36
|
+
for r, f in zip(real_preds, fake_preds):
|
37
|
+
loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
|
38
|
+
return loss
|
39
|
+
|
40
|
+
|
41
|
+
class WaveSettings:
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
n_mels: int = 80,
|
45
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8],
|
46
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
|
47
|
+
upsample_initial_channel: int = 512,
|
48
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
49
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
50
|
+
[1, 3, 5],
|
51
|
+
[1, 3, 5],
|
52
|
+
[1, 3, 5],
|
53
|
+
],
|
54
|
+
n_fft: int = 16,
|
55
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
56
|
+
msd_layers: int = 3,
|
57
|
+
mpd_periods: List[int] = [2, 3, 5, 7, 11],
|
58
|
+
seed: Optional[int] = None,
|
59
|
+
lr: float = 1e-5,
|
60
|
+
adamw_betas: List[float] = [0.75, 0.98],
|
61
|
+
scheduler_template: Callable[
|
62
|
+
[optim.Optimizer], optim.lr_scheduler.LRScheduler
|
63
|
+
] = lambda optimizer: optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998),
|
64
|
+
):
|
65
|
+
self.in_channels = n_mels
|
66
|
+
self.upsample_rates = upsample_rates
|
67
|
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
68
|
+
self.upsample_initial_channel = upsample_initial_channel
|
69
|
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
70
|
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
71
|
+
self.n_fft = n_fft
|
72
|
+
self.activation = activation
|
73
|
+
self.mpd_periods = mpd_periods
|
74
|
+
self.msd_layers = msd_layers
|
75
|
+
self.seed = seed
|
76
|
+
self.lr = lr
|
77
|
+
self.adamw_betas = adamw_betas
|
78
|
+
self.scheduler_template = scheduler_template
|
79
|
+
|
80
|
+
def to_dict(self):
|
81
|
+
return {k: y for k, y in self.__dict__.items()}
|
7
82
|
|
83
|
+
def set_value(self, var_name: str, value: str) -> None:
|
84
|
+
updateDict(self, {var_name: value})
|
8
85
|
|
9
|
-
|
10
|
-
|
11
|
-
|
86
|
+
def get_value(self, var_name: str) -> Any:
|
87
|
+
return self.__dict__.get(var_name)
|
88
|
+
|
89
|
+
|
90
|
+
class WaveDecoder(Model):
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
audio_processor: AudioProcessor,
|
94
|
+
settings: Optional[WaveSettings] = None,
|
95
|
+
generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initalized!
|
96
|
+
):
|
97
|
+
super().__init__()
|
98
|
+
if settings is None:
|
99
|
+
self.settings = WaveSettings()
|
100
|
+
elif isinstance(settings, dict):
|
101
|
+
self.settings = WaveSettings(**settings)
|
102
|
+
elif isinstance(settings, WaveSettings):
|
103
|
+
self.settings = settings
|
104
|
+
else:
|
105
|
+
raise ValueError(
|
106
|
+
"Cannot initialize the waveDecoder with the given settings. "
|
107
|
+
"Use either a dictionary, or the class WaveSettings to setup the settings. "
|
108
|
+
"Alternatively, leave it None to use the default values."
|
109
|
+
)
|
110
|
+
if self.settings.seed is not None:
|
111
|
+
set_seed(self.settings.seed)
|
112
|
+
if generator is None:
|
113
|
+
generator = iSTFTGenerator
|
114
|
+
self.generator: iSTFTGenerator = generator(
|
115
|
+
in_channels=self.settings.in_channels,
|
116
|
+
upsample_rates=self.settings.upsample_rates,
|
117
|
+
upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
|
118
|
+
upsample_initial_channel=self.settings.upsample_initial_channel,
|
119
|
+
resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
|
120
|
+
resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
|
121
|
+
n_fft=self.settings.n_fft,
|
122
|
+
activation=self.settings.activation,
|
123
|
+
)
|
124
|
+
self.generator.eval()
|
125
|
+
self.g_optim = None
|
126
|
+
self.d_optim = None
|
127
|
+
self.gan_training = False
|
128
|
+
self.audio_processor = audio_processor
|
129
|
+
self.register_buffer("msd", None, persistent=False)
|
130
|
+
self.register_buffer("mpd", None, persistent=False)
|
131
|
+
|
132
|
+
def setup_training_mode(self, load_weights_from: Optional[PathLike] = None):
|
133
|
+
"""The location must be path not a file!"""
|
134
|
+
self.finish_training_setup()
|
135
|
+
if self.msd is None:
|
136
|
+
self.msd = MultiScaleDiscriminator(self.settings.msd_layers)
|
137
|
+
if self.mpd is None:
|
138
|
+
self.mpd = MultiPeriodDiscriminator(self.settings.mpd_periods)
|
139
|
+
if load_weights_from is not None:
|
140
|
+
if is_dir(path=load_weights_from, validate=False):
|
141
|
+
try:
|
142
|
+
self.msd.load_weights(Path(load_weights_from, "msd.pt"))
|
143
|
+
except Exception as e:
|
144
|
+
log_traceback(e, "MSD Loading")
|
145
|
+
try:
|
146
|
+
self.mpd.load_weights(Path(load_weights_from, "mpd.pt"))
|
147
|
+
except Exception as e:
|
148
|
+
log_traceback(e, "MPD Loading")
|
149
|
+
|
150
|
+
self.update_schedulers_and_optimizer()
|
151
|
+
self.msd.to(device=self.device)
|
152
|
+
self.mpd.to(device=self.device)
|
153
|
+
|
154
|
+
self.gan_training = True
|
155
|
+
return True
|
156
|
+
|
157
|
+
def update_schedulers_and_optimizer(self):
|
158
|
+
self.g_optim = optim.AdamW(
|
159
|
+
self.generator.parameters(),
|
160
|
+
lr=self.settings.lr,
|
161
|
+
betas=self.settings.adamw_betas,
|
162
|
+
)
|
163
|
+
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
164
|
+
if any([self.mpd is None, self.msd is None]):
|
165
|
+
return
|
166
|
+
self.d_optim = optim.AdamW(
|
167
|
+
itertools.chain(self.mpd.parameters(), self.msd.parameters()),
|
168
|
+
lr=self.settings.lr,
|
169
|
+
betas=self.settings.adamw_betas,
|
170
|
+
)
|
171
|
+
self.d_scheduler = self.settings.scheduler_template(self.d_optim)
|
172
|
+
|
173
|
+
def set_lr(self, new_lr: float = 1e-4):
|
174
|
+
if self.g_optim is not None:
|
175
|
+
for groups in self.g_optim.param_groups:
|
176
|
+
groups["lr"] = new_lr
|
177
|
+
|
178
|
+
if self.d_optim is not None:
|
179
|
+
for groups in self.d_optim.param_groups:
|
180
|
+
groups["lr"] = new_lr
|
181
|
+
return self.get_lr()
|
182
|
+
|
183
|
+
def get_lr(self) -> Tuple[float, float]:
|
184
|
+
g = float("nan")
|
185
|
+
d = float("nan")
|
186
|
+
if self.g_optim is not None:
|
187
|
+
g = self.g_optim.param_groups[0]["lr"]
|
188
|
+
if self.d_optim is not None:
|
189
|
+
d = self.d_optim.param_groups[0]["lr"]
|
190
|
+
return g, d
|
191
|
+
|
192
|
+
def save_weights(self, path, replace=True):
|
193
|
+
is_pathlike(path, check_if_empty=True, validate=True)
|
194
|
+
if str(path).endswith(".pt"):
|
195
|
+
path = Path(path).parent
|
196
|
+
else:
|
197
|
+
path = Path(path)
|
198
|
+
self.generator.save_weights(Path(path, "generator.pt"), replace)
|
199
|
+
if self.msd is not None:
|
200
|
+
self.msd.save_weights(Path(path, "msp.pt"), replace)
|
201
|
+
if self.mpd is not None:
|
202
|
+
self.mpd.save_weights(Path(path, "mpd.pt"), replace)
|
203
|
+
|
204
|
+
def load_weights(
|
205
|
+
self,
|
206
|
+
path,
|
207
|
+
raise_if_not_exists=False,
|
208
|
+
strict=True,
|
209
|
+
assign=False,
|
210
|
+
weights_only=False,
|
211
|
+
mmap=None,
|
212
|
+
**torch_loader_kwargs
|
213
|
+
):
|
214
|
+
is_pathlike(path, check_if_empty=True, validate=True)
|
215
|
+
if str(path).endswith(".pt"):
|
216
|
+
path = Path(path)
|
217
|
+
else:
|
218
|
+
path = Path(path, "generator.pt")
|
219
|
+
|
220
|
+
self.generator.load_weights(
|
221
|
+
path,
|
222
|
+
raise_if_not_exists,
|
223
|
+
strict,
|
224
|
+
assign,
|
225
|
+
weights_only,
|
226
|
+
mmap,
|
227
|
+
**torch_loader_kwargs,
|
228
|
+
)
|
229
|
+
|
230
|
+
def finish_training_setup(self):
|
231
|
+
gc.collect()
|
232
|
+
self.mpd = None
|
233
|
+
clear_cache()
|
234
|
+
gc.collect()
|
235
|
+
self.msd = None
|
236
|
+
clear_cache()
|
237
|
+
self.gan_training = False
|
238
|
+
|
239
|
+
def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
|
240
|
+
"""Returns the generated spec and phase"""
|
241
|
+
return self.generator.forward(mel_spec)
|
242
|
+
|
243
|
+
def inference(
|
244
|
+
self,
|
245
|
+
mel_spec: Tensor,
|
246
|
+
return_dict: bool = False,
|
247
|
+
) -> Union[Dict[str, Tensor], Tensor]:
|
248
|
+
spec, phase = super().inference(mel_spec)
|
249
|
+
wave = self.audio_processor.inverse_transform(
|
250
|
+
spec,
|
251
|
+
phase,
|
252
|
+
self.settings.n_fft,
|
253
|
+
hop_length=4,
|
254
|
+
win_length=self.settings.n_fft,
|
255
|
+
)
|
256
|
+
if not return_dict:
|
257
|
+
return wave[:, : wave.shape[-1] - 256]
|
258
|
+
return {
|
259
|
+
"wave": wave[:, : wave.shape[-1] - 256],
|
260
|
+
"spec": spec,
|
261
|
+
"phase": phase,
|
262
|
+
}
|
263
|
+
|
264
|
+
def set_device(self, device: str):
|
265
|
+
self.to(device=device)
|
266
|
+
self.generator.to(device=device)
|
267
|
+
self.audio_processor.to(device=device)
|
268
|
+
self.msd.to(device=device)
|
269
|
+
self.mpd.to(device=device)
|
270
|
+
|
271
|
+
def train_step(
|
272
|
+
self,
|
273
|
+
mels: Tensor,
|
274
|
+
real_audio: Tensor,
|
275
|
+
stft_scale: float = 1.0,
|
276
|
+
mel_scale: float = 1.0,
|
277
|
+
adv_scale: float = 1.0,
|
278
|
+
fm_scale: float = 1.0,
|
279
|
+
fm_add: float = 0.0,
|
280
|
+
is_discriminator_frozen: bool = False,
|
281
|
+
is_generator_frozen: bool = False,
|
282
|
+
):
|
283
|
+
if not self.gan_training:
|
284
|
+
self.setup_training_mode()
|
285
|
+
spec, phase = super().train_step(mels)
|
286
|
+
real_audio = real_audio.squeeze(1)
|
287
|
+
fake_audio = self.audio_processor.inverse_transform(
|
288
|
+
spec,
|
289
|
+
phase,
|
290
|
+
self.settings.n_fft,
|
291
|
+
hop_length=4,
|
292
|
+
win_length=self.settings.n_fft,
|
293
|
+
# length=real_audio.shape[-1]
|
294
|
+
)[:, : real_audio.shape[-1]]
|
295
|
+
# smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
|
296
|
+
# real_audio = real_audio[:, :, :smallest].squeeze(1)
|
297
|
+
# fake_audio = fake_audio[:, :smallest]
|
298
|
+
|
299
|
+
disc_kwargs = dict(
|
300
|
+
real_audio=real_audio,
|
301
|
+
fake_audio=fake_audio.detach(),
|
302
|
+
am_i_frozen=is_discriminator_frozen,
|
303
|
+
)
|
304
|
+
if is_discriminator_frozen:
|
305
|
+
with torch.no_grad():
|
306
|
+
disc_out = self._discriminator_step(**disc_kwargs)
|
307
|
+
else:
|
308
|
+
disc_out = self._discriminator_step(**disc_kwargs)
|
309
|
+
|
310
|
+
generato_kwargs = dict(
|
311
|
+
mels=mels,
|
312
|
+
real_audio=real_audio,
|
313
|
+
fake_audio=fake_audio,
|
314
|
+
**disc_out,
|
315
|
+
stft_scale=stft_scale,
|
316
|
+
mel_scale=mel_scale,
|
317
|
+
adv_scale=adv_scale,
|
318
|
+
fm_add=fm_add,
|
319
|
+
fm_scale=fm_scale,
|
320
|
+
am_i_frozen=is_generator_frozen,
|
321
|
+
)
|
322
|
+
|
323
|
+
if is_generator_frozen:
|
324
|
+
with torch.no_grad():
|
325
|
+
return self._generator_step(**generato_kwargs)
|
326
|
+
return self._generator_step(**generato_kwargs)
|
327
|
+
|
328
|
+
def _discriminator_step(
|
329
|
+
self,
|
330
|
+
real_audio: Tensor,
|
331
|
+
fake_audio: Tensor,
|
332
|
+
am_i_frozen: bool = False,
|
333
|
+
):
|
334
|
+
# ========== Discriminator Forward Pass ==========
|
335
|
+
|
336
|
+
# MPD
|
337
|
+
real_mpd_preds, _ = self.mpd(real_audio)
|
338
|
+
fake_mpd_preds, _ = self.mpd(fake_audio)
|
339
|
+
# MSD
|
340
|
+
real_msd_preds, _ = self.msd(real_audio)
|
341
|
+
fake_msd_preds, _ = self.msd(fake_audio)
|
342
|
+
|
343
|
+
loss_d_mpd = discriminator_loss(real_mpd_preds, fake_mpd_preds)
|
344
|
+
loss_d_msd = discriminator_loss(real_msd_preds, fake_msd_preds)
|
345
|
+
loss_d = loss_d_mpd + loss_d_msd
|
346
|
+
|
347
|
+
if not am_i_frozen:
|
348
|
+
self.d_optim.zero_grad()
|
349
|
+
loss_d.backward()
|
350
|
+
self.d_optim.step()
|
351
|
+
|
352
|
+
return {
|
353
|
+
"loss_d": loss_d.item(),
|
354
|
+
}
|
355
|
+
|
356
|
+
def _generator_step(
|
357
|
+
self,
|
358
|
+
mels: Tensor,
|
359
|
+
real_audio: Tensor,
|
360
|
+
fake_audio: Tensor,
|
361
|
+
loss_d: float,
|
362
|
+
stft_scale: float = 1.0,
|
363
|
+
mel_scale: float = 1.0,
|
364
|
+
adv_scale: float = 1.0,
|
365
|
+
fm_scale: float = 1.0,
|
366
|
+
fm_add: float = 0.0,
|
367
|
+
am_i_frozen: bool = False,
|
368
|
+
):
|
369
|
+
# ========== Generator Loss ==========
|
370
|
+
real_mpd_feats = self.mpd(real_audio)[1]
|
371
|
+
real_msd_feats = self.msd(real_audio)[1]
|
372
|
+
|
373
|
+
fake_mpd_preds, fake_mpd_feats = self.mpd(fake_audio)
|
374
|
+
fake_msd_preds, fake_msd_feats = self.msd(fake_audio)
|
375
|
+
|
376
|
+
loss_adv_mpd = generator_adv_loss(fake_mpd_preds)
|
377
|
+
loss_adv_msd = generator_adv_loss(fake_msd_preds)
|
378
|
+
loss_fm_mpd = feature_loss(real_mpd_feats, fake_mpd_feats)
|
379
|
+
loss_fm_msd = feature_loss(real_msd_feats, fake_msd_feats)
|
380
|
+
|
381
|
+
loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
|
382
|
+
loss_mel = (
|
383
|
+
F.l1_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
|
384
|
+
)
|
385
|
+
loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
|
386
|
+
|
387
|
+
loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
|
388
|
+
|
389
|
+
loss_g = loss_adv + loss_fm + loss_stft + loss_mel
|
390
|
+
if not am_i_frozen:
|
391
|
+
self.g_optim.zero_grad()
|
392
|
+
loss_g.backward()
|
393
|
+
self.g_optim.step()
|
394
|
+
return {
|
395
|
+
"loss_g": loss_g.item(),
|
396
|
+
"loss_d": loss_d,
|
397
|
+
"loss_adv": loss_adv.item(),
|
398
|
+
"loss_fm": loss_fm.item(),
|
399
|
+
"loss_stft": loss_stft.item(),
|
400
|
+
"loss_mel": loss_mel.item(),
|
401
|
+
"lr_g": self.g_optim.param_groups[0]["lr"],
|
402
|
+
"lr_d": self.d_optim.param_groups[0]["lr"],
|
403
|
+
}
|
404
|
+
|
405
|
+
def step_scheduler(
|
406
|
+
self, is_disc_frozen: bool = False, is_generator_frozen: bool = False
|
407
|
+
):
|
408
|
+
if self.d_scheduler is not None and not is_disc_frozen:
|
409
|
+
self.d_scheduler.step()
|
410
|
+
if self.g_scheduler is not None and not is_generator_frozen:
|
411
|
+
self.g_scheduler.step()
|
412
|
+
|
413
|
+
def reset_schedulers(self, lr: Optional[float] = None):
|
414
|
+
"""
|
415
|
+
In case you have adopted another strategy, with this function,
|
416
|
+
it is possible restart the scheduler and set the lr to another value.
|
417
|
+
"""
|
418
|
+
if lr is not None:
|
419
|
+
self.set_lr(lr)
|
420
|
+
if self.d_optim is not None:
|
421
|
+
self.d_scheduler = None
|
422
|
+
self.d_scheduler = self.settings.scheduler_template(self.d_optim)
|
423
|
+
if self.g_optim is not None:
|
424
|
+
self.g_scheduler = None
|
425
|
+
self.g_scheduler = self.settings.scheduler_template(self.g_optim)
|
426
|
+
|
427
|
+
|
428
|
+
class ResBlocks(ConvNets):
|
429
|
+
def __init__(
|
430
|
+
self,
|
431
|
+
channels: int,
|
432
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
433
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
434
|
+
[1, 3, 5],
|
435
|
+
[1, 3, 5],
|
436
|
+
[1, 3, 5],
|
437
|
+
],
|
438
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
439
|
+
):
|
440
|
+
super().__init__()
|
441
|
+
self.num_kernels = len(resblock_kernel_sizes)
|
442
|
+
self.rb = nn.ModuleList()
|
443
|
+
self.activation = activation
|
444
|
+
|
445
|
+
for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
446
|
+
self.rb.append(ResBlock1D(channels, k, j, activation))
|
447
|
+
|
448
|
+
self.rb.apply(self.init_weights)
|
449
|
+
|
450
|
+
def forward(self, x: torch.Tensor):
|
451
|
+
xs = None
|
452
|
+
for i, block in enumerate(self.rb):
|
453
|
+
if i == 0:
|
454
|
+
xs = block(x)
|
455
|
+
else:
|
456
|
+
xs += block(x)
|
457
|
+
x = xs / self.num_kernels
|
458
|
+
return self.activation(x)
|
459
|
+
|
460
|
+
|
461
|
+
class PhaseRefineNet(ConvNets):
|
462
|
+
def __init__(
|
463
|
+
self,
|
464
|
+
channels: int,
|
465
|
+
kernel_size: int = 1,
|
466
|
+
dilation: int = 1,
|
467
|
+
padding: int = 0,
|
468
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
469
|
+
norm_type: Optional[Literal["weight", "spectral"]] = None,
|
470
|
+
):
|
471
|
+
super().__init__()
|
472
|
+
weight_norm_fn = get_weight_norm(norm_type=norm_type)
|
473
|
+
self.net = nn.Sequential(
|
474
|
+
activation,
|
475
|
+
weight_norm_fn(
|
476
|
+
nn.Conv1d(
|
477
|
+
channels,
|
478
|
+
channels,
|
479
|
+
kernel_size,
|
480
|
+
padding=padding,
|
481
|
+
dilation=max(dilation, 1),
|
482
|
+
)
|
483
|
+
),
|
484
|
+
)
|
485
|
+
|
486
|
+
self.net.apply(self.init_weights)
|
487
|
+
|
488
|
+
def forward(self, x):
|
489
|
+
return self.net(x)
|
490
|
+
|
491
|
+
|
492
|
+
class iSTFTGenerator(ConvNets):
|
12
493
|
|
13
494
|
def __init__(
|
14
495
|
self,
|
@@ -24,10 +505,12 @@ class Generator(Model):
|
|
24
505
|
],
|
25
506
|
n_fft: int = 16,
|
26
507
|
activation: nn.Module = nn.LeakyReLU(0.1),
|
508
|
+
hop_length: int = 256,
|
27
509
|
):
|
28
510
|
super().__init__()
|
29
511
|
self.num_kernels = len(resblock_kernel_sizes)
|
30
512
|
self.num_upsamples = len(upsample_rates)
|
513
|
+
self.hop_length = hop_length
|
31
514
|
self.conv_pre = weight_norm(
|
32
515
|
nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
33
516
|
)
|
@@ -47,7 +530,20 @@ class Generator(Model):
|
|
47
530
|
self.post_n_fft = n_fft // 2 + 1
|
48
531
|
self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
|
49
532
|
self.conv_post.apply(self.init_weights)
|
50
|
-
self.reflection_pad =
|
533
|
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
534
|
+
|
535
|
+
self.phase_pass = nn.Sequential(
|
536
|
+
nn.LeakyReLU(0.2),
|
537
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
538
|
+
nn.LeakyReLU(0.2),
|
539
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
540
|
+
)
|
541
|
+
self.spec_pass = nn.Sequential(
|
542
|
+
nn.LeakyReLU(0.2),
|
543
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
544
|
+
nn.LeakyReLU(0.2),
|
545
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
546
|
+
)
|
51
547
|
|
52
548
|
def _make_blocks(
|
53
549
|
self,
|
@@ -70,7 +566,7 @@ class Generator(Model):
|
|
70
566
|
u,
|
71
567
|
padding=(k - u) // 2,
|
72
568
|
)
|
73
|
-
),
|
569
|
+
).apply(self.init_weights),
|
74
570
|
),
|
75
571
|
residual=ResBlocks(
|
76
572
|
channels,
|
@@ -89,20 +585,7 @@ class Generator(Model):
|
|
89
585
|
|
90
586
|
x = self.reflection_pad(x)
|
91
587
|
x = self.conv_post(x)
|
92
|
-
spec = torch.exp(x[:, : self.post_n_fft, :])
|
93
|
-
phase = torch.sin(x[:, self.post_n_fft :, :])
|
588
|
+
spec = torch.exp(self.spec_pass(x[:, : self.post_n_fft, :]))
|
589
|
+
phase = torch.sin(self.phase_pass(x[:, self.post_n_fft :, :]))
|
94
590
|
|
95
591
|
return spec, phase
|
96
|
-
|
97
|
-
def remove_weight_norm(self):
|
98
|
-
for module in self.modules():
|
99
|
-
try:
|
100
|
-
remove_weight_norm(module)
|
101
|
-
except ValueError:
|
102
|
-
pass # Not normed, skip
|
103
|
-
|
104
|
-
@staticmethod
|
105
|
-
def init_weights(m, mean=0.0, std=0.01):
|
106
|
-
classname = m.__class__.__name__
|
107
|
-
if "Conv" in classname:
|
108
|
-
m.weight.data.normal_(mean, std)
|