lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,450 @@
1
+ __all__ = ["AudioSettings", "AudioDecoder"]
2
+ import gc
3
+ import math
4
+ import itertools
5
+ from lt_utils.common import *
6
+ import torch.nn.functional as F
7
+ from lt_tensor.torch_commons import *
8
+ from lt_tensor.model_base import Model
9
+ from lt_tensor.misc_utils import log_tensor
10
+ from lt_utils.misc_utils import log_traceback
11
+ from lt_tensor.processors import AudioProcessor
12
+ from lt_tensor.misc_utils import set_seed, clear_cache
13
+ from lt_utils.type_utils import is_dir, is_pathlike, is_file
14
+ from lt_tensor.config_templates import updateDict, ModelConfig
15
+ from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
16
+ from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
17
+ from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
18
+
19
+
20
+ def feature_loss(real_feats, fake_feats):
21
+ loss = 0.0
22
+ for r, f in zip(real_feats, fake_feats):
23
+ for ri, fi in zip(r, f):
24
+ loss += F.l1_loss(ri, fi)
25
+ return loss
26
+
27
+
28
+ def generator_adv_loss(fake_preds):
29
+ loss = 0.0
30
+ for f in fake_preds:
31
+ loss += torch.mean((f - 1.0) ** 2)
32
+ return loss
33
+
34
+
35
+ def discriminator_loss(real_preds, fake_preds):
36
+ loss = 0.0
37
+ for r, f in zip(real_preds, fake_preds):
38
+ loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
39
+ return loss
40
+
41
+
42
+ class AudioSettings(ModelConfig):
43
+ def __init__(
44
+ self,
45
+ n_mels: int = 80,
46
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
47
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
48
+ upsample_initial_channel: int = 512,
49
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
50
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
51
+ [1, 3, 5],
52
+ [1, 3, 5],
53
+ [1, 3, 5],
54
+ ],
55
+ n_fft: int = 16,
56
+ activation: nn.Module = nn.LeakyReLU(0.1),
57
+ msd_layers: int = 3,
58
+ mpd_periods: List[int] = [2, 3, 5, 7, 11],
59
+ seed: Optional[int] = None,
60
+ lr: float = 1e-5,
61
+ adamw_betas: List[float] = [0.75, 0.98],
62
+ scheduler_template: Callable[
63
+ [optim.Optimizer], optim.lr_scheduler.LRScheduler
64
+ ] = lambda optimizer: optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998),
65
+ ):
66
+ self.in_channels = n_mels
67
+ self.upsample_rates = upsample_rates
68
+ self.upsample_kernel_sizes = upsample_kernel_sizes
69
+ self.upsample_initial_channel = upsample_initial_channel
70
+ self.resblock_kernel_sizes = resblock_kernel_sizes
71
+ self.resblock_dilation_sizes = resblock_dilation_sizes
72
+ self.n_fft = n_fft
73
+ self.activation = activation
74
+ self.mpd_periods = mpd_periods
75
+ self.msd_layers = msd_layers
76
+ self.seed = seed
77
+ self.lr = lr
78
+ self.adamw_betas = adamw_betas
79
+ self.scheduler_template = scheduler_template
80
+
81
+
82
+ class AudioDecoder(Model):
83
+ def __init__(
84
+ self,
85
+ audio_processor: AudioProcessor,
86
+ settings: Optional[AudioSettings] = None,
87
+ generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initalized!
88
+ ):
89
+ super().__init__()
90
+ if settings is None:
91
+ self.settings = AudioSettings()
92
+ elif isinstance(settings, dict):
93
+ self.settings = AudioSettings(**settings)
94
+ elif isinstance(settings, AudioSettings):
95
+ self.settings = settings
96
+ else:
97
+ raise ValueError(
98
+ "Cannot initialize the waveDecoder with the given settings. "
99
+ "Use either a dictionary, or the class WaveSettings to setup the settings. "
100
+ "Alternatively, leave it None to use the default values."
101
+ )
102
+ if self.settings.seed is not None:
103
+ set_seed(self.settings.seed)
104
+ if generator is None:
105
+ generator = iSTFTGenerator
106
+ self.generator: iSTFTGenerator = generator(
107
+ in_channels=self.settings.in_channels,
108
+ upsample_rates=self.settings.upsample_rates,
109
+ upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
110
+ upsample_initial_channel=self.settings.upsample_initial_channel,
111
+ resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
112
+ resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
113
+ n_fft=self.settings.n_fft,
114
+ activation=self.settings.activation,
115
+ )
116
+ self.generator.eval()
117
+ self.g_optim = None
118
+ self.d_optim = None
119
+ self.gan_training = False
120
+ self.audio_processor = audio_processor
121
+ self.register_buffer("msd", None, persistent=False)
122
+ self.register_buffer("mpd", None, persistent=False)
123
+
124
+ def setup_training_mode(self, load_weights_from: Optional[PathLike] = None):
125
+ """The location must be path not a file!"""
126
+ self.finish_training_setup()
127
+ if self.msd is None:
128
+ self.msd = MultiScaleDiscriminator(self.settings.msd_layers)
129
+ if self.mpd is None:
130
+ self.mpd = MultiPeriodDiscriminator(self.settings.mpd_periods)
131
+ if load_weights_from is not None:
132
+ if is_dir(path=load_weights_from, validate=False):
133
+ try:
134
+ self.msd.load_weights(Path(load_weights_from, "msd.pt"))
135
+ except Exception as e:
136
+ log_traceback(e, "MSD Loading")
137
+ try:
138
+ self.mpd.load_weights(Path(load_weights_from, "mpd.pt"))
139
+ except Exception as e:
140
+ log_traceback(e, "MPD Loading")
141
+
142
+ self.update_schedulers_and_optimizer()
143
+ self.msd.to(device=self.device)
144
+ self.mpd.to(device=self.device)
145
+
146
+ self.gan_training = True
147
+ return True
148
+
149
+ def update_schedulers_and_optimizer(self):
150
+ self.g_optim = optim.AdamW(
151
+ self.generator.parameters(),
152
+ lr=self.settings.lr,
153
+ betas=self.settings.adamw_betas,
154
+ )
155
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)
156
+ if any([self.mpd is None, self.msd is None]):
157
+ return
158
+ self.d_optim = optim.AdamW(
159
+ itertools.chain(self.mpd.parameters(), self.msd.parameters()),
160
+ lr=self.settings.lr,
161
+ betas=self.settings.adamw_betas,
162
+ )
163
+ self.d_scheduler = self.settings.scheduler_template(self.d_optim)
164
+
165
+ def set_lr(self, new_lr: float = 1e-4):
166
+ if self.g_optim is not None:
167
+ for groups in self.g_optim.param_groups:
168
+ groups["lr"] = new_lr
169
+
170
+ if self.d_optim is not None:
171
+ for groups in self.d_optim.param_groups:
172
+ groups["lr"] = new_lr
173
+ return self.get_lr()
174
+
175
+ def get_lr(self) -> Tuple[float, float]:
176
+ g = float("nan")
177
+ d = float("nan")
178
+ if self.g_optim is not None:
179
+ g = self.g_optim.param_groups[0]["lr"]
180
+ if self.d_optim is not None:
181
+ d = self.d_optim.param_groups[0]["lr"]
182
+ return g, d
183
+
184
+ def save_weights(self, path, replace=True):
185
+ is_pathlike(path, check_if_empty=True, validate=True)
186
+ if str(path).endswith(".pt"):
187
+ path = Path(path).parent
188
+ else:
189
+ path = Path(path)
190
+ self.generator.save_weights(Path(path, "generator.pt"), replace)
191
+ if self.msd is not None:
192
+ self.msd.save_weights(Path(path, "msp.pt"), replace)
193
+ if self.mpd is not None:
194
+ self.mpd.save_weights(Path(path, "mpd.pt"), replace)
195
+
196
+ def load_weights(
197
+ self,
198
+ path,
199
+ raise_if_not_exists=False,
200
+ strict=True,
201
+ assign=False,
202
+ weights_only=False,
203
+ mmap=None,
204
+ **torch_loader_kwargs
205
+ ):
206
+ is_pathlike(path, check_if_empty=True, validate=True)
207
+ if str(path).endswith(".pt"):
208
+ path = Path(path)
209
+ else:
210
+ path = Path(path, "generator.pt")
211
+
212
+ self.generator.load_weights(
213
+ path,
214
+ raise_if_not_exists,
215
+ strict,
216
+ assign,
217
+ weights_only,
218
+ mmap,
219
+ **torch_loader_kwargs,
220
+ )
221
+
222
+ def finish_training_setup(self):
223
+ gc.collect()
224
+ self.mpd = None
225
+ clear_cache()
226
+ gc.collect()
227
+ self.msd = None
228
+ clear_cache()
229
+ self.gan_training = False
230
+
231
+ def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
232
+ """Returns the generated spec and phase"""
233
+ return self.generator.forward(mel_spec)
234
+
235
+ def inference(
236
+ self,
237
+ mel_spec: Tensor,
238
+ return_dict: bool = False,
239
+ ) -> Union[Dict[str, Tensor], Tensor]:
240
+ spec, phase = super().inference(mel_spec)
241
+ wave = self.audio_processor.inverse_transform(
242
+ spec,
243
+ phase,
244
+ self.settings.n_fft,
245
+ hop_length=4,
246
+ win_length=self.settings.n_fft,
247
+ )
248
+ if not return_dict:
249
+ return wave[:, : wave.shape[-1] - 256]
250
+ return {
251
+ "wave": wave[:, : wave.shape[-1] - 256],
252
+ "spec": spec,
253
+ "phase": phase,
254
+ }
255
+
256
+ def set_device(self, device: str):
257
+ self.to(device=device)
258
+ self.generator.to(device=device)
259
+ self.audio_processor.to(device=device)
260
+ self.msd.to(device=device)
261
+ self.mpd.to(device=device)
262
+
263
+ def train_step(
264
+ self,
265
+ mels: Tensor,
266
+ real_audio: Tensor,
267
+ stft_scale: float = 1.0,
268
+ mel_scale: float = 1.0,
269
+ adv_scale: float = 1.0,
270
+ fm_scale: float = 1.0,
271
+ fm_add: float = 0.0,
272
+ is_discriminator_frozen: bool = False,
273
+ is_generator_frozen: bool = False,
274
+ ):
275
+ if not self.gan_training:
276
+ self.setup_training_mode()
277
+ spec, phase = super().train_step(mels)
278
+ real_audio = real_audio.squeeze(1)
279
+ fake_audio = self.audio_processor.inverse_transform(
280
+ spec,
281
+ phase,
282
+ self.settings.n_fft,
283
+ hop_length=4,
284
+ win_length=self.settings.n_fft,
285
+ # length=real_audio.shape[-1]
286
+ )[:, : real_audio.shape[-1]]
287
+ # smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
288
+ # real_audio = real_audio[:, :, :smallest].squeeze(1)
289
+ # fake_audio = fake_audio[:, :smallest]
290
+
291
+ disc_kwargs = dict(
292
+ real_audio=real_audio,
293
+ fake_audio=fake_audio.detach(),
294
+ am_i_frozen=is_discriminator_frozen,
295
+ )
296
+ if is_discriminator_frozen:
297
+ with torch.no_grad():
298
+ disc_out = self._discriminator_step(**disc_kwargs)
299
+ else:
300
+ disc_out = self._discriminator_step(**disc_kwargs)
301
+
302
+ generato_kwargs = dict(
303
+ mels=mels,
304
+ real_audio=real_audio,
305
+ fake_audio=fake_audio,
306
+ **disc_out,
307
+ stft_scale=stft_scale,
308
+ mel_scale=mel_scale,
309
+ adv_scale=adv_scale,
310
+ fm_add=fm_add,
311
+ fm_scale=fm_scale,
312
+ am_i_frozen=is_generator_frozen,
313
+ )
314
+
315
+ if is_generator_frozen:
316
+ with torch.no_grad():
317
+ return self._generator_step(**generato_kwargs)
318
+ return self._generator_step(**generato_kwargs)
319
+
320
+ def _discriminator_step(
321
+ self,
322
+ real_audio: Tensor,
323
+ fake_audio: Tensor,
324
+ am_i_frozen: bool = False,
325
+ ):
326
+ # ========== Discriminator Forward Pass ==========
327
+
328
+ # MPD
329
+ real_mpd_preds, _ = self.mpd(real_audio)
330
+ fake_mpd_preds, _ = self.mpd(fake_audio)
331
+ # MSD
332
+ real_msd_preds, _ = self.msd(real_audio)
333
+ fake_msd_preds, _ = self.msd(fake_audio)
334
+
335
+ loss_d_mpd = discriminator_loss(real_mpd_preds, fake_mpd_preds)
336
+ loss_d_msd = discriminator_loss(real_msd_preds, fake_msd_preds)
337
+ loss_d = loss_d_mpd + loss_d_msd
338
+
339
+ if not am_i_frozen:
340
+ self.d_optim.zero_grad()
341
+ loss_d.backward()
342
+ self.d_optim.step()
343
+
344
+ return {
345
+ "loss_d": loss_d.item(),
346
+ }
347
+
348
+ def _generator_step(
349
+ self,
350
+ mels: Tensor,
351
+ real_audio: Tensor,
352
+ fake_audio: Tensor,
353
+ loss_d: float,
354
+ stft_scale: float = 1.0,
355
+ mel_scale: float = 1.0,
356
+ adv_scale: float = 1.0,
357
+ fm_scale: float = 1.0,
358
+ fm_add: float = 0.0,
359
+ am_i_frozen: bool = False,
360
+ ):
361
+ # ========== Generator Loss ==========
362
+ real_mpd_feats = self.mpd(real_audio)[1]
363
+ real_msd_feats = self.msd(real_audio)[1]
364
+
365
+ fake_mpd_preds, fake_mpd_feats = self.mpd(fake_audio)
366
+ fake_msd_preds, fake_msd_feats = self.msd(fake_audio)
367
+
368
+ loss_adv_mpd = generator_adv_loss(fake_mpd_preds)
369
+ loss_adv_msd = generator_adv_loss(fake_msd_preds)
370
+ loss_fm_mpd = feature_loss(real_mpd_feats, fake_mpd_feats)
371
+ loss_fm_msd = feature_loss(real_msd_feats, fake_msd_feats)
372
+
373
+ loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
374
+ loss_mel = (
375
+ F.l1_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
376
+ )
377
+ loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
378
+
379
+ loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
380
+
381
+ loss_g = loss_adv + loss_fm + loss_stft + loss_mel
382
+ if not am_i_frozen:
383
+ self.g_optim.zero_grad()
384
+ loss_g.backward()
385
+ self.g_optim.step()
386
+ return {
387
+ "loss_g": loss_g.item(),
388
+ "loss_d": loss_d,
389
+ "loss_adv": loss_adv.item(),
390
+ "loss_fm": loss_fm.item(),
391
+ "loss_stft": loss_stft.item(),
392
+ "loss_mel": loss_mel.item(),
393
+ "lr_g": self.g_optim.param_groups[0]["lr"],
394
+ "lr_d": self.d_optim.param_groups[0]["lr"],
395
+ }
396
+
397
+ def step_scheduler(
398
+ self, is_disc_frozen: bool = False, is_generator_frozen: bool = False
399
+ ):
400
+ if self.d_scheduler is not None and not is_disc_frozen:
401
+ self.d_scheduler.step()
402
+ if self.g_scheduler is not None and not is_generator_frozen:
403
+ self.g_scheduler.step()
404
+
405
+ def reset_schedulers(self, lr: Optional[float] = None):
406
+ """
407
+ In case you have adopted another strategy, with this function,
408
+ it is possible restart the scheduler and set the lr to another value.
409
+ """
410
+ if lr is not None:
411
+ self.set_lr(lr)
412
+ if self.d_optim is not None:
413
+ self.d_scheduler = None
414
+ self.d_scheduler = self.settings.scheduler_template(self.d_optim)
415
+ if self.g_optim is not None:
416
+ self.g_scheduler = None
417
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)
418
+
419
+
420
+ class ResBlocks(ConvNets):
421
+ def __init__(
422
+ self,
423
+ channels: int,
424
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
425
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
426
+ [1, 3, 5],
427
+ [1, 3, 5],
428
+ [1, 3, 5],
429
+ ],
430
+ activation: nn.Module = nn.LeakyReLU(0.1),
431
+ ):
432
+ super().__init__()
433
+ self.num_kernels = len(resblock_kernel_sizes)
434
+ self.rb = nn.ModuleList()
435
+ self.activation = activation
436
+
437
+ for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
438
+ self.rb.append(ResBlock1D(channels, k, j, activation))
439
+
440
+ self.rb.apply(self.init_weights)
441
+
442
+ def forward(self, x: torch.Tensor):
443
+ xs = None
444
+ for i, block in enumerate(self.rb):
445
+ if i == 0:
446
+ xs = block(x)
447
+ else:
448
+ xs += block(x)
449
+ x = xs / self.num_kernels
450
+ return self.activation(x)