lt-tensor 0.0.1a33__py3-none-any.whl → 0.0.1a35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,18 @@
1
1
  __all__ = ["HifiganGenerator", "HifiganConfig"]
2
+
3
+
2
4
  from lt_utils.common import *
3
5
  from lt_tensor.torch_commons import *
4
6
  from lt_tensor.model_zoo.convs import ConvNets
5
- from torch.nn import functional as F
6
- from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
7
- from lt_tensor.misc_utils import get_config, get_weights
7
+ from lt_tensor.config_templates import ModelConfig
8
+ from lt_utils.file_ops import is_file
9
+ from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
8
10
 
9
11
 
10
12
  def get_padding(kernel_size, dilation=1):
11
13
  return int((kernel_size * dilation - dilation) / 2)
12
14
 
13
15
 
14
- from lt_tensor.config_templates import ModelConfig
15
-
16
-
17
16
  class HifiganConfig(ModelConfig):
18
17
  # Training params
19
18
  in_channels: int = 80
@@ -28,6 +27,7 @@ class HifiganConfig(ModelConfig):
28
27
  ]
29
28
 
30
29
  activation: nn.Module = nn.LeakyReLU(0.1)
30
+ resblock_activation: nn.Module = nn.LeakyReLU(0.1)
31
31
  resblock: int = 0
32
32
 
33
33
  def __init__(
@@ -43,6 +43,7 @@ class HifiganConfig(ModelConfig):
43
43
  [1, 3, 5],
44
44
  ],
45
45
  activation: nn.Module = nn.LeakyReLU(0.1),
46
+ resblock_activation: nn.Module = nn.LeakyReLU(0.1),
46
47
  resblock: Union[int, str] = "1",
47
48
  *args,
48
49
  **kwargs,
@@ -55,6 +56,7 @@ class HifiganConfig(ModelConfig):
55
56
  "resblock_kernel_sizes": resblock_kernel_sizes,
56
57
  "resblock_dilation_sizes": resblock_dilation_sizes,
57
58
  "activation": activation,
59
+ "resblock_activation": resblock_activation,
58
60
  "resblock": resblock,
59
61
  }
60
62
  super().__init__(**settings)
@@ -64,128 +66,6 @@ class HifiganConfig(ModelConfig):
64
66
  self.resblock = 0 if self.resblock == "1" else 1
65
67
 
66
68
 
67
- class ResBlock1(ConvNets):
68
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
69
- super().__init__()
70
-
71
- self.convs1 = nn.ModuleList(
72
- [
73
- weight_norm(
74
- nn.Conv1d(
75
- channels,
76
- channels,
77
- kernel_size,
78
- 1,
79
- dilation=dilation[0],
80
- padding=get_padding(kernel_size, dilation[0]),
81
- )
82
- ),
83
- weight_norm(
84
- nn.Conv1d(
85
- channels,
86
- channels,
87
- kernel_size,
88
- 1,
89
- dilation=dilation[1],
90
- padding=get_padding(kernel_size, dilation[1]),
91
- )
92
- ),
93
- weight_norm(
94
- nn.Conv1d(
95
- channels,
96
- channels,
97
- kernel_size,
98
- 1,
99
- dilation=dilation[2],
100
- padding=get_padding(kernel_size, dilation[2]),
101
- )
102
- ),
103
- ]
104
- )
105
- self.convs1.apply(self.init_weights)
106
-
107
- self.convs2 = nn.ModuleList(
108
- [
109
- weight_norm(
110
- nn.Conv1d(
111
- channels,
112
- channels,
113
- kernel_size,
114
- 1,
115
- dilation=1,
116
- padding=get_padding(kernel_size, 1),
117
- )
118
- ),
119
- weight_norm(
120
- nn.Conv1d(
121
- channels,
122
- channels,
123
- kernel_size,
124
- 1,
125
- dilation=1,
126
- padding=get_padding(kernel_size, 1),
127
- )
128
- ),
129
- weight_norm(
130
- nn.Conv1d(
131
- channels,
132
- channels,
133
- kernel_size,
134
- 1,
135
- dilation=1,
136
- padding=get_padding(kernel_size, 1),
137
- )
138
- ),
139
- ]
140
- )
141
- self.convs2.apply(self.init_weights)
142
- self.activation = nn.LeakyReLU(0.1)
143
-
144
- def forward(self, x):
145
- for c1, c2 in zip(self.convs1, self.convs2):
146
- xt = c1(self.activation(x))
147
- xt = c2(self.activation(xt))
148
- x = xt + x
149
- return x
150
-
151
-
152
- class ResBlock2(ConvNets):
153
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
154
- super().__init__()
155
- self.convs = nn.ModuleList(
156
- [
157
- weight_norm(
158
- nn.Conv1d(
159
- channels,
160
- channels,
161
- kernel_size,
162
- 1,
163
- dilation=dilation[0],
164
- padding=get_padding(kernel_size, dilation[0]),
165
- )
166
- ),
167
- weight_norm(
168
- nn.Conv1d(
169
- channels,
170
- channels,
171
- kernel_size,
172
- 1,
173
- dilation=dilation[1],
174
- padding=get_padding(kernel_size, dilation[1]),
175
- )
176
- ),
177
- ]
178
- )
179
- self.convs.apply(self.init_weights)
180
- self.activation = nn.LeakyReLU(0.1)
181
-
182
- def forward(self, x):
183
- for c in self.convs:
184
- xt = c(self.activation(x))
185
- x = xt + x
186
- return x
187
-
188
-
189
69
  class HifiganGenerator(ConvNets):
190
70
  def __init__(self, cfg: HifiganConfig = HifiganConfig()):
191
71
  super().__init__()
@@ -219,7 +99,7 @@ class HifiganGenerator(ConvNets):
219
99
  for j, (k, d) in enumerate(
220
100
  zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes)
221
101
  ):
222
- self.resblocks.append(resblock(ch, k, d))
102
+ self.resblocks.append(resblock(ch, k, d, cfg.resblock_activation))
223
103
 
224
104
  self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
225
105
  self.ups.apply(self.init_weights)
@@ -237,9 +117,7 @@ class HifiganGenerator(ConvNets):
237
117
  xs += self.resblocks[i * self.num_kernels + j](x)
238
118
  x = xs / self.num_kernels
239
119
  x = self.conv_post(self.activation(x))
240
- x = torch.tanh(x)
241
-
242
- return x
120
+ return x.tanh()
243
121
 
244
122
  def load_weights(
245
123
  self,
@@ -252,7 +130,7 @@ class HifiganGenerator(ConvNets):
252
130
  **pickle_load_args,
253
131
  ):
254
132
  try:
255
- incompatible_keys = super().load_weights(
133
+ return super().load_weights(
256
134
  path,
257
135
  raise_if_not_exists,
258
136
  strict,
@@ -261,18 +139,6 @@ class HifiganGenerator(ConvNets):
261
139
  mmap,
262
140
  **pickle_load_args,
263
141
  )
264
- if incompatible_keys:
265
- self.remove_norms()
266
- incompatible_keys = super().load_weights(
267
- path,
268
- raise_if_not_exists,
269
- strict,
270
- assign,
271
- weights_only,
272
- mmap,
273
- **pickle_load_args,
274
- )
275
- return incompatible_keys
276
142
  except RuntimeError:
277
143
  self.remove_norms()
278
144
  return super().load_weights(
@@ -291,6 +157,7 @@ class HifiganGenerator(ConvNets):
291
157
  model_file: PathLike,
292
158
  model_config: Union[HifiganConfig, Dict[str, Any]],
293
159
  *,
160
+ remove_norms: bool = False,
294
161
  strict: bool = False,
295
162
  map_location: str = "cpu",
296
163
  weights_only: bool = False,
@@ -308,11 +175,11 @@ class HifiganGenerator(ConvNets):
308
175
  h = HifiganConfig(**model_config)
309
176
 
310
177
  model = cls(h)
178
+ if remove_norms:
179
+ model.remove_norms()
311
180
  try:
312
- incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
313
- if incompatible_keys:
314
- model.remove_norms()
315
- model.load_state_dict(model_state_dict, strict=strict)
181
+ model.load_state_dict(model_state_dict, strict=strict)
182
+ return model
316
183
  except RuntimeError:
317
184
  print(
318
185
  f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
@@ -320,201 +187,3 @@ class HifiganGenerator(ConvNets):
320
187
  model.remove_norms()
321
188
  model.load_state_dict(model_state_dict, strict=strict)
322
189
  return model
323
-
324
-
325
- class DiscriminatorP(ConvNets):
326
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
327
- super(DiscriminatorP, self).__init__()
328
- self.period = period
329
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
330
- self.convs = nn.ModuleList(
331
- [
332
- norm_f(
333
- nn.Conv2d(
334
- 1,
335
- 32,
336
- (kernel_size, 1),
337
- (stride, 1),
338
- padding=(get_padding(5, 1), 0),
339
- )
340
- ),
341
- norm_f(
342
- nn.Conv2d(
343
- 32,
344
- 128,
345
- (kernel_size, 1),
346
- (stride, 1),
347
- padding=(get_padding(5, 1), 0),
348
- )
349
- ),
350
- norm_f(
351
- nn.Conv2d(
352
- 128,
353
- 512,
354
- (kernel_size, 1),
355
- (stride, 1),
356
- padding=(get_padding(5, 1), 0),
357
- )
358
- ),
359
- norm_f(
360
- nn.Conv2d(
361
- 512,
362
- 1024,
363
- (kernel_size, 1),
364
- (stride, 1),
365
- padding=(get_padding(5, 1), 0),
366
- )
367
- ),
368
- norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
369
- ]
370
- )
371
- self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
372
- self.activation = nn.LeakyReLU(0.1)
373
-
374
- def forward(self, x):
375
- fmap = []
376
-
377
- # 1d to 2d
378
- b, c, t = x.shape
379
- if t % self.period != 0: # pad first
380
- n_pad = self.period - (t % self.period)
381
- x = F.pad(x, (0, n_pad), "reflect")
382
- t = t + n_pad
383
- x = x.view(b, c, t // self.period, self.period)
384
-
385
- for l in self.convs:
386
- x = l(x)
387
- x = self.activation(x)
388
- fmap.append(x)
389
- x = self.conv_post(x)
390
- fmap.append(x)
391
- x = torch.flatten(x, 1, -1)
392
-
393
- return x, fmap
394
-
395
-
396
- class MultiPeriodDiscriminator(ConvNets):
397
- def __init__(self):
398
- super(MultiPeriodDiscriminator, self).__init__()
399
- self.discriminators = nn.ModuleList(
400
- [
401
- DiscriminatorP(2),
402
- DiscriminatorP(3),
403
- DiscriminatorP(5),
404
- DiscriminatorP(7),
405
- DiscriminatorP(11),
406
- ]
407
- )
408
-
409
- def forward(self, y, y_hat):
410
- y_d_rs = []
411
- y_d_gs = []
412
- fmap_rs = []
413
- fmap_gs = []
414
- for i, d in enumerate(self.discriminators):
415
- y_d_r, fmap_r = d(y)
416
- y_d_g, fmap_g = d(y_hat)
417
- y_d_rs.append(y_d_r)
418
- fmap_rs.append(fmap_r)
419
- y_d_gs.append(y_d_g)
420
- fmap_gs.append(fmap_g)
421
-
422
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
423
-
424
-
425
- class DiscriminatorS(ConvNets):
426
- def __init__(self, use_spectral_norm=False):
427
- super(DiscriminatorS, self).__init__()
428
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
429
- self.convs = nn.ModuleList(
430
- [
431
- norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
432
- norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
433
- norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
434
- norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
435
- norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
436
- norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
437
- norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
438
- ]
439
- )
440
- self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
441
- self.activation = nn.LeakyReLU(0.1)
442
-
443
- def forward(self, x):
444
- fmap = []
445
- for l in self.convs:
446
- x = l(x)
447
- x = self.activation(x)
448
- fmap.append(x)
449
- x = self.conv_post(x)
450
- fmap.append(x)
451
- x = torch.flatten(x, 1, -1)
452
-
453
- return x, fmap
454
-
455
-
456
- class MultiScaleDiscriminator(ConvNets):
457
- def __init__(self):
458
- super(MultiScaleDiscriminator, self).__init__()
459
- self.discriminators = nn.ModuleList(
460
- [
461
- DiscriminatorS(use_spectral_norm=True),
462
- DiscriminatorS(),
463
- DiscriminatorS(),
464
- ]
465
- )
466
- self.meanpools = nn.ModuleList(
467
- [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
468
- )
469
-
470
- def forward(self, y, y_hat):
471
- y_d_rs = []
472
- y_d_gs = []
473
- fmap_rs = []
474
- fmap_gs = []
475
- for i, d in enumerate(self.discriminators):
476
- if i != 0:
477
- y = self.meanpools[i - 1](y)
478
- y_hat = self.meanpools[i - 1](y_hat)
479
- y_d_r, fmap_r = d(y)
480
- y_d_g, fmap_g = d(y_hat)
481
- y_d_rs.append(y_d_r)
482
- fmap_rs.append(fmap_r)
483
- y_d_gs.append(y_d_g)
484
- fmap_gs.append(fmap_g)
485
-
486
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
487
-
488
-
489
- def feature_loss(fmap_r, fmap_g):
490
- loss = 0
491
- for dr, dg in zip(fmap_r, fmap_g):
492
- for rl, gl in zip(dr, dg):
493
- loss += torch.mean(torch.abs(rl - gl))
494
-
495
- return loss * 2
496
-
497
-
498
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
499
- loss = 0
500
- r_losses = []
501
- g_losses = []
502
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
503
- r_loss = torch.mean((1 - dr) ** 2)
504
- g_loss = torch.mean(dg**2)
505
- loss += r_loss + g_loss
506
- r_losses.append(r_loss.item())
507
- g_losses.append(g_loss.item())
508
-
509
- return loss, r_losses, g_losses
510
-
511
-
512
- def generator_loss(disc_outputs):
513
- loss = 0
514
- gen_losses = []
515
- for dg in disc_outputs:
516
- l = torch.mean((1 - dg) ** 2)
517
- gen_losses.append(l)
518
- loss += l
519
-
520
- return loss, gen_losses