lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,495 @@
1
- from ..torch_commons import *
2
- from ..model_base import Model
3
- from .rsd import ResBlocks
4
- from ..misc_utils import log_tensor
5
-
1
+ __all__ = ["WaveSettings", "WaveDecoder", "iSTFTGenerator"]
2
+ import gc
3
+ import math
4
+ import itertools
5
+ from lt_utils.common import *
6
+ from lt_tensor.torch_commons import *
7
+ from lt_tensor.model_base import Model
8
+ from lt_tensor.misc_utils import log_tensor
9
+ from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
10
+ from lt_utils.misc_utils import log_traceback
11
+ from lt_tensor.processors import AudioProcessor
12
+ from lt_utils.type_utils import is_dir, is_pathlike
13
+ from lt_tensor.misc_utils import set_seed, clear_cache
14
+ from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
6
15
  import torch.nn.functional as F
16
+ from lt_tensor.config_templates import updateDict, ModelConfig
17
+
18
+
19
+ def feature_loss(real_feats, fake_feats):
20
+ loss = 0.0
21
+ for r, f in zip(real_feats, fake_feats):
22
+ for ri, fi in zip(r, f):
23
+ loss += F.l1_loss(ri, fi)
24
+ return loss
25
+
26
+
27
+ def generator_adv_loss(fake_preds):
28
+ loss = 0.0
29
+ for f in fake_preds:
30
+ loss += torch.mean((f - 1.0) ** 2)
31
+ return loss
32
+
33
+
34
+ def discriminator_loss(real_preds, fake_preds):
35
+ loss = 0.0
36
+ for r, f in zip(real_preds, fake_preds):
37
+ loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
38
+ return loss
39
+
40
+
41
+ class WaveSettings:
42
+ def __init__(
43
+ self,
44
+ n_mels: int = 80,
45
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
46
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
47
+ upsample_initial_channel: int = 512,
48
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
49
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
50
+ [1, 3, 5],
51
+ [1, 3, 5],
52
+ [1, 3, 5],
53
+ ],
54
+ n_fft: int = 16,
55
+ activation: nn.Module = nn.LeakyReLU(0.1),
56
+ msd_layers: int = 3,
57
+ mpd_periods: List[int] = [2, 3, 5, 7, 11],
58
+ seed: Optional[int] = None,
59
+ lr: float = 1e-5,
60
+ adamw_betas: List[float] = [0.75, 0.98],
61
+ scheduler_template: Callable[
62
+ [optim.Optimizer], optim.lr_scheduler.LRScheduler
63
+ ] = lambda optimizer: optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998),
64
+ ):
65
+ self.in_channels = n_mels
66
+ self.upsample_rates = upsample_rates
67
+ self.upsample_kernel_sizes = upsample_kernel_sizes
68
+ self.upsample_initial_channel = upsample_initial_channel
69
+ self.resblock_kernel_sizes = resblock_kernel_sizes
70
+ self.resblock_dilation_sizes = resblock_dilation_sizes
71
+ self.n_fft = n_fft
72
+ self.activation = activation
73
+ self.mpd_periods = mpd_periods
74
+ self.msd_layers = msd_layers
75
+ self.seed = seed
76
+ self.lr = lr
77
+ self.adamw_betas = adamw_betas
78
+ self.scheduler_template = scheduler_template
79
+
80
+ def to_dict(self):
81
+ return {k: y for k, y in self.__dict__.items()}
7
82
 
83
+ def set_value(self, var_name: str, value: str) -> None:
84
+ updateDict(self, {var_name: value})
8
85
 
9
- class Generator(Model):
10
- """Based on the adaptation made by from Rishikesh
11
- A Generator for audio processing, can be usd for tother things."""
86
+ def get_value(self, var_name: str) -> Any:
87
+ return self.__dict__.get(var_name)
88
+
89
+
90
+ class WaveDecoder(Model):
91
+ def __init__(
92
+ self,
93
+ audio_processor: AudioProcessor,
94
+ settings: Optional[WaveSettings] = None,
95
+ generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initalized!
96
+ ):
97
+ super().__init__()
98
+ if settings is None:
99
+ self.settings = WaveSettings()
100
+ elif isinstance(settings, dict):
101
+ self.settings = WaveSettings(**settings)
102
+ elif isinstance(settings, WaveSettings):
103
+ self.settings = settings
104
+ else:
105
+ raise ValueError(
106
+ "Cannot initialize the waveDecoder with the given settings. "
107
+ "Use either a dictionary, or the class WaveSettings to setup the settings. "
108
+ "Alternatively, leave it None to use the default values."
109
+ )
110
+ if self.settings.seed is not None:
111
+ set_seed(self.settings.seed)
112
+ if generator is None:
113
+ generator = iSTFTGenerator
114
+ self.generator: iSTFTGenerator = generator(
115
+ in_channels=self.settings.in_channels,
116
+ upsample_rates=self.settings.upsample_rates,
117
+ upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
118
+ upsample_initial_channel=self.settings.upsample_initial_channel,
119
+ resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
120
+ resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
121
+ n_fft=self.settings.n_fft,
122
+ activation=self.settings.activation,
123
+ )
124
+ self.generator.eval()
125
+ self.g_optim = None
126
+ self.d_optim = None
127
+ self.gan_training = False
128
+ self.audio_processor = audio_processor
129
+ self.register_buffer("msd", None, persistent=False)
130
+ self.register_buffer("mpd", None, persistent=False)
131
+
132
+ def setup_training_mode(self, load_weights_from: Optional[PathLike] = None):
133
+ """The location must be path not a file!"""
134
+ self.finish_training_setup()
135
+ if self.msd is None:
136
+ self.msd = MultiScaleDiscriminator(self.settings.msd_layers)
137
+ if self.mpd is None:
138
+ self.mpd = MultiPeriodDiscriminator(self.settings.mpd_periods)
139
+ if load_weights_from is not None:
140
+ if is_dir(path=load_weights_from, validate=False):
141
+ try:
142
+ self.msd.load_weights(Path(load_weights_from, "msd.pt"))
143
+ except Exception as e:
144
+ log_traceback(e, "MSD Loading")
145
+ try:
146
+ self.mpd.load_weights(Path(load_weights_from, "mpd.pt"))
147
+ except Exception as e:
148
+ log_traceback(e, "MPD Loading")
149
+
150
+ self.update_schedulers_and_optimizer()
151
+ self.msd.to(device=self.device)
152
+ self.mpd.to(device=self.device)
153
+
154
+ self.gan_training = True
155
+ return True
156
+
157
+ def update_schedulers_and_optimizer(self):
158
+ self.g_optim = optim.AdamW(
159
+ self.generator.parameters(),
160
+ lr=self.settings.lr,
161
+ betas=self.settings.adamw_betas,
162
+ )
163
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)
164
+ if any([self.mpd is None, self.msd is None]):
165
+ return
166
+ self.d_optim = optim.AdamW(
167
+ itertools.chain(self.mpd.parameters(), self.msd.parameters()),
168
+ lr=self.settings.lr,
169
+ betas=self.settings.adamw_betas,
170
+ )
171
+ self.d_scheduler = self.settings.scheduler_template(self.d_optim)
172
+
173
+ def set_lr(self, new_lr: float = 1e-4):
174
+ if self.g_optim is not None:
175
+ for groups in self.g_optim.param_groups:
176
+ groups["lr"] = new_lr
177
+
178
+ if self.d_optim is not None:
179
+ for groups in self.d_optim.param_groups:
180
+ groups["lr"] = new_lr
181
+ return self.get_lr()
182
+
183
+ def get_lr(self) -> Tuple[float, float]:
184
+ g = float("nan")
185
+ d = float("nan")
186
+ if self.g_optim is not None:
187
+ g = self.g_optim.param_groups[0]["lr"]
188
+ if self.d_optim is not None:
189
+ d = self.d_optim.param_groups[0]["lr"]
190
+ return g, d
191
+
192
+ def save_weights(self, path, replace=True):
193
+ is_pathlike(path, check_if_empty=True, validate=True)
194
+ if str(path).endswith(".pt"):
195
+ path = Path(path).parent
196
+ else:
197
+ path = Path(path)
198
+ self.generator.save_weights(Path(path, "generator.pt"), replace)
199
+ if self.msd is not None:
200
+ self.msd.save_weights(Path(path, "msp.pt"), replace)
201
+ if self.mpd is not None:
202
+ self.mpd.save_weights(Path(path, "mpd.pt"), replace)
203
+
204
+ def load_weights(
205
+ self,
206
+ path,
207
+ raise_if_not_exists=False,
208
+ strict=True,
209
+ assign=False,
210
+ weights_only=False,
211
+ mmap=None,
212
+ **torch_loader_kwargs
213
+ ):
214
+ is_pathlike(path, check_if_empty=True, validate=True)
215
+ if str(path).endswith(".pt"):
216
+ path = Path(path)
217
+ else:
218
+ path = Path(path, "generator.pt")
219
+
220
+ self.generator.load_weights(
221
+ path,
222
+ raise_if_not_exists,
223
+ strict,
224
+ assign,
225
+ weights_only,
226
+ mmap,
227
+ **torch_loader_kwargs,
228
+ )
229
+
230
+ def finish_training_setup(self):
231
+ gc.collect()
232
+ self.mpd = None
233
+ clear_cache()
234
+ gc.collect()
235
+ self.msd = None
236
+ clear_cache()
237
+ self.gan_training = False
238
+
239
+ def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
240
+ """Returns the generated spec and phase"""
241
+ return self.generator.forward(mel_spec)
242
+
243
+ def inference(
244
+ self,
245
+ mel_spec: Tensor,
246
+ return_dict: bool = False,
247
+ ) -> Union[Dict[str, Tensor], Tensor]:
248
+ spec, phase = super().inference(mel_spec)
249
+ wave = self.audio_processor.inverse_transform(
250
+ spec,
251
+ phase,
252
+ self.settings.n_fft,
253
+ hop_length=4,
254
+ win_length=self.settings.n_fft,
255
+ )
256
+ if not return_dict:
257
+ return wave[:, : wave.shape[-1] - 256]
258
+ return {
259
+ "wave": wave[:, : wave.shape[-1] - 256],
260
+ "spec": spec,
261
+ "phase": phase,
262
+ }
263
+
264
+ def set_device(self, device: str):
265
+ self.to(device=device)
266
+ self.generator.to(device=device)
267
+ self.audio_processor.to(device=device)
268
+ self.msd.to(device=device)
269
+ self.mpd.to(device=device)
270
+
271
+ def train_step(
272
+ self,
273
+ mels: Tensor,
274
+ real_audio: Tensor,
275
+ stft_scale: float = 1.0,
276
+ mel_scale: float = 1.0,
277
+ adv_scale: float = 1.0,
278
+ fm_scale: float = 1.0,
279
+ fm_add: float = 0.0,
280
+ is_discriminator_frozen: bool = False,
281
+ is_generator_frozen: bool = False,
282
+ ):
283
+ if not self.gan_training:
284
+ self.setup_training_mode()
285
+ spec, phase = super().train_step(mels)
286
+ real_audio = real_audio.squeeze(1)
287
+ fake_audio = self.audio_processor.inverse_transform(
288
+ spec,
289
+ phase,
290
+ self.settings.n_fft,
291
+ hop_length=4,
292
+ win_length=self.settings.n_fft,
293
+ # length=real_audio.shape[-1]
294
+ )[:, : real_audio.shape[-1]]
295
+ # smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
296
+ # real_audio = real_audio[:, :, :smallest].squeeze(1)
297
+ # fake_audio = fake_audio[:, :smallest]
298
+
299
+ disc_kwargs = dict(
300
+ real_audio=real_audio,
301
+ fake_audio=fake_audio.detach(),
302
+ am_i_frozen=is_discriminator_frozen,
303
+ )
304
+ if is_discriminator_frozen:
305
+ with torch.no_grad():
306
+ disc_out = self._discriminator_step(**disc_kwargs)
307
+ else:
308
+ disc_out = self._discriminator_step(**disc_kwargs)
309
+
310
+ generato_kwargs = dict(
311
+ mels=mels,
312
+ real_audio=real_audio,
313
+ fake_audio=fake_audio,
314
+ **disc_out,
315
+ stft_scale=stft_scale,
316
+ mel_scale=mel_scale,
317
+ adv_scale=adv_scale,
318
+ fm_add=fm_add,
319
+ fm_scale=fm_scale,
320
+ am_i_frozen=is_generator_frozen,
321
+ )
322
+
323
+ if is_generator_frozen:
324
+ with torch.no_grad():
325
+ return self._generator_step(**generato_kwargs)
326
+ return self._generator_step(**generato_kwargs)
327
+
328
+ def _discriminator_step(
329
+ self,
330
+ real_audio: Tensor,
331
+ fake_audio: Tensor,
332
+ am_i_frozen: bool = False,
333
+ ):
334
+ # ========== Discriminator Forward Pass ==========
335
+
336
+ # MPD
337
+ real_mpd_preds, _ = self.mpd(real_audio)
338
+ fake_mpd_preds, _ = self.mpd(fake_audio)
339
+ # MSD
340
+ real_msd_preds, _ = self.msd(real_audio)
341
+ fake_msd_preds, _ = self.msd(fake_audio)
342
+
343
+ loss_d_mpd = discriminator_loss(real_mpd_preds, fake_mpd_preds)
344
+ loss_d_msd = discriminator_loss(real_msd_preds, fake_msd_preds)
345
+ loss_d = loss_d_mpd + loss_d_msd
346
+
347
+ if not am_i_frozen:
348
+ self.d_optim.zero_grad()
349
+ loss_d.backward()
350
+ self.d_optim.step()
351
+
352
+ return {
353
+ "loss_d": loss_d.item(),
354
+ }
355
+
356
+ def _generator_step(
357
+ self,
358
+ mels: Tensor,
359
+ real_audio: Tensor,
360
+ fake_audio: Tensor,
361
+ loss_d: float,
362
+ stft_scale: float = 1.0,
363
+ mel_scale: float = 1.0,
364
+ adv_scale: float = 1.0,
365
+ fm_scale: float = 1.0,
366
+ fm_add: float = 0.0,
367
+ am_i_frozen: bool = False,
368
+ ):
369
+ # ========== Generator Loss ==========
370
+ real_mpd_feats = self.mpd(real_audio)[1]
371
+ real_msd_feats = self.msd(real_audio)[1]
372
+
373
+ fake_mpd_preds, fake_mpd_feats = self.mpd(fake_audio)
374
+ fake_msd_preds, fake_msd_feats = self.msd(fake_audio)
375
+
376
+ loss_adv_mpd = generator_adv_loss(fake_mpd_preds)
377
+ loss_adv_msd = generator_adv_loss(fake_msd_preds)
378
+ loss_fm_mpd = feature_loss(real_mpd_feats, fake_mpd_feats)
379
+ loss_fm_msd = feature_loss(real_msd_feats, fake_msd_feats)
380
+
381
+ loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
382
+ loss_mel = (
383
+ F.l1_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
384
+ )
385
+ loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
386
+
387
+ loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
388
+
389
+ loss_g = loss_adv + loss_fm + loss_stft + loss_mel
390
+ if not am_i_frozen:
391
+ self.g_optim.zero_grad()
392
+ loss_g.backward()
393
+ self.g_optim.step()
394
+ return {
395
+ "loss_g": loss_g.item(),
396
+ "loss_d": loss_d,
397
+ "loss_adv": loss_adv.item(),
398
+ "loss_fm": loss_fm.item(),
399
+ "loss_stft": loss_stft.item(),
400
+ "loss_mel": loss_mel.item(),
401
+ "lr_g": self.g_optim.param_groups[0]["lr"],
402
+ "lr_d": self.d_optim.param_groups[0]["lr"],
403
+ }
404
+
405
+ def step_scheduler(
406
+ self, is_disc_frozen: bool = False, is_generator_frozen: bool = False
407
+ ):
408
+ if self.d_scheduler is not None and not is_disc_frozen:
409
+ self.d_scheduler.step()
410
+ if self.g_scheduler is not None and not is_generator_frozen:
411
+ self.g_scheduler.step()
412
+
413
+ def reset_schedulers(self, lr: Optional[float] = None):
414
+ """
415
+ In case you have adopted another strategy, with this function,
416
+ it is possible restart the scheduler and set the lr to another value.
417
+ """
418
+ if lr is not None:
419
+ self.set_lr(lr)
420
+ if self.d_optim is not None:
421
+ self.d_scheduler = None
422
+ self.d_scheduler = self.settings.scheduler_template(self.d_optim)
423
+ if self.g_optim is not None:
424
+ self.g_scheduler = None
425
+ self.g_scheduler = self.settings.scheduler_template(self.g_optim)
426
+
427
+
428
+ class ResBlocks(ConvNets):
429
+ def __init__(
430
+ self,
431
+ channels: int,
432
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
433
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
434
+ [1, 3, 5],
435
+ [1, 3, 5],
436
+ [1, 3, 5],
437
+ ],
438
+ activation: nn.Module = nn.LeakyReLU(0.1),
439
+ ):
440
+ super().__init__()
441
+ self.num_kernels = len(resblock_kernel_sizes)
442
+ self.rb = nn.ModuleList()
443
+ self.activation = activation
444
+
445
+ for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
446
+ self.rb.append(ResBlock1D(channels, k, j, activation))
447
+
448
+ self.rb.apply(self.init_weights)
449
+
450
+ def forward(self, x: torch.Tensor):
451
+ xs = None
452
+ for i, block in enumerate(self.rb):
453
+ if i == 0:
454
+ xs = block(x)
455
+ else:
456
+ xs += block(x)
457
+ x = xs / self.num_kernels
458
+ return self.activation(x)
459
+
460
+
461
+ class PhaseRefineNet(ConvNets):
462
+ def __init__(
463
+ self,
464
+ channels: int,
465
+ kernel_size: int = 1,
466
+ dilation: int = 1,
467
+ padding: int = 0,
468
+ activation: nn.Module = nn.LeakyReLU(0.1),
469
+ norm_type: Optional[Literal["weight", "spectral"]] = None,
470
+ ):
471
+ super().__init__()
472
+ weight_norm_fn = get_weight_norm(norm_type=norm_type)
473
+ self.net = nn.Sequential(
474
+ activation,
475
+ weight_norm_fn(
476
+ nn.Conv1d(
477
+ channels,
478
+ channels,
479
+ kernel_size,
480
+ padding=padding,
481
+ dilation=max(dilation, 1),
482
+ )
483
+ ),
484
+ )
485
+
486
+ self.net.apply(self.init_weights)
487
+
488
+ def forward(self, x):
489
+ return self.net(x)
490
+
491
+
492
+ class iSTFTGenerator(ConvNets):
12
493
 
13
494
  def __init__(
14
495
  self,
@@ -24,10 +505,12 @@ class Generator(Model):
24
505
  ],
25
506
  n_fft: int = 16,
26
507
  activation: nn.Module = nn.LeakyReLU(0.1),
508
+ hop_length: int = 256,
27
509
  ):
28
510
  super().__init__()
29
511
  self.num_kernels = len(resblock_kernel_sizes)
30
512
  self.num_upsamples = len(upsample_rates)
513
+ self.hop_length = hop_length
31
514
  self.conv_pre = weight_norm(
32
515
  nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
33
516
  )
@@ -47,7 +530,20 @@ class Generator(Model):
47
530
  self.post_n_fft = n_fft // 2 + 1
48
531
  self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
49
532
  self.conv_post.apply(self.init_weights)
50
- self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
533
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
534
+
535
+ self.phase_pass = nn.Sequential(
536
+ nn.LeakyReLU(0.2),
537
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
538
+ nn.LeakyReLU(0.2),
539
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
540
+ )
541
+ self.spec_pass = nn.Sequential(
542
+ nn.LeakyReLU(0.2),
543
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
544
+ nn.LeakyReLU(0.2),
545
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
546
+ )
51
547
 
52
548
  def _make_blocks(
53
549
  self,
@@ -70,7 +566,7 @@ class Generator(Model):
70
566
  u,
71
567
  padding=(k - u) // 2,
72
568
  )
73
- ),
569
+ ).apply(self.init_weights),
74
570
  ),
75
571
  residual=ResBlocks(
76
572
  channels,
@@ -89,20 +585,7 @@ class Generator(Model):
89
585
 
90
586
  x = self.reflection_pad(x)
91
587
  x = self.conv_post(x)
92
- spec = torch.exp(x[:, : self.post_n_fft, :])
93
- phase = torch.sin(x[:, self.post_n_fft :, :])
588
+ spec = torch.exp(self.spec_pass(x[:, : self.post_n_fft, :]))
589
+ phase = torch.sin(self.phase_pass(x[:, self.post_n_fft :, :]))
94
590
 
95
591
  return spec, phase
96
-
97
- def remove_weight_norm(self):
98
- for module in self.modules():
99
- try:
100
- remove_weight_norm(module)
101
- except ValueError:
102
- pass # Not normed, skip
103
-
104
- @staticmethod
105
- def init_weights(m, mean=0.0, std=0.01):
106
- classname = m.__class__.__name__
107
- if "Conv" in classname:
108
- m.weight.data.normal_(mean, std)
@@ -5,8 +5,8 @@ __all__ = [
5
5
  ]
6
6
 
7
7
  import math
8
- from ..torch_commons import *
9
- from ..model_base import Model
8
+ from lt_tensor.torch_commons import *
9
+ from lt_tensor.model_base import Model
10
10
 
11
11
 
12
12
  class RotaryEmbedding(nn.Module):