lt-tensor 0.0.1a33__py3-none-any.whl → 0.0.1a35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,42 +40,22 @@ class ConvNets(Model):
40
40
 
41
41
  def remove_norms(self, name: str = "weight"):
42
42
  for module in self.modules():
43
- if "Conv" in module.__class__.__name__:
44
- remove_norm(module, name)
43
+ try:
44
+ if "Conv" in module.__class__.__name__:
45
+ remove_norm(module, name)
46
+ except:
47
+ pass
45
48
 
46
49
  @staticmethod
47
- def init_weights(
48
- m: nn.Module,
49
- norm: Optional[Literal["spectral", "weight"]] = None,
50
- mean=0.0,
51
- std=0.02,
52
- name: str = "weight",
53
- n_power_iterations: int = 1,
54
- eps: float = 1e-9,
55
- dim_sn: Optional[int] = None,
56
- dim_wn: int = 0,
57
- ):
50
+ def init_weights(m: nn.Module, mean=0.0, std=0.02):
58
51
  if "Conv" in m.__class__.__name__:
59
- if norm is not None:
60
- try:
61
- if norm == "spectral":
62
- m.apply(
63
- lambda m: spectral_norm(
64
- m,
65
- n_power_iterations=n_power_iterations,
66
- eps=eps,
67
- name=name,
68
- dim=dim_sn,
69
- )
70
- )
71
- else:
72
- m.apply(lambda m: weight_norm(m, name=name, dim=dim_wn))
73
- except ValueError:
74
- pass
75
52
  m.weight.data.normal_(mean, std)
76
53
 
77
54
 
78
55
  class Conv1dEXT(ConvNets):
56
+
57
+ # TODO: Use this module to replace all that are using normalizations, mostly those in `audio_models`
58
+
79
59
  def __init__(
80
60
  self,
81
61
  in_channels: int,
@@ -90,7 +70,8 @@ class Conv1dEXT(ConvNets):
90
70
  device: Optional[Any] = None,
91
71
  dtype: Optional[Any] = None,
92
72
  apply_norm: Optional[Literal["weight", "spectral"]] = None,
93
- activation: nn.Module = nn.Identity(),
73
+ activation_in: nn.Module = nn.Identity(),
74
+ activation_out: nn.Module = nn.Identity(),
94
75
  *args,
95
76
  **kwargs,
96
77
  ):
@@ -112,13 +93,21 @@ class Conv1dEXT(ConvNets):
112
93
  )
113
94
  if apply_norm is None:
114
95
  self.cnn = nn.Conv1d(**cnn_kwargs)
96
+ self.has_wn = False
115
97
  else:
98
+ self.has_wn = True
116
99
  if apply_norm == "spectral":
117
100
  self.cnn = spectral_norm(nn.Conv1d(**cnn_kwargs))
118
101
  else:
119
102
  self.cnn = weight_norm(nn.Conv1d(**cnn_kwargs))
120
- self.activation = activation
103
+ self.actv_in = activation_in
104
+ self.actv_out = activation_out
121
105
  self.cnn.apply(self.init_weights)
122
106
 
123
107
  def forward(self, input: Tensor):
124
- return self.cnn(self.activation(input))
108
+ return self.actv_out(self.cnn(self.actv_in(input)))
109
+
110
+ def remove_norms(self, name="weight"):
111
+ if self.has_wn:
112
+ remove_norm(self.cnn, name)
113
+ self.has_wn = False
@@ -7,8 +7,6 @@ from lt_tensor.model_base import Model
7
7
  from lt_tensor.model_zoo.convs import ConvNets
8
8
  from torch.nn import functional as F
9
9
  from torchaudio import transforms as T
10
- from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
11
-
12
10
 
13
11
  MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
14
12
  List[Tensor],
@@ -19,9 +17,11 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
19
17
 
20
18
 
21
19
  class MultiDiscriminatorWrapper(Model):
22
- def __init__(self, list_discriminator: List["_MultiDiscriminatorT"]):
20
+ def __init__(
21
+ self, list_discriminator: Union[List["_MultiDiscriminatorT"], nn.ModuleList]
22
+ ):
23
23
  """Setup example:
24
- model_d = MultiDiscriminatorStep(
24
+ model_d = MultiDiscriminatorWrapper(
25
25
  [
26
26
  MultiEnvelopeDiscriminator(),
27
27
  MultiBandDiscriminator(),
@@ -31,7 +31,12 @@ class MultiDiscriminatorWrapper(Model):
31
31
  )
32
32
  """
33
33
  super().__init__()
34
- self.disc: Sequence[_MultiDiscriminatorT] = nn.ModuleList(list_discriminator)
34
+
35
+ self.disc: Sequence[_MultiDiscriminatorT] = (
36
+ nn.ModuleList(list_discriminator)
37
+ if isinstance(list_discriminator, (list, tuple, set))
38
+ else list_discriminator
39
+ )
35
40
  self.total = len(self.disc)
36
41
 
37
42
  def forward(
@@ -82,23 +87,6 @@ class MultiDiscriminatorWrapper(Model):
82
87
  return disc_loss, disc_real_losses, disc_gen_losses
83
88
 
84
89
 
85
- def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-5):
86
- norm = torch.norm(x, dim=-1, keepdim=True)
87
- return x / (norm + eps)
88
-
89
-
90
- def normalize_minmax(x: torch.Tensor, eps: float = 1e-5):
91
- min_val = x.amin(dim=-1, keepdim=True)
92
- max_val = x.amax(dim=-1, keepdim=True)
93
- return (x - min_val) / (max_val - min_val + eps)
94
-
95
-
96
- def normalize_zscore(x: torch.Tensor, eps: float = 1e-5):
97
- mean = x.mean(dim=-1, keepdim=True)
98
- std = x.std(dim=-1, keepdim=True)
99
- return (x - mean) / (std + eps)
100
-
101
-
102
90
  def get_padding(kernel_size, dilation=1):
103
91
  return int((kernel_size * dilation - dilation) / 2)
104
92
 
@@ -113,7 +101,6 @@ class _MultiDiscriminatorT(ConvNets):
113
101
  def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
114
102
  pass
115
103
 
116
- # for type hinting
117
104
  def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
118
105
  return super().__call__(*args, **kwds)
119
106
 
@@ -176,7 +163,7 @@ class DiscriminatorP(ConvNets):
176
163
  def __init__(
177
164
  self,
178
165
  period: List[int],
179
- discriminator_channel_mult: Number = 1,
166
+ discriminator_channel_multi: Number = 1,
180
167
  kernel_size: int = 5,
181
168
  stride: int = 3,
182
169
  use_spectral_norm: bool = False,
@@ -184,7 +171,7 @@ class DiscriminatorP(ConvNets):
184
171
  super().__init__()
185
172
  self.period = period
186
173
  norm_f = weight_norm if not use_spectral_norm else spectral_norm
187
- dsc = lambda x: int(x * discriminator_channel_mult)
174
+ dsc = lambda x: int(x * discriminator_channel_multi)
188
175
  self.convs = nn.ModuleList(
189
176
  [
190
177
  norm_f(
@@ -259,19 +246,18 @@ class DiscriminatorP(ConvNets):
259
246
  class MultiPeriodDiscriminator(_MultiDiscriminatorT):
260
247
  def __init__(
261
248
  self,
262
- discriminator_channel_mult: Number = 1,
249
+ discriminator_channel_multi: Number = 1,
263
250
  mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
264
251
  use_spectral_norm: bool = False,
265
252
  ):
266
253
  super().__init__()
267
254
  self.mpd_reshapes = mpd_reshapes
268
- print(f"mpd_reshapes: {self.mpd_reshapes}")
269
255
  self.discriminators = nn.ModuleList(
270
256
  [
271
257
  DiscriminatorP(
272
258
  rs,
273
259
  use_spectral_norm=use_spectral_norm,
274
- discriminator_channel_mult=discriminator_channel_mult,
260
+ discriminator_channel_multi=discriminator_channel_multi,
275
261
  )
276
262
  for rs in self.mpd_reshapes
277
263
  ]
@@ -293,6 +279,79 @@ class MultiPeriodDiscriminator(_MultiDiscriminatorT):
293
279
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
294
280
 
295
281
 
282
+ class DiscriminatorS(ConvNets):
283
+ def __init__(
284
+ self,
285
+ use_spectral_norm=False,
286
+ discriminator_channel_multi: Number = 1,
287
+ ):
288
+ super().__init__()
289
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
290
+ dsc = lambda x: int(x * discriminator_channel_multi)
291
+ self.convs = nn.ModuleList(
292
+ [
293
+ norm_f(nn.Conv1d(1, dsc(128), 15, 1, padding=7)),
294
+ norm_f(nn.Conv1d(dsc(128), dsc(128), 41, 2, groups=4, padding=20)),
295
+ norm_f(nn.Conv1d(dsc(128), dsc(256), 41, 2, groups=16, padding=20)),
296
+ norm_f(nn.Conv1d(dsc(256), dsc(512), 41, 4, groups=16, padding=20)),
297
+ norm_f(nn.Conv1d(dsc(512), dsc(1024), 41, 4, groups=16, padding=20)),
298
+ norm_f(nn.Conv1d(dsc(1024), dsc(1024), 41, 1, groups=16, padding=20)),
299
+ norm_f(nn.Conv1d(dsc(1024), dsc(1024), 5, 1, padding=2)),
300
+ ]
301
+ )
302
+ self.conv_post = norm_f(nn.Conv1d(dsc(1024), 1, 3, 1, padding=1))
303
+ self.activation = nn.LeakyReLU(0.1)
304
+
305
+ def forward(self, x):
306
+ fmap = []
307
+ for l in self.convs:
308
+ x = l(x)
309
+ x = self.activation(x)
310
+ fmap.append(x)
311
+ x = self.conv_post(x)
312
+ fmap.append(x)
313
+ return x.flatten(1, -1), fmap
314
+
315
+
316
+ class MultiScaleDiscriminator(ConvNets):
317
+ def __init__(
318
+ self,
319
+ discriminator_channel_multi: Number = 1,
320
+ ):
321
+ super().__init__()
322
+ self.discriminators = nn.ModuleList(
323
+ [
324
+ DiscriminatorS(
325
+ use_spectral_norm=True,
326
+ discriminator_channel_multi=discriminator_channel_multi,
327
+ ),
328
+ DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
329
+ DiscriminatorS(discriminator_channel_multi=discriminator_channel_multi),
330
+ ]
331
+ )
332
+ self.meanpools = nn.ModuleList(
333
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
334
+ )
335
+
336
+ def forward(self, y, y_hat):
337
+ y_d_rs = []
338
+ y_d_gs = []
339
+ fmap_rs = []
340
+ fmap_gs = []
341
+ for i, d in enumerate(self.discriminators):
342
+ if i > 0:
343
+ y = self.meanpools[i - 1](y)
344
+ y_hat = self.meanpools[i - 1](y_hat)
345
+ y_d_r, fmap_r = d(y)
346
+ y_d_g, fmap_g = d(y_hat)
347
+ y_d_rs.append(y_d_r)
348
+ fmap_rs.append(fmap_r)
349
+ y_d_gs.append(y_d_g)
350
+ fmap_gs.append(fmap_g)
351
+
352
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
353
+
354
+
296
355
  class EnvelopeExtractor(Model):
297
356
  """Extracts the amplitude envelope of the audio signal."""
298
357
 
@@ -314,21 +373,35 @@ class EnvelopeExtractor(Model):
314
373
 
315
374
 
316
375
  class DiscriminatorEnvelope(ConvNets):
317
- def __init__(self, use_spectral_norm=False):
376
+ def __init__(
377
+ self,
378
+ use_spectral_norm=False,
379
+ discriminator_channel_multi: Number = 1,
380
+ kernel_size: int = 101,
381
+ ):
318
382
  super().__init__()
319
383
  norm_f = weight_norm if not use_spectral_norm else spectral_norm
320
- self.extractor = EnvelopeExtractor(kernel_size=101)
384
+ self.extractor = EnvelopeExtractor(kernel_size=kernel_size)
385
+ dsc = lambda x: int(x * discriminator_channel_multi)
321
386
  self.convs = nn.ModuleList(
322
387
  [
323
- norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
324
- norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
325
- norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
326
- norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
327
- norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
328
- norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
388
+ norm_f(nn.Conv1d(1, dsc(64), 15, stride=1, padding=7)),
389
+ norm_f(
390
+ nn.Conv1d(dsc(64), dsc(128), 41, stride=2, groups=4, padding=20)
391
+ ),
392
+ norm_f(
393
+ nn.Conv1d(dsc(128), dsc(256), 41, stride=2, groups=16, padding=20)
394
+ ),
395
+ norm_f(
396
+ nn.Conv1d(dsc(256), dsc(512), 41, stride=4, groups=16, padding=20)
397
+ ),
398
+ norm_f(
399
+ nn.Conv1d(dsc(512), dsc(512), 41, stride=4, groups=16, padding=20)
400
+ ),
401
+ norm_f(nn.Conv1d(dsc(512), dsc(512), 5, stride=1, padding=2)),
329
402
  ]
330
403
  )
331
- self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
404
+ self.conv_post = norm_f(nn.Conv1d(dsc(512), 1, 3, stride=1, padding=1))
332
405
  self.activation = nn.LeakyReLU(0.1)
333
406
 
334
407
  def forward(self, x):
@@ -344,11 +417,17 @@ class DiscriminatorEnvelope(ConvNets):
344
417
 
345
418
 
346
419
  class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
347
- def __init__(self, use_spectral_norm: bool = False):
420
+ def __init__(
421
+ self,
422
+ use_spectral_norm: bool = False,
423
+ discriminator_channel_multi: Number = 1,
424
+ ):
348
425
  super().__init__()
349
426
  self.discriminators = nn.ModuleList(
350
427
  [
351
- DiscriminatorEnvelope(use_spectral_norm), # raw envelope
428
+ DiscriminatorEnvelope(
429
+ use_spectral_norm, discriminator_channel_multi
430
+ ), # raw envelope
352
431
  DiscriminatorEnvelope(use_spectral_norm), # downsampled once
353
432
  DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
354
433
  ]
@@ -448,7 +527,7 @@ class DiscriminatorB(ConvNets):
448
527
  for band, stack in zip(x_bands, self.band_convs):
449
528
  for i, layer in enumerate(stack):
450
529
  band = layer(band)
451
- band = torch.nn.functional.leaky_relu(band, 0.1)
530
+ band = F.leaky_relu(band, 0.1)
452
531
  if i > 0:
453
532
  fmap.append(band)
454
533
  x.append(band)
@@ -469,11 +548,21 @@ class MultiBandDiscriminator(_MultiDiscriminatorT):
469
548
  def __init__(
470
549
  self,
471
550
  mbd_fft_sizes: list[int] = [2048, 1024, 512],
551
+ channels: int = 32,
552
+ hop_factor: float = 0.25,
553
+ bands: Tuple[Tuple[float, float], ...] = (
554
+ (0.0, 0.1),
555
+ (0.1, 0.25),
556
+ (0.25, 0.5),
557
+ (0.5, 0.75),
558
+ (0.75, 1.0),
559
+ ),
472
560
  ):
473
561
  super().__init__()
474
562
  self.fft_sizes = mbd_fft_sizes
563
+ kwargs_disc = dict(channels=channels, hop_factor=hop_factor, bands=bands)
475
564
  self.discriminators = nn.ModuleList(
476
- [DiscriminatorB(window_length=w) for w in self.fft_sizes]
565
+ [DiscriminatorB(window_length=w, **kwargs_disc) for w in self.fft_sizes]
477
566
  )
478
567
 
479
568
  def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
@@ -500,7 +589,7 @@ class DiscriminatorR(ConvNets):
500
589
  self,
501
590
  resolution: List[int],
502
591
  use_spectral_norm: bool = False,
503
- discriminator_channel_mult: int = 1,
592
+ discriminator_channel_multi: Number = 1,
504
593
  ):
505
594
  super().__init__()
506
595
 
@@ -518,13 +607,13 @@ class DiscriminatorR(ConvNets):
518
607
  [
519
608
  norm_f(
520
609
  nn.Conv2d(
521
- 1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
610
+ 1, int(32 * discriminator_channel_multi), (3, 9), padding=(1, 4)
522
611
  )
523
612
  ),
524
613
  norm_f(
525
614
  nn.Conv2d(
526
- int(32 * discriminator_channel_mult),
527
- int(32 * discriminator_channel_mult),
615
+ int(32 * discriminator_channel_multi),
616
+ int(32 * discriminator_channel_multi),
528
617
  (3, 9),
529
618
  stride=(1, 2),
530
619
  padding=(1, 4),
@@ -532,8 +621,8 @@ class DiscriminatorR(ConvNets):
532
621
  ),
533
622
  norm_f(
534
623
  nn.Conv2d(
535
- int(32 * discriminator_channel_mult),
536
- int(32 * discriminator_channel_mult),
624
+ int(32 * discriminator_channel_multi),
625
+ int(32 * discriminator_channel_multi),
537
626
  (3, 9),
538
627
  stride=(1, 2),
539
628
  padding=(1, 4),
@@ -541,8 +630,8 @@ class DiscriminatorR(ConvNets):
541
630
  ),
542
631
  norm_f(
543
632
  nn.Conv2d(
544
- int(32 * discriminator_channel_mult),
545
- int(32 * discriminator_channel_mult),
633
+ int(32 * discriminator_channel_multi),
634
+ int(32 * discriminator_channel_multi),
546
635
  (3, 9),
547
636
  stride=(1, 2),
548
637
  padding=(1, 4),
@@ -550,8 +639,8 @@ class DiscriminatorR(ConvNets):
550
639
  ),
551
640
  norm_f(
552
641
  nn.Conv2d(
553
- int(32 * discriminator_channel_mult),
554
- int(32 * discriminator_channel_mult),
642
+ int(32 * discriminator_channel_multi),
643
+ int(32 * discriminator_channel_multi),
555
644
  (3, 3),
556
645
  padding=(1, 1),
557
646
  )
@@ -559,7 +648,7 @@ class DiscriminatorR(ConvNets):
559
648
  ]
560
649
  )
561
650
  self.conv_post = norm_f(
562
- nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
651
+ nn.Conv2d(int(32 * discriminator_channel_multi), 1, (3, 3), padding=(1, 1))
563
652
  )
564
653
 
565
654
  def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
@@ -603,7 +692,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
603
692
  def __init__(
604
693
  self,
605
694
  use_spectral_norm: bool = False,
606
- discriminator_channel_mult: int = 1,
695
+ discriminator_channel_multi: Number = 1,
607
696
  resolutions: List[List[int]] = [
608
697
  [1024, 120, 600],
609
698
  [2048, 240, 1200],
@@ -618,7 +707,7 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
618
707
  self.discriminators = nn.ModuleList(
619
708
  [
620
709
  DiscriminatorR(
621
- resolution, use_spectral_norm, discriminator_channel_mult
710
+ resolution, use_spectral_norm, discriminator_channel_multi
622
711
  )
623
712
  for resolution in self.resolutions
624
713
  ]
@@ -637,179 +726,3 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
637
726
  y_d_gs.append(y_d_g)
638
727
  fmap_gs.append(fmap_g)
639
728
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
640
-
641
-
642
- class MultiMelScaleLoss(Model):
643
- # TODO: Make the normalization an argument to be chosen by the dev
644
- def __init__(
645
- self,
646
- sample_rate: int,
647
- n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
648
- window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
649
- n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
650
- hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
651
- weight: float = 1.0,
652
- lambda_mel: float = 1.0,
653
- f_min: float = [0, 0, 0, 0, 0, 0, 0],
654
- f_max: Optional[float] = [None, None, None, None, None, None, None],
655
- loss_fn: Callable = nn.L1Loss(),
656
- center: bool = True,
657
- power: float = 1.0,
658
- normalized: bool = False,
659
- pad_mode: str = "reflect",
660
- onesided: Optional[bool] = None,
661
- std: int = 4,
662
- mean: int = -4,
663
- auto_interpolate: bool = True,
664
- use_istft_norm: bool = True,
665
- use_pitch_loss: bool = False,
666
- use_rms_loss: bool = False,
667
- lambda_pitch: float = 0.5,
668
- lambda_rms: float = 0.5,
669
- ):
670
- super().__init__()
671
- assert (
672
- len(n_mels)
673
- == len(window_lengths)
674
- == len(n_ffts)
675
- == len(hops)
676
- == len(f_min)
677
- == len(f_max)
678
- )
679
- self.loss_fn = loss_fn
680
- self.lambda_mel = lambda_mel
681
- self.weight = weight
682
- self.use_istft_norm = use_istft_norm
683
- self.auto_interpolate = auto_interpolate if not self.use_istft_norm else False
684
- self.use_pitch_loss = use_pitch_loss
685
- self.use_rms_loss = use_rms_loss
686
- self.lambda_pitch = lambda_pitch
687
- self.lambda_rms = lambda_rms
688
-
689
- self._setup_mels(
690
- sample_rate,
691
- n_mels,
692
- window_lengths,
693
- n_ffts,
694
- hops,
695
- f_min,
696
- f_max,
697
- center,
698
- power,
699
- normalized,
700
- pad_mode,
701
- onesided,
702
- std,
703
- mean,
704
- )
705
-
706
- def _setup_mels(
707
- self,
708
- sample_rate: int,
709
- n_mels: List[int],
710
- window_lengths: List[int],
711
- n_ffts: List[int],
712
- hops: List[int],
713
- f_min: List[float],
714
- f_max: List[Optional[float]],
715
- center: bool,
716
- power: float,
717
- normalized: bool,
718
- pad_mode: str = "reflect",
719
- onesided: Optional[bool] = None,
720
- std: int = 4,
721
- mean: int = -4,
722
- ):
723
- assert (
724
- len(n_mels)
725
- == len(window_lengths)
726
- == len(n_ffts)
727
- == len(hops)
728
- == len(f_min)
729
- == len(f_max)
730
- )
731
- _mel_kwargs = dict(
732
- sample_rate=sample_rate,
733
- center=center,
734
- onesided=onesided,
735
- normalized=normalized,
736
- power=power,
737
- pad_mode=pad_mode,
738
- std=std,
739
- mean=mean,
740
- )
741
- self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
742
- [
743
- AudioProcessor(
744
- AudioProcessorConfig(
745
- **_mel_kwargs,
746
- n_mels=mel,
747
- n_fft=n_fft,
748
- win_length=win,
749
- hop_length=hop,
750
- f_min=fmin,
751
- f_max=fmax,
752
- )
753
- )
754
- for mel, win, n_fft, hop, fmin, fmax in zip(
755
- n_mels, window_lengths, n_ffts, hops, f_min, f_max
756
- )
757
- ]
758
- )
759
-
760
- def _process_tensor(
761
- self,
762
- input_wave: torch.Tensor,
763
- target_wave: torch.Tensor,
764
- ):
765
- if input_wave.shape[-1] != target_wave.shape[-1]:
766
- if input_wave.ndim < 3:
767
- # To be compatible with interpolatin
768
- if input_wave.ndim == 2:
769
- input_wave = input_wave.unsqueeze(1)
770
- else:
771
- input_wave = input_wave.unsqueeze(0).unsqueeze(0)
772
- input_wave = F.interpolate(input_wave, target_wave.shape[-1], mode="linear")
773
- return input_wave
774
-
775
- def forward(
776
- self, input_wave: torch.Tensor, target_wave: torch.Tensor
777
- ) -> torch.Tensor:
778
- assert (
779
- self.use_istft_norm
780
- or self.auto_interpolate
781
- or input_wave.shape[-1] == target_wave.shape[-1]
782
- )
783
- if self.auto_interpolate:
784
- input_wave = self._process_tensor(input_wave, target_wave)
785
-
786
- losses = 0.0
787
- for M in self.mel_spectrograms:
788
- # Apply normalization if requested
789
- if self.use_istft_norm:
790
- input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
791
- target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
792
- else:
793
- input_proc, target_proc = input_wave, target_wave
794
-
795
- x_mels = M(input_proc)
796
- y_mels = M(target_proc)
797
-
798
- loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
799
- losses += loss * self.lambda_mel
800
-
801
- # pitch/f0 loss
802
- if self.use_pitch_loss:
803
- x_pitch = normalize_unit_norm(M.compute_pitch(input_proc))
804
- y_pitch = normalize_unit_norm(M.compute_pitch(target_proc))
805
- f0_loss = self.loss_fn(x_pitch, y_pitch)
806
- losses += f0_loss * self.lambda_pitch
807
-
808
- # energy/rms loss
809
- if self.use_rms_loss:
810
- x_rms = normalize_unit_norm(M.compute_rms(input_proc, x_mels))
811
- y_rms = normalize_unit_norm(M.compute_rms(target_proc, y_mels))
812
- rms_loss = self.loss_fn(x_rms, y_rms)
813
- losses += rms_loss * self.lambda_rms
814
-
815
- return losses * self.weight
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a33
3
+ Version: 0.0.1a35
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -0,0 +1,40 @@
1
+ lt_tensor/__init__.py,sha256=4NqhrI_O5q4YQMBpyoLtNUUbBnnbWkO92GE1hxHcrd8,441
2
+ lt_tensor/config_templates.py,sha256=F9UvL8paAjkSvio890kp8WznpYeI50pYnm9iqQroBxk,2797
3
+ lt_tensor/losses.py,sha256=Heco_WyoC1HkNkcJEircOAzS9umusATHiNAG-FKGyzc,8918
4
+ lt_tensor/lr_schedulers.py,sha256=6_vcfaPHrozfH3wvmNEdKSFYl6iTIijYoHL8vuG-45U,7651
5
+ lt_tensor/math_ops.py,sha256=ahX6Z1Mt3X-FhmwSZYZea5mB1B0S8GDuvKPfAm5e_FQ,2646
6
+ lt_tensor/misc_utils.py,sha256=stL6q3M7S2N4FBICFYbgYpdPDrJRlwmr24-iCXMRifM,28933
7
+ lt_tensor/model_base.py,sha256=5T4dbAh4MXbQmPRpihGtMYwTY8sJTQOhY6An3VboM58,18086
8
+ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
9
+ lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
+ lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
11
+ lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
+ lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
13
+ lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
14
+ lt_tensor/model_zoo/convs.py,sha256=Tws0jrPfs9m7OLmJ30W0AfkAvZgppW7lNi4xt0e-qRU,3518
15
+ lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
16
+ lt_tensor/model_zoo/fusion.py,sha256=usC1bcjQRNivDc8xzkIS5T1glm78OLcs2V_tPqfp-eI,5422
17
+ lt_tensor/model_zoo/pos_encoder.py,sha256=3d1EYLinCU9UAy-WuEWeYMGhMqaGknCiQ5qEmhw_UYM,4487
18
+ lt_tensor/model_zoo/residual.py,sha256=tMXgif9Ggep9bk75K93yueeU5vk5S25AGCRFwOQOyB8,6452
19
+ lt_tensor/model_zoo/transformer.py,sha256=HUFoFFh7EQJErxdd9XIxhssdjvNVx2tNGDJOTUfwG2A,4301
20
+ lt_tensor/model_zoo/activations/__init__.py,sha256=f_IsuC-SaFsX6w4OtBWa5bbS4TqR90X-cvLxGUgYfjk,67
21
+ lt_tensor/model_zoo/activations/alias_free/__init__.py,sha256=dgLjatRm9nusoPVOl1pvCef5rZsaRfS3BJUs05SPYzw,64
22
+ lt_tensor/model_zoo/activations/alias_free/act.py,sha256=1wxmab2kMD88L6wsQgf3t25dBwR7_he2eM1DlV0FQak,1424
23
+ lt_tensor/model_zoo/activations/alias_free/filter.py,sha256=5TvXESv31toD5sePBe_OUJJfMXv6Ohwmx2YawjQL-pk,6004
24
+ lt_tensor/model_zoo/activations/alias_free/resample.py,sha256=3iM4fNr9fLNXXMyXvzW-MwkSjOZOrMZLfS80UHs6zk0,3386
25
+ lt_tensor/model_zoo/activations/snake/__init__.py,sha256=AtOAbJuMinxmKkppITGMzRbcbPQaALnl9mCtl1c3x0Q,4356
26
+ lt_tensor/model_zoo/audio_models/__init__.py,sha256=WwiP9MekJreMOfKPWLl24VkRJIpLk6hhL8ch0aKgOss,103
27
+ lt_tensor/model_zoo/audio_models/resblocks.py,sha256=u-foHxaFDUICjxSkpyHXljQYQG9zMxVYaOGqLR_nJ-k,7978
28
+ lt_tensor/model_zoo/audio_models/bigvgan/__init__.py,sha256=Dpt_3JXUToldxQrZx4a1gfI-awsLIVipAXqWm4lzBzM,8495
29
+ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
30
+ lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=3HVfEreQ4NqYIC9AWEkmL4ePcIbR1kTyH0cBG8u_Jik,6387
31
+ lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=blICjLX_z_IFmR3_TCz_dJiSayLYGza9eG6fd9aKyvE,7448
32
+ lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
33
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=HBO7jwCsUGsYfSz-JZPZccuYLnto6jfZs3Ve5j51JQE,24247
34
+ lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
35
+ lt_tensor/processors/audio.py,sha256=HNr1GS-6M2q0Rda4cErf5y2Jlc9f4jD58FvpX2ua9d4,18369
36
+ lt_tensor-0.0.1a35.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
37
+ lt_tensor-0.0.1a35.dist-info/METADATA,sha256=0FrtLNnbU49bKOlyshasXPZOZ90Sok03XkXbtxP4VMI,1062
38
+ lt_tensor-0.0.1a35.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ lt_tensor-0.0.1a35.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
40
+ lt_tensor-0.0.1a35.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- from . import *