lt-tensor 0.0.1a15__py3-none-any.whl → 0.0.1a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,627 +0,0 @@
1
- __all__ = ["AudioSettings", "AudioDecoderTrainer", "AudioGeneratorOnlyTrainer"]
2
- import gc
3
- import itertools
4
- from lt_utils.common import *
5
- import torch.nn.functional as F
6
- from lt_tensor.torch_commons import *
7
- from lt_tensor.model_base import Model
8
- from lt_utils.misc_utils import log_traceback
9
- from lt_tensor.processors import AudioProcessor
10
- from lt_tensor.misc_utils import set_seed, clear_cache
11
- from lt_utils.type_utils import is_dir, is_pathlike
12
- from lt_tensor.config_templates import ModelConfig
13
- from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
14
- from lt_tensor.model_zoo.discriminator import (
15
- MultiPeriodDiscriminator,
16
- MultiScaleDiscriminator,
17
- )
18
- from lt_tensor.model_zoo.residual import ResBlock1D2, ResBlock1D
19
-
20
-
21
- def feature_loss(fmap_r, fmap_g):
22
- loss = 0
23
- for dr, dg in zip(fmap_r, fmap_g):
24
- for rl, gl in zip(dr, dg):
25
- loss += torch.mean(torch.abs(rl - gl))
26
- return loss * 2
27
-
28
-
29
- def generator_adv_loss(disc_outputs):
30
- loss = 0
31
- for dg in disc_outputs:
32
- l = torch.mean((1 - dg) ** 2)
33
- loss += l
34
- return loss
35
-
36
-
37
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
38
- loss = 0
39
-
40
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
41
- r_loss = torch.mean((1 - dr) ** 2)
42
- g_loss = torch.mean(dg**2)
43
- loss += r_loss + g_loss
44
- return loss
45
-
46
-
47
- class AudioSettings(ModelConfig):
48
- def __init__(
49
- self,
50
- n_mels: int = 80,
51
- upsample_rates: List[Union[int, List[int]]] = [8, 8],
52
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
53
- upsample_initial_channel: int = 512,
54
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
55
- resblock_dilation_sizes: List[Union[int, List[int]]] = [
56
- [1, 3, 5],
57
- [1, 3, 5],
58
- [1, 3, 5],
59
- ],
60
- n_fft: int = 16,
61
- activation: nn.Module = nn.LeakyReLU(0.1),
62
- msd_layers: int = 3,
63
- mpd_periods: List[int] = [2, 3, 5, 7, 11],
64
- seed: Optional[int] = None,
65
- lr: float = 1e-5,
66
- adamw_betas: List[float] = [0.75, 0.98],
67
- scheduler_template: Callable[
68
- [optim.Optimizer], optim.lr_scheduler.LRScheduler
69
- ] = lambda optimizer: optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998),
70
- residual_cls: Union[ResBlock1D, ResBlock1D2] = ResBlock1D,
71
- ):
72
- self.in_channels = n_mels
73
- self.upsample_rates = upsample_rates
74
- self.upsample_kernel_sizes = upsample_kernel_sizes
75
- self.upsample_initial_channel = upsample_initial_channel
76
- self.resblock_kernel_sizes = resblock_kernel_sizes
77
- self.resblock_dilation_sizes = resblock_dilation_sizes
78
- self.n_fft = n_fft
79
- self.activation = activation
80
- self.mpd_periods = mpd_periods
81
- self.msd_layers = msd_layers
82
- self.seed = seed
83
- self.lr = lr
84
- self.adamw_betas = adamw_betas
85
- self.scheduler_template = scheduler_template
86
- self.residual_cls = residual_cls
87
-
88
-
89
- class AudioDecoderTrainer(Model):
90
- def __init__(
91
- self,
92
- audio_processor: AudioProcessor,
93
- settings: Optional[AudioSettings] = None,
94
- generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
95
- ):
96
- super().__init__()
97
- if settings is None:
98
- self.settings = AudioSettings()
99
- elif isinstance(settings, dict):
100
- self.settings = AudioSettings(**settings)
101
- elif isinstance(settings, AudioSettings):
102
- self.settings = settings
103
- else:
104
- raise ValueError(
105
- "Cannot initialize the waveDecoder with the given settings. "
106
- "Use either a dictionary, or the class WaveSettings to setup the settings. "
107
- "Alternatively, leave it None to use the default values."
108
- )
109
- if self.settings.seed is not None:
110
- set_seed(self.settings.seed)
111
- if generator is None:
112
- generator = iSTFTGenerator
113
- self.generator: iSTFTGenerator = generator(
114
- in_channels=self.settings.in_channels,
115
- upsample_rates=self.settings.upsample_rates,
116
- upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
117
- upsample_initial_channel=self.settings.upsample_initial_channel,
118
- resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
119
- resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
120
- n_fft=self.settings.n_fft,
121
- activation=self.settings.activation,
122
- )
123
- self.generator.eval()
124
- self.g_optim = None
125
- self.d_optim = None
126
- self.gan_training = False
127
- self.audio_processor = audio_processor
128
- self.register_buffer("msd", None, persistent=False)
129
- self.register_buffer("mpd", None, persistent=False)
130
-
131
- def setup_training_mode(self, load_weights_from: Optional[PathLike] = None):
132
- """The location must be path not a file!"""
133
- self.finish_training_setup()
134
- if self.msd is None:
135
- self.msd = MultiScaleDiscriminator(self.settings.msd_layers)
136
- if self.mpd is None:
137
- self.mpd = MultiPeriodDiscriminator(self.settings.mpd_periods)
138
- if load_weights_from is not None:
139
- if is_dir(path=load_weights_from, validate=False):
140
- try:
141
- self.msd.load_weights(Path(load_weights_from, "msd.pt"))
142
- except Exception as e:
143
- log_traceback(e, "MSD Loading")
144
- try:
145
- self.mpd.load_weights(Path(load_weights_from, "mpd.pt"))
146
- except Exception as e:
147
- log_traceback(e, "MPD Loading")
148
-
149
- self.update_schedulers_and_optimizer()
150
- self.msd.to(device=self.device)
151
- self.mpd.to(device=self.device)
152
-
153
- self.gan_training = True
154
- return True
155
-
156
- def update_schedulers_and_optimizer(self):
157
- gc.collect()
158
- self.g_optim = None
159
- self.g_scheduler = None
160
- gc.collect()
161
- self.g_optim = optim.AdamW(
162
- self.generator.parameters(),
163
- lr=self.settings.lr,
164
- betas=self.settings.adamw_betas,
165
- )
166
- gc.collect()
167
- self.g_scheduler = self.settings.scheduler_template(self.g_optim)
168
- if any([self.mpd is None, self.msd is None]):
169
- return
170
- gc.collect()
171
- self.d_optim = optim.AdamW(
172
- itertools.chain(self.mpd.parameters(), self.msd.parameters()),
173
- lr=self.settings.lr,
174
- betas=self.settings.adamw_betas,
175
- )
176
- self.d_scheduler = self.settings.scheduler_template(self.d_optim)
177
-
178
- def set_lr(self, new_lr: float = 1e-4):
179
- if self.g_optim is not None:
180
- for groups in self.g_optim.param_groups:
181
- groups["lr"] = new_lr
182
-
183
- if self.d_optim is not None:
184
- for groups in self.d_optim.param_groups:
185
- groups["lr"] = new_lr
186
- return self.get_lr()
187
-
188
- def get_lr(self) -> Tuple[float, float]:
189
- g = float("nan")
190
- d = float("nan")
191
- if self.g_optim is not None:
192
- g = self.g_optim.param_groups[0]["lr"]
193
- if self.d_optim is not None:
194
- d = self.d_optim.param_groups[0]["lr"]
195
- return g, d
196
-
197
- def save_weights(self, path, replace=True):
198
- is_pathlike(path, check_if_empty=True, validate=True)
199
- if str(path).endswith(".pt"):
200
- path = Path(path).parent
201
- else:
202
- path = Path(path)
203
- self.generator.save_weights(Path(path, "generator.pt"), replace)
204
- if self.msd is not None:
205
- self.msd.save_weights(Path(path, "msp.pt"), replace)
206
- if self.mpd is not None:
207
- self.mpd.save_weights(Path(path, "mpd.pt"), replace)
208
-
209
- def load_weights(
210
- self,
211
- path,
212
- raise_if_not_exists=False,
213
- strict=True,
214
- assign=False,
215
- weights_only=False,
216
- mmap=None,
217
- **torch_loader_kwargs
218
- ):
219
- is_pathlike(path, check_if_empty=True, validate=True)
220
- if str(path).endswith(".pt"):
221
- path = Path(path)
222
- else:
223
- path = Path(path, "generator.pt")
224
-
225
- self.generator.load_weights(
226
- path,
227
- raise_if_not_exists,
228
- strict,
229
- assign,
230
- weights_only,
231
- mmap,
232
- **torch_loader_kwargs,
233
- )
234
-
235
- def finish_training_setup(self):
236
- gc.collect()
237
- self.mpd = None
238
- clear_cache()
239
- gc.collect()
240
- self.msd = None
241
- clear_cache()
242
- self.gan_training = False
243
-
244
- def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
245
- """Returns the generated spec and phase"""
246
- return self.generator.forward(mel_spec)
247
-
248
- def inference(
249
- self,
250
- mel_spec: Tensor,
251
- return_dict: bool = False,
252
- ) -> Union[Dict[str, Tensor], Tensor]:
253
- spec, phase = super().inference(mel_spec)
254
- wave = self.audio_processor.inverse_transform(
255
- spec,
256
- phase,
257
- self.settings.n_fft,
258
- hop_length=4,
259
- win_length=self.settings.n_fft,
260
- )
261
- if not return_dict:
262
- return wave
263
- return {
264
- "wave": wave,
265
- "spec": spec,
266
- "phase": phase,
267
- }
268
-
269
- def set_device(self, device: str):
270
- self.to(device=device)
271
- self.generator.to(device=device)
272
- self.audio_processor.to(device=device)
273
- self.msd.to(device=device)
274
- self.mpd.to(device=device)
275
-
276
- def train_step(
277
- self,
278
- mels: Tensor,
279
- real_audio: Tensor,
280
- stft_scale: float = 1.0,
281
- mel_scale: float = 1.0,
282
- adv_scale: float = 1.0,
283
- fm_scale: float = 1.0,
284
- fm_add: float = 0.0,
285
- is_discriminator_frozen: bool = False,
286
- is_generator_frozen: bool = False,
287
- ):
288
- if not self.gan_training:
289
- self.setup_training_mode()
290
- spec, phase = super().train_step(mels)
291
- real_audio = real_audio.squeeze(1)
292
- fake_audio = self.audio_processor.inverse_transform(
293
- spec,
294
- phase,
295
- self.settings.n_fft,
296
- hop_length=4,
297
- win_length=self.settings.n_fft,
298
- length=real_audio.shape[-1],
299
- )
300
-
301
- disc_kwargs = dict(
302
- real_audio=real_audio,
303
- fake_audio=fake_audio.detach(),
304
- am_i_frozen=is_discriminator_frozen,
305
- )
306
- if is_discriminator_frozen:
307
- with torch.no_grad():
308
- disc_out = self._discriminator_step(**disc_kwargs)
309
- else:
310
- disc_out = self._discriminator_step(**disc_kwargs)
311
-
312
- generator_kwargs = dict(
313
- mels=mels,
314
- real_audio=real_audio,
315
- fake_audio=fake_audio,
316
- **disc_out,
317
- stft_scale=stft_scale,
318
- mel_scale=mel_scale,
319
- adv_scale=adv_scale,
320
- fm_add=fm_add,
321
- fm_scale=fm_scale,
322
- am_i_frozen=is_generator_frozen,
323
- )
324
-
325
- if is_generator_frozen:
326
- with torch.no_grad():
327
- return self._generator_step(**generator_kwargs)
328
- return self._generator_step(**generator_kwargs)
329
-
330
- def _discriminator_step(
331
- self,
332
- real_audio: Tensor,
333
- fake_audio: Tensor,
334
- am_i_frozen: bool = False,
335
- ):
336
- # ========== Discriminator Forward Pass ==========
337
- if not am_i_frozen:
338
- self.d_optim.zero_grad()
339
- # MPD
340
- real_mpd_preds, _ = self.mpd(real_audio)
341
- fake_mpd_preds, _ = self.mpd(fake_audio)
342
- # MSD
343
- real_msd_preds, _ = self.msd(real_audio)
344
- fake_msd_preds, _ = self.msd(fake_audio)
345
-
346
- loss_d_mpd = discriminator_loss(real_mpd_preds, fake_mpd_preds)
347
- loss_d_msd = discriminator_loss(real_msd_preds, fake_msd_preds)
348
- loss_d = loss_d_mpd + loss_d_msd
349
-
350
- if not am_i_frozen:
351
- loss_d.backward()
352
- self.d_optim.step()
353
-
354
- return {
355
- "loss_d": loss_d.item(),
356
- }
357
-
358
- def _generator_step(
359
- self,
360
- mels: Tensor,
361
- real_audio: Tensor,
362
- fake_audio: Tensor,
363
- loss_d: float,
364
- stft_scale: float = 1.0,
365
- mel_scale: float = 1.0,
366
- adv_scale: float = 1.0,
367
- fm_scale: float = 1.0,
368
- fm_add: float = 0.0,
369
- am_i_frozen: bool = False,
370
- ):
371
- # ========== Generator Loss ==========
372
- if not am_i_frozen:
373
- self.g_optim.zero_grad()
374
- real_mpd_feats = self.mpd(real_audio)[1]
375
- real_msd_feats = self.msd(real_audio)[1]
376
-
377
- fake_mpd_preds, fake_mpd_feats = self.mpd(fake_audio)
378
- fake_msd_preds, fake_msd_feats = self.msd(fake_audio)
379
-
380
- loss_adv_mpd = generator_adv_loss(fake_mpd_preds)
381
- loss_adv_msd = generator_adv_loss(fake_msd_preds)
382
- loss_fm_mpd = feature_loss(real_mpd_feats, fake_mpd_feats)
383
- loss_fm_msd = feature_loss(real_msd_feats, fake_msd_feats)
384
-
385
- # loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
386
- loss_mel = (
387
- F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
388
- )
389
- loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
390
-
391
- loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
392
-
393
- loss_g = loss_adv + loss_fm + loss_mel # + loss_stft
394
- if not am_i_frozen:
395
- loss_g.backward()
396
- self.g_optim.step()
397
-
398
- lr_g, lr_d = self.get_lr()
399
- return {
400
- "loss_g": loss_g.item(),
401
- "loss_d": loss_d,
402
- "loss_adv": loss_adv.item(),
403
- "loss_fm": loss_fm.item(),
404
- "loss_stft": 1.0, # loss_stft.item(),
405
- "loss_mel": loss_mel.item(),
406
- "lr_g": lr_g,
407
- "lr_d": lr_d,
408
- }
409
-
410
- def step_scheduler(
411
- self, is_disc_frozen: bool = False, is_generator_frozen: bool = False
412
- ):
413
- if self.d_scheduler is not None and not is_disc_frozen:
414
- self.d_scheduler.step()
415
- if self.g_scheduler is not None and not is_generator_frozen:
416
- self.g_scheduler.step()
417
-
418
- def reset_schedulers(self, lr: Optional[float] = None):
419
- """
420
- In case you have adopted another strategy, with this function,
421
- it is possible restart the scheduler and set the lr to another value.
422
- """
423
- if lr is not None:
424
- self.set_lr(lr)
425
- if self.d_optim is not None:
426
- self.d_scheduler = None
427
- self.d_scheduler = self.settings.scheduler_template(self.d_optim)
428
- if self.g_optim is not None:
429
- self.g_scheduler = None
430
- self.g_scheduler = self.settings.scheduler_template(self.g_optim)
431
-
432
-
433
- class AudioGeneratorOnlyTrainer(Model):
434
- def __init__(
435
- self,
436
- audio_processor: AudioProcessor,
437
- settings: Optional[AudioSettings] = None,
438
- generator: Optional[Union[Model, "iSTFTGenerator"]] = None, # non initialized!
439
- ):
440
- super().__init__()
441
- if settings is None:
442
- self.settings = AudioSettings()
443
- elif isinstance(settings, dict):
444
- self.settings = AudioSettings(**settings)
445
- elif isinstance(settings, AudioSettings):
446
- self.settings = settings
447
- else:
448
- raise ValueError(
449
- "Cannot initialize the waveDecoder with the given settings. "
450
- "Use either a dictionary, or the class WaveSettings to setup the settings. "
451
- "Alternatively, leave it None to use the default values."
452
- )
453
- if self.settings.seed is not None:
454
- set_seed(self.settings.seed)
455
- if generator is None:
456
- generator = iSTFTGenerator
457
- self.generator: iSTFTGenerator = generator(
458
- in_channels=self.settings.in_channels,
459
- upsample_rates=self.settings.upsample_rates,
460
- upsample_kernel_sizes=self.settings.upsample_kernel_sizes,
461
- upsample_initial_channel=self.settings.upsample_initial_channel,
462
- resblock_kernel_sizes=self.settings.resblock_kernel_sizes,
463
- resblock_dilation_sizes=self.settings.resblock_dilation_sizes,
464
- n_fft=self.settings.n_fft,
465
- activation=self.settings.activation,
466
- )
467
- self.generator.eval()
468
- self.gen_training = False
469
- self.audio_processor = audio_processor
470
-
471
- def setup_training_mode(self, *args, **kwargs):
472
- self.finish_training_setup()
473
- self.update_schedulers_and_optimizer()
474
- self.gen_training = True
475
- return True
476
-
477
- def update_schedulers_and_optimizer(self):
478
- self.g_optim = optim.AdamW(
479
- self.generator.parameters(),
480
- lr=self.settings.lr,
481
- betas=self.settings.adamw_betas,
482
- )
483
- self.g_scheduler = self.settings.scheduler_template(self.g_optim)
484
-
485
- def set_lr(self, new_lr: float = 1e-4):
486
- if self.g_optim is not None:
487
- for groups in self.g_optim.param_groups:
488
- groups["lr"] = new_lr
489
- return self.get_lr()
490
-
491
- def get_lr(self) -> Tuple[float, float]:
492
- if self.g_optim is not None:
493
- return self.g_optim.param_groups[0]["lr"]
494
- return float("nan")
495
-
496
- def save_weights(self, path, replace=True):
497
- is_pathlike(path, check_if_empty=True, validate=True)
498
- if str(path).endswith(".pt"):
499
- path = Path(path).parent
500
- else:
501
- path = Path(path)
502
- self.generator.save_weights(Path(path, "generator.pt"), replace)
503
-
504
- def load_weights(
505
- self,
506
- path,
507
- raise_if_not_exists=False,
508
- strict=True,
509
- assign=False,
510
- weights_only=False,
511
- mmap=None,
512
- **torch_loader_kwargs
513
- ):
514
- is_pathlike(path, check_if_empty=True, validate=True)
515
- if str(path).endswith(".pt"):
516
- path = Path(path)
517
- else:
518
- path = Path(path, "generator.pt")
519
-
520
- self.generator.load_weights(
521
- path,
522
- raise_if_not_exists,
523
- strict,
524
- assign,
525
- weights_only,
526
- mmap,
527
- **torch_loader_kwargs,
528
- )
529
-
530
- def finish_training_setup(self):
531
- gc.collect()
532
- clear_cache()
533
- self.eval()
534
- self.gen_training = False
535
-
536
- def forward(self, mel_spec: Tensor) -> Tuple[Tensor, Tensor]:
537
- """Returns the generated spec and phase"""
538
- return self.generator.forward(mel_spec)
539
-
540
- def inference(
541
- self,
542
- mel_spec: Tensor,
543
- return_dict: bool = False,
544
- ) -> Union[Dict[str, Tensor], Tensor]:
545
- spec, phase = super().inference(mel_spec)
546
- wave = self.audio_processor.inverse_transform(
547
- spec,
548
- phase,
549
- self.settings.n_fft,
550
- hop_length=4,
551
- win_length=self.settings.n_fft,
552
- )
553
- if not return_dict:
554
- return wave[:, : wave.shape[-1] - 256]
555
- return {
556
- "wave": wave[:, : wave.shape[-1] - 256],
557
- "spec": spec,
558
- "phase": phase,
559
- }
560
-
561
- def set_device(self, device: str):
562
- self.to(device=device)
563
- self.generator.to(device=device)
564
- self.audio_processor.to(device=device)
565
- self.msd.to(device=device)
566
- self.mpd.to(device=device)
567
-
568
- def train_step(
569
- self,
570
- mels: Tensor,
571
- real_audio: Tensor,
572
- stft_scale: float = 1.0,
573
- mel_scale: float = 1.0,
574
- ext_loss: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
575
- ):
576
- if not self.gen_training:
577
- self.setup_training_mode()
578
-
579
- self.g_optim.zero_grad()
580
- spec, phase = self.generator.train_step(mels)
581
-
582
- real_audio = real_audio.squeeze(1)
583
- with torch.no_grad():
584
- fake_audio = self.audio_processor.inverse_transform(
585
- spec,
586
- phase,
587
- self.settings.n_fft,
588
- hop_length=4,
589
- win_length=self.settings.n_fft,
590
- )[:, : real_audio.shape[-1]]
591
- loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
592
- loss_mel = (
593
- F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
594
- )
595
- loss_g.backward()
596
- loss_g = loss_stft + loss_mel
597
- loss_ext = 0
598
-
599
- if ext_loss is not None:
600
- l_ext = ext_loss(fake_audio, real_audio)
601
- loss_g = loss_g + l_ext
602
- loss_ext = l_ext.item()
603
-
604
- self.g_optim.step()
605
- return {
606
- "loss": loss_g.item(),
607
- "loss_stft": loss_stft.item(),
608
- "loss_mel": loss_mel.item(),
609
- "loss_ext": loss_ext,
610
- "lr": self.get_lr(),
611
- }
612
-
613
- def step_scheduler(self):
614
-
615
- if self.g_scheduler is not None:
616
- self.g_scheduler.step()
617
-
618
- def reset_schedulers(self, lr: Optional[float] = None):
619
- """
620
- In case you have adopted another strategy, with this function,
621
- it is possible restart the scheduler and set the lr to another value.
622
- """
623
- if lr is not None:
624
- self.set_lr(lr)
625
- if self.g_optim is not None:
626
- self.g_scheduler = None
627
- self.g_scheduler = self.settings.scheduler_template(self.g_optim)