lt-tensor 0.0.1a34__py3-none-any.whl → 0.0.1a36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. lt_tensor/__init__.py +1 -1
  2. lt_tensor/losses.py +11 -7
  3. lt_tensor/lr_schedulers.py +147 -21
  4. lt_tensor/misc_utils.py +35 -42
  5. lt_tensor/model_zoo/activations/__init__.py +3 -0
  6. lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
  7. lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
  8. lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
  9. lt_tensor/model_zoo/audio_models/__init__.py +2 -2
  10. lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
  11. lt_tensor/model_zoo/audio_models/hifigan/__init__.py +22 -357
  12. lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
  13. lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
  14. lt_tensor/model_zoo/convs.py +21 -32
  15. lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
  16. lt_tensor/model_zoo/losses/CQT/transforms.py +336 -0
  17. lt_tensor/model_zoo/losses/CQT/utils.py +519 -0
  18. lt_tensor/model_zoo/losses/discriminators.py +375 -37
  19. lt_tensor/processors/audio.py +67 -57
  20. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/METADATA +1 -1
  21. lt_tensor-0.0.1a36.dist-info/RECORD +43 -0
  22. lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
  23. lt_tensor-0.0.1a34.dist-info/RECORD +0 -37
  24. /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
  25. /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
  26. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/WHEEL +0 -0
  27. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/licenses/LICENSE +0 -0
  28. {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,243 @@
1
+ from lt_utils.common import *
2
+ from lt_tensor.torch_commons import *
3
+ from lt_tensor.model_zoo.convs import ConvNets
4
+ from lt_tensor.config_templates import ModelConfig
5
+ from lt_tensor.model_zoo.activations import snake, alias_free
6
+ from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
7
+ from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
8
+
9
+
10
+ class BigVGANConfig(ModelConfig):
11
+ # Training params
12
+ in_channels: int = 80
13
+ upsample_rates: List[Union[int, List[int]]] = [4, 4, 2, 2, 2, 2]
14
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [8, 8, 4, 4, 4, 4]
15
+ upsample_initial_channel: int = 1536
16
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
17
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
18
+ [1, 3, 5],
19
+ [1, 3, 5],
20
+ [1, 3, 5],
21
+ ]
22
+
23
+ activation: Literal["snake", "snakebeta"] = "snakebeta"
24
+ resblock_activation: Literal["snake", "snakebeta"] = "snakebeta"
25
+ resblock: int = 0
26
+ use_bias_at_final: bool = True
27
+ use_tanh_at_final: bool = True
28
+ snake_logscale: bool = True
29
+
30
+ def __init__(
31
+ self,
32
+ in_channels: int = 80,
33
+ upsample_rates: List[Union[int, List[int]]] = [4, 4, 2, 2, 2, 2],
34
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [8, 8, 4, 4, 4, 4],
35
+ upsample_initial_channel: int = 1536,
36
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
37
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
38
+ [1, 3, 5],
39
+ [1, 3, 5],
40
+ [1, 3, 5],
41
+ ],
42
+ activation: Literal["snake", "snakebeta"] = "snakebeta",
43
+ resblock_activation: Literal["snake", "snakebeta"] = "snakebeta",
44
+ resblock: Union[int, str] = "1",
45
+ use_bias_at_final: bool = False,
46
+ use_tanh_at_final: bool = False,
47
+ *args,
48
+ **kwargs,
49
+ ):
50
+ settings = {
51
+ "in_channels": in_channels,
52
+ "upsample_rates": upsample_rates,
53
+ "upsample_kernel_sizes": upsample_kernel_sizes,
54
+ "upsample_initial_channel": upsample_initial_channel,
55
+ "resblock_kernel_sizes": resblock_kernel_sizes,
56
+ "resblock_dilation_sizes": resblock_dilation_sizes,
57
+ "activation": activation,
58
+ "resblock_activation": resblock_activation,
59
+ "resblock": resblock,
60
+ "use_bias_at_final": use_bias_at_final,
61
+ "use_tanh_at_final": use_tanh_at_final,
62
+ }
63
+ super().__init__(**settings)
64
+
65
+ def post_process(self):
66
+ if isinstance(self.resblock, str):
67
+ self.resblock = 0 if self.resblock == "1" else 1
68
+
69
+
70
+ class BigVGAN(ConvNets):
71
+ """Modified from 'https://github.com/NVIDIA/BigVGAN/blob/main/bigvgan.py' under mit license.
72
+
73
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
74
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
75
+
76
+ Args:
77
+ cfg (BigVGANConfig): Hyperparameters.
78
+
79
+ """
80
+
81
+ def __init__(self, cfg: BigVGANConfig):
82
+ super().__init__()
83
+ self.cfg = cfg
84
+ actv = get_snake(self.cfg.activation)
85
+
86
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
87
+
88
+ self.num_kernels = len(cfg.resblock_kernel_sizes)
89
+ self.num_upsamples = len(cfg.upsample_rates)
90
+
91
+ # Pre-conv
92
+ self.conv_pre = weight_norm(
93
+ nn.Conv1d(cfg.in_channels, cfg.upsample_initial_channel, 7, 1, padding=3)
94
+ )
95
+
96
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
97
+ resblock_class = AMPBlock1 if cfg.resblock == 0 else AMPBlock2
98
+
99
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
100
+ self.ups = nn.ModuleList()
101
+ for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
102
+ self.ups.append(
103
+ nn.ModuleList(
104
+ [
105
+ weight_norm(
106
+ nn.ConvTranspose1d(
107
+ cfg.upsample_initial_channel // (2**i),
108
+ cfg.upsample_initial_channel // (2 ** (i + 1)),
109
+ k,
110
+ u,
111
+ padding=(k - u) // 2,
112
+ )
113
+ )
114
+ ]
115
+ )
116
+ )
117
+
118
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
119
+ self.resblocks = nn.ModuleList()
120
+ for i in range(len(self.ups)):
121
+ ch = cfg.upsample_initial_channel // (2 ** (i + 1))
122
+ for k, d in zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes):
123
+ self.resblocks.append(
124
+ resblock_class(
125
+ ch,
126
+ k,
127
+ d,
128
+ snake_logscale=cfg.snake_logscale,
129
+ activation=cfg.resblock_activation,
130
+ )
131
+ )
132
+
133
+ # Post-conv
134
+ activation_post = actv(ch, alpha_logscale=cfg.snake_logscale)
135
+
136
+ self.activation_post = alias_free.Activation1d(activation=activation_post)
137
+
138
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
139
+ self.conv_post = weight_norm(
140
+ nn.Conv1d(ch, 1, 7, 1, padding=3, bias=self.cfg.use_bias_at_final)
141
+ )
142
+
143
+ # Weight initialization
144
+ for i in range(len(self.ups)):
145
+ self.ups[i].apply(self.init_weights)
146
+ self.conv_post.apply(self.init_weights)
147
+
148
+ # Final tanh activation. Defaults to True for backward compatibility
149
+ self.use_tanh_at_final = cfg.use_tanh_at_final
150
+
151
+ def forward(self, x):
152
+ # Pre-conv
153
+ x = self.conv_pre(x)
154
+
155
+ for i in range(self.num_upsamples):
156
+ # Upsampling
157
+ for i_up in range(len(self.ups[i])):
158
+ x = self.ups[i][i_up](x)
159
+ # AMP blocks
160
+ xs = None
161
+ for j in range(self.num_kernels):
162
+ if xs is None:
163
+ xs = self.resblocks[i * self.num_kernels + j](x)
164
+ else:
165
+ xs += self.resblocks[i * self.num_kernels + j](x)
166
+ x = xs / self.num_kernels
167
+
168
+ # Post-conv
169
+ x = self.activation_post(x)
170
+ x: Tensor = self.conv_post(x)
171
+ # Final tanh activation
172
+ if self.use_tanh_at_final:
173
+ return x.tanh()
174
+ return x.clamp(min=-1.0, max=1.0)
175
+
176
+ def load_weights(
177
+ self,
178
+ path,
179
+ strict=False,
180
+ assign=False,
181
+ weights_only=False,
182
+ mmap=None,
183
+ raise_if_not_exists=False,
184
+ **pickle_load_args,
185
+ ):
186
+ try:
187
+ return super().load_weights(
188
+ path,
189
+ raise_if_not_exists,
190
+ strict,
191
+ assign,
192
+ weights_only,
193
+ mmap,
194
+ **pickle_load_args,
195
+ )
196
+ except RuntimeError:
197
+ self.remove_norms()
198
+ return super().load_weights(
199
+ path,
200
+ raise_if_not_exists,
201
+ strict,
202
+ assign,
203
+ weights_only,
204
+ mmap,
205
+ **pickle_load_args,
206
+ )
207
+
208
+ @classmethod
209
+ def from_pretrained(
210
+ cls,
211
+ model_file: PathLike,
212
+ model_config: Union[BigVGANConfig, Dict[str, Any]],
213
+ *,
214
+ remove_norms: bool = False,
215
+ strict: bool = False,
216
+ map_location: str = "cpu",
217
+ weights_only: bool = False,
218
+ **kwargs,
219
+ ):
220
+
221
+ is_file(model_file, validate=True)
222
+ model_state_dict = torch.load(
223
+ model_file, weights_only=weights_only, map_location=map_location
224
+ )
225
+
226
+ if isinstance(model_config, BigVGANConfig):
227
+ h = model_config
228
+ else:
229
+ h = BigVGANConfig(**model_config)
230
+
231
+ model = cls(h)
232
+ if remove_norms:
233
+ model.remove_norms()
234
+ try:
235
+ model.load_state_dict(model_state_dict, strict=strict)
236
+ return model
237
+ except RuntimeError:
238
+ print(
239
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
240
+ )
241
+ model.remove_norms()
242
+ model.load_state_dict(model_state_dict, strict=strict)
243
+ return model
@@ -1,48 +1,45 @@
1
1
  __all__ = ["HifiganGenerator", "HifiganConfig"]
2
+
3
+
2
4
  from lt_utils.common import *
3
5
  from lt_tensor.torch_commons import *
4
6
  from lt_tensor.model_zoo.convs import ConvNets
5
- from torch.nn import functional as F
6
- from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
7
- from lt_tensor.misc_utils import get_config, get_weights
7
+ from lt_tensor.config_templates import ModelConfig
8
+ from lt_utils.file_ops import is_file
9
+ from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
8
10
 
9
11
 
10
12
  def get_padding(kernel_size, dilation=1):
11
13
  return int((kernel_size * dilation - dilation) / 2)
12
14
 
13
15
 
14
- from lt_tensor.config_templates import ModelConfig
15
-
16
-
17
16
  class HifiganConfig(ModelConfig):
18
17
  # Training params
19
18
  in_channels: int = 80
20
- upsample_rates: List[Union[int, List[int]]] = [8, 8]
21
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16]
19
+ upsample_rates: List[Union[int, List[int]]] = [8,8,2,2]
20
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4]
22
21
  upsample_initial_channel: int = 512
23
22
  resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
24
- resblock_dilation_sizes: List[Union[int, List[int]]] = [
25
- [1, 3, 5],
26
- [1, 3, 5],
27
- [1, 3, 5],
28
- ]
23
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [[1,3,5], [1,3,5], [1,3,5]]
29
24
 
30
25
  activation: nn.Module = nn.LeakyReLU(0.1)
26
+ resblock_activation: nn.Module = nn.LeakyReLU(0.1)
31
27
  resblock: int = 0
32
28
 
33
29
  def __init__(
34
30
  self,
35
31
  in_channels: int = 80,
36
- upsample_rates: List[Union[int, List[int]]] = [8, 8],
37
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
32
+ upsample_rates: List[Union[int, List[int]]] = [8,8,2,2],
33
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4],
38
34
  upsample_initial_channel: int = 512,
39
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
35
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3,7,11],
40
36
  resblock_dilation_sizes: List[Union[int, List[int]]] = [
41
37
  [1, 3, 5],
42
38
  [1, 3, 5],
43
39
  [1, 3, 5],
44
40
  ],
45
41
  activation: nn.Module = nn.LeakyReLU(0.1),
42
+ resblock_activation: nn.Module = nn.LeakyReLU(0.1),
46
43
  resblock: Union[int, str] = "1",
47
44
  *args,
48
45
  **kwargs,
@@ -55,6 +52,7 @@ class HifiganConfig(ModelConfig):
55
52
  "resblock_kernel_sizes": resblock_kernel_sizes,
56
53
  "resblock_dilation_sizes": resblock_dilation_sizes,
57
54
  "activation": activation,
55
+ "resblock_activation": resblock_activation,
58
56
  "resblock": resblock,
59
57
  }
60
58
  super().__init__(**settings)
@@ -64,128 +62,6 @@ class HifiganConfig(ModelConfig):
64
62
  self.resblock = 0 if self.resblock == "1" else 1
65
63
 
66
64
 
67
- class ResBlock1(ConvNets):
68
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
69
- super().__init__()
70
-
71
- self.convs1 = nn.ModuleList(
72
- [
73
- weight_norm(
74
- nn.Conv1d(
75
- channels,
76
- channels,
77
- kernel_size,
78
- 1,
79
- dilation=dilation[0],
80
- padding=get_padding(kernel_size, dilation[0]),
81
- )
82
- ),
83
- weight_norm(
84
- nn.Conv1d(
85
- channels,
86
- channels,
87
- kernel_size,
88
- 1,
89
- dilation=dilation[1],
90
- padding=get_padding(kernel_size, dilation[1]),
91
- )
92
- ),
93
- weight_norm(
94
- nn.Conv1d(
95
- channels,
96
- channels,
97
- kernel_size,
98
- 1,
99
- dilation=dilation[2],
100
- padding=get_padding(kernel_size, dilation[2]),
101
- )
102
- ),
103
- ]
104
- )
105
- self.convs1.apply(self.init_weights)
106
-
107
- self.convs2 = nn.ModuleList(
108
- [
109
- weight_norm(
110
- nn.Conv1d(
111
- channels,
112
- channels,
113
- kernel_size,
114
- 1,
115
- dilation=1,
116
- padding=get_padding(kernel_size, 1),
117
- )
118
- ),
119
- weight_norm(
120
- nn.Conv1d(
121
- channels,
122
- channels,
123
- kernel_size,
124
- 1,
125
- dilation=1,
126
- padding=get_padding(kernel_size, 1),
127
- )
128
- ),
129
- weight_norm(
130
- nn.Conv1d(
131
- channels,
132
- channels,
133
- kernel_size,
134
- 1,
135
- dilation=1,
136
- padding=get_padding(kernel_size, 1),
137
- )
138
- ),
139
- ]
140
- )
141
- self.convs2.apply(self.init_weights)
142
- self.activation = nn.LeakyReLU(0.1)
143
-
144
- def forward(self, x):
145
- for c1, c2 in zip(self.convs1, self.convs2):
146
- xt = c1(self.activation(x))
147
- xt = c2(self.activation(xt))
148
- x = xt + x
149
- return x
150
-
151
-
152
- class ResBlock2(ConvNets):
153
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
154
- super().__init__()
155
- self.convs = nn.ModuleList(
156
- [
157
- weight_norm(
158
- nn.Conv1d(
159
- channels,
160
- channels,
161
- kernel_size,
162
- 1,
163
- dilation=dilation[0],
164
- padding=get_padding(kernel_size, dilation[0]),
165
- )
166
- ),
167
- weight_norm(
168
- nn.Conv1d(
169
- channels,
170
- channels,
171
- kernel_size,
172
- 1,
173
- dilation=dilation[1],
174
- padding=get_padding(kernel_size, dilation[1]),
175
- )
176
- ),
177
- ]
178
- )
179
- self.convs.apply(self.init_weights)
180
- self.activation = nn.LeakyReLU(0.1)
181
-
182
- def forward(self, x):
183
- for c in self.convs:
184
- xt = c(self.activation(x))
185
- x = xt + x
186
- return x
187
-
188
-
189
65
  class HifiganGenerator(ConvNets):
190
66
  def __init__(self, cfg: HifiganConfig = HifiganConfig()):
191
67
  super().__init__()
@@ -219,7 +95,7 @@ class HifiganGenerator(ConvNets):
219
95
  for j, (k, d) in enumerate(
220
96
  zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes)
221
97
  ):
222
- self.resblocks.append(resblock(ch, k, d))
98
+ self.resblocks.append(resblock(ch, k, d, cfg.resblock_activation))
223
99
 
224
100
  self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
225
101
  self.ups.apply(self.init_weights)
@@ -237,9 +113,7 @@ class HifiganGenerator(ConvNets):
237
113
  xs += self.resblocks[i * self.num_kernels + j](x)
238
114
  x = xs / self.num_kernels
239
115
  x = self.conv_post(self.activation(x))
240
- x = torch.tanh(x)
241
-
242
- return x
116
+ return x.tanh()
243
117
 
244
118
  def load_weights(
245
119
  self,
@@ -252,7 +126,7 @@ class HifiganGenerator(ConvNets):
252
126
  **pickle_load_args,
253
127
  ):
254
128
  try:
255
- incompatible_keys = super().load_weights(
129
+ return super().load_weights(
256
130
  path,
257
131
  raise_if_not_exists,
258
132
  strict,
@@ -261,18 +135,6 @@ class HifiganGenerator(ConvNets):
261
135
  mmap,
262
136
  **pickle_load_args,
263
137
  )
264
- if incompatible_keys:
265
- self.remove_norms()
266
- incompatible_keys = super().load_weights(
267
- path,
268
- raise_if_not_exists,
269
- strict,
270
- assign,
271
- weights_only,
272
- mmap,
273
- **pickle_load_args,
274
- )
275
- return incompatible_keys
276
138
  except RuntimeError:
277
139
  self.remove_norms()
278
140
  return super().load_weights(
@@ -291,6 +153,7 @@ class HifiganGenerator(ConvNets):
291
153
  model_file: PathLike,
292
154
  model_config: Union[HifiganConfig, Dict[str, Any]],
293
155
  *,
156
+ remove_norms: bool = False,
294
157
  strict: bool = False,
295
158
  map_location: str = "cpu",
296
159
  weights_only: bool = False,
@@ -308,11 +171,11 @@ class HifiganGenerator(ConvNets):
308
171
  h = HifiganConfig(**model_config)
309
172
 
310
173
  model = cls(h)
174
+ if remove_norms:
175
+ model.remove_norms()
311
176
  try:
312
- incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
313
- if incompatible_keys:
314
- model.remove_norms()
315
- model.load_state_dict(model_state_dict, strict=strict)
177
+ model.load_state_dict(model_state_dict, strict=strict)
178
+ return model
316
179
  except RuntimeError:
317
180
  print(
318
181
  f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
@@ -320,201 +183,3 @@ class HifiganGenerator(ConvNets):
320
183
  model.remove_norms()
321
184
  model.load_state_dict(model_state_dict, strict=strict)
322
185
  return model
323
-
324
-
325
- class DiscriminatorP(ConvNets):
326
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
327
- super(DiscriminatorP, self).__init__()
328
- self.period = period
329
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
330
- self.convs = nn.ModuleList(
331
- [
332
- norm_f(
333
- nn.Conv2d(
334
- 1,
335
- 32,
336
- (kernel_size, 1),
337
- (stride, 1),
338
- padding=(get_padding(5, 1), 0),
339
- )
340
- ),
341
- norm_f(
342
- nn.Conv2d(
343
- 32,
344
- 128,
345
- (kernel_size, 1),
346
- (stride, 1),
347
- padding=(get_padding(5, 1), 0),
348
- )
349
- ),
350
- norm_f(
351
- nn.Conv2d(
352
- 128,
353
- 512,
354
- (kernel_size, 1),
355
- (stride, 1),
356
- padding=(get_padding(5, 1), 0),
357
- )
358
- ),
359
- norm_f(
360
- nn.Conv2d(
361
- 512,
362
- 1024,
363
- (kernel_size, 1),
364
- (stride, 1),
365
- padding=(get_padding(5, 1), 0),
366
- )
367
- ),
368
- norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
369
- ]
370
- )
371
- self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
372
- self.activation = nn.LeakyReLU(0.1)
373
-
374
- def forward(self, x):
375
- fmap = []
376
-
377
- # 1d to 2d
378
- b, c, t = x.shape
379
- if t % self.period != 0: # pad first
380
- n_pad = self.period - (t % self.period)
381
- x = F.pad(x, (0, n_pad), "reflect")
382
- t = t + n_pad
383
- x = x.view(b, c, t // self.period, self.period)
384
-
385
- for l in self.convs:
386
- x = l(x)
387
- x = self.activation(x)
388
- fmap.append(x)
389
- x = self.conv_post(x)
390
- fmap.append(x)
391
- x = torch.flatten(x, 1, -1)
392
-
393
- return x, fmap
394
-
395
-
396
- class MultiPeriodDiscriminator(ConvNets):
397
- def __init__(self):
398
- super(MultiPeriodDiscriminator, self).__init__()
399
- self.discriminators = nn.ModuleList(
400
- [
401
- DiscriminatorP(2),
402
- DiscriminatorP(3),
403
- DiscriminatorP(5),
404
- DiscriminatorP(7),
405
- DiscriminatorP(11),
406
- ]
407
- )
408
-
409
- def forward(self, y, y_hat):
410
- y_d_rs = []
411
- y_d_gs = []
412
- fmap_rs = []
413
- fmap_gs = []
414
- for i, d in enumerate(self.discriminators):
415
- y_d_r, fmap_r = d(y)
416
- y_d_g, fmap_g = d(y_hat)
417
- y_d_rs.append(y_d_r)
418
- fmap_rs.append(fmap_r)
419
- y_d_gs.append(y_d_g)
420
- fmap_gs.append(fmap_g)
421
-
422
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
423
-
424
-
425
- class DiscriminatorS(ConvNets):
426
- def __init__(self, use_spectral_norm=False):
427
- super(DiscriminatorS, self).__init__()
428
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
429
- self.convs = nn.ModuleList(
430
- [
431
- norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
432
- norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
433
- norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
434
- norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
435
- norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
436
- norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
437
- norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
438
- ]
439
- )
440
- self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
441
- self.activation = nn.LeakyReLU(0.1)
442
-
443
- def forward(self, x):
444
- fmap = []
445
- for l in self.convs:
446
- x = l(x)
447
- x = self.activation(x)
448
- fmap.append(x)
449
- x = self.conv_post(x)
450
- fmap.append(x)
451
- x = torch.flatten(x, 1, -1)
452
-
453
- return x, fmap
454
-
455
-
456
- class MultiScaleDiscriminator(ConvNets):
457
- def __init__(self):
458
- super(MultiScaleDiscriminator, self).__init__()
459
- self.discriminators = nn.ModuleList(
460
- [
461
- DiscriminatorS(use_spectral_norm=True),
462
- DiscriminatorS(),
463
- DiscriminatorS(),
464
- ]
465
- )
466
- self.meanpools = nn.ModuleList(
467
- [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
468
- )
469
-
470
- def forward(self, y, y_hat):
471
- y_d_rs = []
472
- y_d_gs = []
473
- fmap_rs = []
474
- fmap_gs = []
475
- for i, d in enumerate(self.discriminators):
476
- if i != 0:
477
- y = self.meanpools[i - 1](y)
478
- y_hat = self.meanpools[i - 1](y_hat)
479
- y_d_r, fmap_r = d(y)
480
- y_d_g, fmap_g = d(y_hat)
481
- y_d_rs.append(y_d_r)
482
- fmap_rs.append(fmap_r)
483
- y_d_gs.append(y_d_g)
484
- fmap_gs.append(fmap_g)
485
-
486
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
487
-
488
-
489
- def feature_loss(fmap_r, fmap_g):
490
- loss = 0
491
- for dr, dg in zip(fmap_r, fmap_g):
492
- for rl, gl in zip(dr, dg):
493
- loss += torch.mean(torch.abs(rl - gl))
494
-
495
- return loss * 2
496
-
497
-
498
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
499
- loss = 0
500
- r_losses = []
501
- g_losses = []
502
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
503
- r_loss = torch.mean((1 - dr) ** 2)
504
- g_loss = torch.mean(dg**2)
505
- loss += r_loss + g_loss
506
- r_losses.append(r_loss.item())
507
- g_losses.append(g_loss.item())
508
-
509
- return loss, r_losses, g_losses
510
-
511
-
512
- def generator_loss(disc_outputs):
513
- loss = 0
514
- gen_losses = []
515
- for dg in disc_outputs:
516
- l = torch.mean((1 - dg) ** 2)
517
- gen_losses.append(l)
518
- loss += l
519
-
520
- return loss, gen_losses