lt-tensor 0.0.1a26__py3-none-any.whl → 0.0.1a28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -80,6 +80,62 @@ class _Devices_Base(nn.Module):
80
80
  assert isinstance(device, (str, torch.device))
81
81
  self._device = torch.device(device) if isinstance(device, str) else device
82
82
 
83
+ def freeze_all(self, exclude: list[str] = []):
84
+ for name, module in self.named_modules():
85
+ if name in exclude or not hasattr(module, "requires_grad"):
86
+ continue
87
+ try:
88
+ self.freeze_module(module)
89
+ except:
90
+ pass
91
+
92
+ def unfreeze_all(self, exclude: list[str] = []):
93
+ for name, module in self.named_modules():
94
+ if name in exclude or not hasattr(module, "requires_grad"):
95
+ continue
96
+ try:
97
+ self.unfreeze_module(module)
98
+ except:
99
+ pass
100
+
101
+ def freeze_module(
102
+ self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
103
+ ):
104
+ self._change_gradient_state(module_or_name, False)
105
+
106
+ def unfreeze_module(
107
+ self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
108
+ ):
109
+ self._change_gradient_state(module_or_name, True)
110
+
111
+ def _change_gradient_state(
112
+ self,
113
+ module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor],
114
+ new_state: bool, # True = unfreeze
115
+ ):
116
+ assert isinstance(
117
+ module_or_name, (str, nn.Module, nn.Parameter, Model, Tensor)
118
+ ), f"Item '{module_or_name}' is not a valid module, parameter, tensor or a string."
119
+ if isinstance(module_or_name, (nn.Module, nn.Parameter, Model, Tensor)):
120
+ target = module_or_name
121
+ else:
122
+ target = getattr(self, module_or_name)
123
+
124
+ if isinstance(target, Tensor):
125
+ target.requires_grad = new_state
126
+ elif isinstance(target, nn.Parameter):
127
+ target.requires_grad = new_state
128
+ elif isinstance(target, Model):
129
+ target.freeze_all()
130
+ elif isinstance(target, nn.Module):
131
+ for param in target.parameters():
132
+ if hasattr(param, "requires_grad"):
133
+ param.requires_grad = new_state
134
+ else:
135
+ raise ValueError(
136
+ f"Item '{module_or_name}' is not a valid module, parameter or tensor."
137
+ )
138
+
83
139
  def _apply_device(self):
84
140
  """Add here components that are needed to have device applied to them,
85
141
  that usually the '.to()' function fails to apply
@@ -182,20 +238,12 @@ class Model(_Devices_Base, ABC):
182
238
  """
183
239
 
184
240
  _autocast: bool = False
185
- _is_unfrozen: bool = False
186
- # list with modules that can be frozen or unfrozen
187
- registered_freezable_modules: List[str] = []
188
- is_frozen: bool = False
189
- _can_be_frozen: bool = (
190
- False # to control if the module can or cannot be freezed by other modules from 'Model' class
191
- )
241
+
192
242
  # this is to be used on the case of they module requires low-rank adapters
193
243
  _low_rank_lambda: Optional[Callable[[], nn.Module]] = (
194
244
  None # Example: lambda: nn.Linear(32, 32, True)
195
245
  )
196
246
  low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
197
- # never freeze:
198
- _never_freeze_modules: List[str] = ["low_rank_adapter"]
199
247
 
200
248
  # dont save list:
201
249
  _dont_save_items: List[str] = []
@@ -208,75 +256,6 @@ class Model(_Devices_Base, ABC):
208
256
  def autocast(self, value: bool):
209
257
  self._autocast = value
210
258
 
211
- def freeze_all(self, exclude: Optional[List[str]] = None, force: bool = False):
212
- no_exclusions = not exclude
213
- no_exclusions = not exclude
214
- results = []
215
- for name, module in self.named_modules():
216
- if (
217
- name in self._never_freeze_modules
218
- or not force
219
- and name not in self.registered_freezable_modules
220
- ):
221
- results.append(
222
- (
223
- name,
224
- "Unregistered module, to freeze/unfreeze it add its name into 'registered_freezable_modules'.",
225
- )
226
- )
227
- continue
228
- if no_exclusions:
229
- self.change_frozen_state(True, module)
230
- elif not any(exclusion in name for exclusion in exclude):
231
- results.append((name, self.change_frozen_state(True, module)))
232
- else:
233
- results.append((name, "excluded"))
234
- return results
235
-
236
- def unfreeze_all(self, exclude: Optional[list[str]] = None, force: bool = False):
237
- """Unfreezes all model parameters except specified layers."""
238
- no_exclusions = not exclude
239
- results = []
240
- for name, module in self.named_modules():
241
- if (
242
- name in self._never_freeze_modules
243
- or not force
244
- and name not in self.registered_freezable_modules
245
- ):
246
-
247
- results.append(
248
- (
249
- name,
250
- "Unregistered module, to freeze/unfreeze it add it into 'registered_freezable_modules'.",
251
- )
252
- )
253
- continue
254
- if no_exclusions:
255
- self.change_frozen_state(False, module)
256
- elif not any(exclusion in name for exclusion in exclude):
257
- results.append((name, self.change_frozen_state(False, module)))
258
- else:
259
- results.append((name, "excluded"))
260
- return results
261
-
262
- def change_frozen_state(self, freeze: bool, module: nn.Module):
263
- assert isinstance(module, nn.Module)
264
- if module.__class__.__name__ in self._never_freeze_modules:
265
- return "Not Allowed"
266
- try:
267
- if isinstance(module, Model):
268
- if module._can_be_frozen:
269
- if freeze:
270
- return module.freeze_all()
271
- return module.unfreeze_all()
272
- else:
273
- return "Not Allowed"
274
- else:
275
- module.requires_grad_(not freeze)
276
- return not freeze
277
- except Exception as e:
278
- return e
279
-
280
259
  def trainable_parameters(self, module_name: Optional[str] = None):
281
260
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
282
261
  if module_name is not None:
@@ -9,6 +9,7 @@ __all__ = [
9
9
  "audio_models",
10
10
  "hifigan",
11
11
  "istft",
12
+ "losses",
12
13
  ]
13
14
  from .audio_models import hifigan, istft
14
15
  from . import (
@@ -19,4 +20,5 @@ from . import (
19
20
  pos_encoder,
20
21
  residual,
21
22
  transformer,
23
+ losses,
22
24
  )
@@ -0,0 +1,3 @@
1
+ from . import discriminators
2
+
3
+ __all__ = ["discriminators"]
@@ -0,0 +1,610 @@
1
+ from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
2
+ from lt_utils.common import *
3
+ from lt_tensor.torch_commons import *
4
+ from lt_tensor.model_base import Model
5
+ from lt_tensor.model_zoo.convs import ConvNets
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+
9
+ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
10
+ List[Tensor],
11
+ List[Tensor],
12
+ List[List[Tensor]],
13
+ List[List[Tensor]],
14
+ ]
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class MultiDiscriminatorWrapper(ConvNets):
22
+ """Base for all multi-steps type of discriminators"""
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.leaky_relu = nn.LeakyReLU(kwargs.get("negative_slope", 0.1))
26
+
27
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
28
+ pass
29
+
30
+ # for type hinting
31
+ def __call__(self, *args, **kwds) -> MULTI_DISC_OUT_TYPE:
32
+ return super().__call__(*args, **kwds)
33
+
34
+ def gen_step(self, y: Tensor, y_hat: Tensor) -> tuple[Tensor, Tensor, List[float]]:
35
+ """For generator loss step [feature loss, generator loss, list of generator losses (float)]"""
36
+ _, y_hat_gen, feat_map_real, feat_map_gen = self.train_step(y, y_hat)
37
+ loss_feat = self.feature_loss(feat_map_real, feat_map_gen)
38
+ loss_generator, losses_gen_s = self.generator_loss(y_hat_gen)
39
+ return loss_feat, loss_generator, losses_gen_s
40
+
41
+ def disc_step(
42
+ self, y: Tensor, y_hat: Tensor
43
+ ) -> tuple[Tensor, tuple[List[float], List[float]]]:
44
+ """For discriminator loss step [discriminator loss, (disc losses real, disc losses generated)]"""
45
+ y_hat_real, y_hat_gen, _, _ = self.train_step(y, y_hat)
46
+
47
+ loss_disc, losses_disc_real, losses_disc_generated = self.discriminator_loss(
48
+ y_hat_real, y_hat_gen
49
+ )
50
+ return loss_disc, (losses_disc_real, losses_disc_generated)
51
+
52
+ @staticmethod
53
+ def discriminator_loss(
54
+ disc_real_outputs, disc_generated_outputs
55
+ ) -> Tuple[Tensor, List[float], List[float]]:
56
+ loss = 0
57
+ r_losses = []
58
+ g_losses = []
59
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
60
+ r_loss = torch.mean((1 - dr) ** 2)
61
+ g_loss = torch.mean(dg**2)
62
+ loss += r_loss + g_loss
63
+ r_losses.append(r_loss.item())
64
+ g_losses.append(g_loss.item())
65
+
66
+ return loss, r_losses, g_losses
67
+
68
+ @staticmethod
69
+ def feature_loss(fmap_r, fmap_g) -> Tensor:
70
+ loss = 0
71
+ for dr, dg in zip(fmap_r, fmap_g):
72
+ for rl, gl in zip(dr, dg):
73
+ loss += torch.mean(torch.abs(rl - gl))
74
+
75
+ return loss * 2
76
+
77
+ @staticmethod
78
+ def generator_loss(disc_outputs) -> Tuple[Tensor, List[float]]:
79
+ loss = 0
80
+ gen_losses = []
81
+ for dg in disc_outputs:
82
+ l = torch.mean((1 - dg) ** 2)
83
+ gen_losses.append(l.item())
84
+ loss += l
85
+
86
+ return loss, gen_losses
87
+
88
+
89
+ class DiscriminatorP(ConvNets):
90
+ def __init__(
91
+ self,
92
+ period: List[int],
93
+ discriminator_channel_mult: Number = 1,
94
+ kernel_size: int = 5,
95
+ stride: int = 3,
96
+ use_spectral_norm: bool = False,
97
+ ):
98
+ super().__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
101
+ dsc = lambda x: int(x * discriminator_channel_mult)
102
+ self.convs = nn.ModuleList(
103
+ [
104
+ norm_f(
105
+ nn.Conv2d(
106
+ 1,
107
+ dsc(32),
108
+ (kernel_size, 1),
109
+ (stride, 1),
110
+ padding=(get_padding(5, 1), 0),
111
+ )
112
+ ),
113
+ norm_f(
114
+ nn.Conv2d(
115
+ dsc(32),
116
+ dsc(128),
117
+ (kernel_size, 1),
118
+ (stride, 1),
119
+ padding=(get_padding(5, 1), 0),
120
+ )
121
+ ),
122
+ norm_f(
123
+ nn.Conv2d(
124
+ dsc(128),
125
+ dsc(512),
126
+ (kernel_size, 1),
127
+ (stride, 1),
128
+ padding=(get_padding(5, 1), 0),
129
+ )
130
+ ),
131
+ norm_f(
132
+ nn.Conv2d(
133
+ dsc(512),
134
+ dsc(1024),
135
+ (kernel_size, 1),
136
+ (stride, 1),
137
+ padding=(get_padding(5, 1), 0),
138
+ )
139
+ ),
140
+ norm_f(
141
+ nn.Conv2d(
142
+ dsc(1024),
143
+ dsc(1024),
144
+ (kernel_size, 1),
145
+ 1,
146
+ padding=(2, 0),
147
+ )
148
+ ),
149
+ ]
150
+ )
151
+ self.conv_post = norm_f(nn.Conv2d(dsc(1024), 1, (3, 1), 1, padding=(1, 0)))
152
+
153
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
154
+ fmap = []
155
+
156
+ # 1d to 2d
157
+ b, c, t = x.shape
158
+ if t % self.period != 0: # pad first
159
+ n_pad = self.period - (t % self.period)
160
+ x = F.pad(x, (0, n_pad), "reflect")
161
+ t = t + n_pad
162
+ x = x.view(b, c, t // self.period, self.period)
163
+
164
+ for l in self.convs:
165
+ x = l(x)
166
+ x = F.leaky_relu(x, 0.1)
167
+ fmap.append(x)
168
+ x = self.conv_post(x)
169
+ fmap.append(x)
170
+ return x.flatten(1, -1), fmap
171
+
172
+
173
+ class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
174
+ def __init__(
175
+ self,
176
+ discriminator_channel_mult: Number = 1,
177
+ mpd_reshapes: list[int] = [2, 3, 5, 7, 11],
178
+ use_spectral_norm: bool = False,
179
+ ):
180
+ super().__init__()
181
+ self.mpd_reshapes = mpd_reshapes
182
+ print(f"mpd_reshapes: {self.mpd_reshapes}")
183
+ self.discriminators = nn.ModuleList(
184
+ [
185
+ DiscriminatorP(
186
+ rs,
187
+ use_spectral_norm=use_spectral_norm,
188
+ discriminator_channel_mult=discriminator_channel_mult,
189
+ )
190
+ for rs in self.mpd_reshapes
191
+ ]
192
+ )
193
+
194
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> MULTI_DISC_OUT_TYPE:
195
+ y_d_rs = []
196
+ y_d_gs = []
197
+ fmap_rs = []
198
+ fmap_gs = []
199
+ for i, d in enumerate(self.discriminators):
200
+ y_d_r, fmap_r = d(y)
201
+ y_d_g, fmap_g = d(y_hat)
202
+ y_d_rs.append(y_d_r)
203
+ fmap_rs.append(fmap_r)
204
+ y_d_gs.append(y_d_g)
205
+ fmap_gs.append(fmap_g)
206
+
207
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
208
+
209
+
210
+ class EnvelopeExtractor(nn.Module):
211
+ """Extracts the amplitude envelope of the audio signal."""
212
+
213
+ def __init__(self, kernel_size=101):
214
+ super().__init__()
215
+ # Lowpass filter for smoothing envelope (moving average)
216
+ self.kernel_size = kernel_size
217
+ self.register_buffer("kernel", torch.ones(1, 1, kernel_size) / kernel_size)
218
+
219
+ def forward(self, x):
220
+ # x: (B, 1, T) -> abs(x)
221
+ envelope = torch.abs(x)
222
+ # Apply low-pass smoothing (via conv1d)
223
+ envelope = F.pad(
224
+ envelope, (self.kernel_size // 2, self.kernel_size // 2), mode="reflect"
225
+ )
226
+ envelope = F.conv1d(envelope, self.kernel)
227
+ return envelope
228
+
229
+
230
+ class DiscriminatorEnvelope(ConvNets):
231
+ def __init__(self, use_spectral_norm=False):
232
+ super().__init__()
233
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
234
+ self.extractor = EnvelopeExtractor(kernel_size=101)
235
+ self.convs = nn.ModuleList(
236
+ [
237
+ norm_f(nn.Conv1d(1, 64, 15, stride=1, padding=7)),
238
+ norm_f(nn.Conv1d(64, 128, 41, stride=2, groups=4, padding=20)),
239
+ norm_f(nn.Conv1d(128, 256, 41, stride=2, groups=16, padding=20)),
240
+ norm_f(nn.Conv1d(256, 512, 41, stride=4, groups=16, padding=20)),
241
+ norm_f(nn.Conv1d(512, 512, 41, stride=4, groups=16, padding=20)),
242
+ norm_f(nn.Conv1d(512, 512, 5, stride=1, padding=2)),
243
+ ]
244
+ )
245
+ self.conv_post = norm_f(nn.Conv1d(512, 1, 3, stride=1, padding=1))
246
+ self.activation = nn.LeakyReLU(0.1)
247
+
248
+ def forward(self, x):
249
+ # Input: raw audio (B, 1, T)
250
+ x = self.extractor(x)
251
+ fmap = []
252
+ for layer in self.convs:
253
+ x = self.activation(layer(x))
254
+ fmap.append(x)
255
+ x = self.conv_post(x)
256
+ fmap.append(x)
257
+ return x.flatten(1), fmap
258
+
259
+
260
+ class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
261
+ def __init__(self, use_spectral_norm: bool = False):
262
+ super().__init__()
263
+ self.discriminators = nn.ModuleList(
264
+ [
265
+ DiscriminatorEnvelope(use_spectral_norm), # raw envelope
266
+ DiscriminatorEnvelope(use_spectral_norm), # downsampled once
267
+ DiscriminatorEnvelope(use_spectral_norm), # downsampled twice
268
+ ]
269
+ )
270
+ self.meanpools = nn.ModuleList(
271
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
272
+ )
273
+
274
+ def forward(self, y, y_hat):
275
+ y_d_rs, y_d_gs = [], []
276
+ fmap_rs, fmap_gs = [], []
277
+
278
+ for i, d in enumerate(self.discriminators):
279
+ if i != 0:
280
+ y = self.meanpools[i - 1](y)
281
+ y_hat = self.meanpools[i - 1](y_hat)
282
+ y_d_r, fmap_r = d(y)
283
+ y_d_g, fmap_g = d(y_hat)
284
+ y_d_rs.append(y_d_r)
285
+ y_d_gs.append(y_d_g)
286
+ fmap_rs.append(fmap_r)
287
+ fmap_gs.append(fmap_g)
288
+
289
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
290
+
291
+
292
+ class DiscriminatorB(ConvNets):
293
+ """
294
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
295
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ window_length: int,
301
+ channels: int = 32,
302
+ hop_factor: float = 0.25,
303
+ bands: Tuple[Tuple[float, float], ...] = (
304
+ (0.0, 0.1),
305
+ (0.1, 0.25),
306
+ (0.25, 0.5),
307
+ (0.5, 0.75),
308
+ (0.75, 1.0),
309
+ ),
310
+ ):
311
+ super().__init__()
312
+ self.window_length = window_length
313
+ self.hop_factor = hop_factor
314
+ self.spec_fn = T.Spectrogram(
315
+ n_fft=window_length,
316
+ hop_length=int(window_length * hop_factor),
317
+ win_length=window_length,
318
+ power=None,
319
+ )
320
+ n_fft = window_length // 2 + 1
321
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
322
+ self.bands = bands
323
+ convs = lambda: nn.ModuleList(
324
+ [
325
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
326
+ weight_norm(
327
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
328
+ ),
329
+ weight_norm(
330
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
331
+ ),
332
+ weight_norm(
333
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
334
+ ),
335
+ weight_norm(
336
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
337
+ ),
338
+ ]
339
+ )
340
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
341
+
342
+ self.conv_post = weight_norm(
343
+ nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
344
+ )
345
+
346
+ def spectrogram(self, x: Tensor) -> List[Tensor]:
347
+ # Remove DC offset
348
+ x = x - x.mean(dim=-1, keepdims=True)
349
+ # Peak normalize the volume of input audio
350
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
351
+ x = self.spec_fn(x)
352
+ x = torch.view_as_real(x)
353
+ x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
354
+ # Split into bands
355
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
356
+ return x_bands
357
+
358
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
359
+ x_bands = self.spectrogram(x.squeeze(1))
360
+ fmap = []
361
+ x = []
362
+
363
+ for band, stack in zip(x_bands, self.band_convs):
364
+ for i, layer in enumerate(stack):
365
+ band = layer(band)
366
+ band = torch.nn.functional.leaky_relu(band, 0.1)
367
+ if i > 0:
368
+ fmap.append(band)
369
+ x.append(band)
370
+
371
+ x = torch.cat(x, dim=-1)
372
+ x = self.conv_post(x)
373
+ fmap.append(x)
374
+
375
+ return x, fmap
376
+
377
+
378
+ class MultiBandDiscriminator(MultiDiscriminatorWrapper):
379
+ """
380
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
381
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ mbd_fft_sizes: list[int] = [2048, 1024, 512],
387
+ ):
388
+ super().__init__()
389
+ self.fft_sizes = mbd_fft_sizes
390
+ self.discriminators = nn.ModuleList(
391
+ [DiscriminatorB(window_length=w) for w in self.fft_sizes]
392
+ )
393
+
394
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
395
+
396
+ y_d_rs = []
397
+ y_d_gs = []
398
+ fmap_rs = []
399
+ fmap_gs = []
400
+
401
+ for d in self.discriminators:
402
+
403
+ y_d_r, fmap_r = d(x=y)
404
+ y_d_g, fmap_g = d(x=y_hat)
405
+ y_d_rs.append(y_d_r)
406
+ fmap_rs.append(fmap_r)
407
+ y_d_gs.append(y_d_g)
408
+ fmap_gs.append(fmap_g)
409
+
410
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
411
+
412
+
413
+ class DiscriminatorR(ConvNets):
414
+ def __init__(
415
+ self,
416
+ resolution: List[int],
417
+ use_spectral_norm: bool = False,
418
+ discriminator_channel_mult: int = 1,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.resolution = resolution
423
+ assert (
424
+ len(self.resolution) == 3
425
+ ), f"MRD layer requires list with len=3, got {self.resolution}"
426
+ self.lrelu_slope = 0.1
427
+
428
+ self.register_buffer("window", torch.hann_window(self.resolution[-1]))
429
+
430
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
431
+
432
+ self.convs = nn.ModuleList(
433
+ [
434
+ norm_f(
435
+ nn.Conv2d(
436
+ 1, int(32 * discriminator_channel_mult), (3, 9), padding=(1, 4)
437
+ )
438
+ ),
439
+ norm_f(
440
+ nn.Conv2d(
441
+ int(32 * discriminator_channel_mult),
442
+ int(32 * discriminator_channel_mult),
443
+ (3, 9),
444
+ stride=(1, 2),
445
+ padding=(1, 4),
446
+ )
447
+ ),
448
+ norm_f(
449
+ nn.Conv2d(
450
+ int(32 * discriminator_channel_mult),
451
+ int(32 * discriminator_channel_mult),
452
+ (3, 9),
453
+ stride=(1, 2),
454
+ padding=(1, 4),
455
+ )
456
+ ),
457
+ norm_f(
458
+ nn.Conv2d(
459
+ int(32 * discriminator_channel_mult),
460
+ int(32 * discriminator_channel_mult),
461
+ (3, 9),
462
+ stride=(1, 2),
463
+ padding=(1, 4),
464
+ )
465
+ ),
466
+ norm_f(
467
+ nn.Conv2d(
468
+ int(32 * discriminator_channel_mult),
469
+ int(32 * discriminator_channel_mult),
470
+ (3, 3),
471
+ padding=(1, 1),
472
+ )
473
+ ),
474
+ ]
475
+ )
476
+ self.conv_post = norm_f(
477
+ nn.Conv2d(int(32 * discriminator_channel_mult), 1, (3, 3), padding=(1, 1))
478
+ )
479
+
480
+ def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
481
+ fmap = []
482
+ x = self.spectrogram(x)
483
+ x = x.unsqueeze(1)
484
+ for l in self.convs:
485
+ x = l(x)
486
+ x = F.leaky_relu(x, self.lrelu_slope)
487
+ fmap.append(x)
488
+ x = self.conv_post(x)
489
+ fmap.append(x)
490
+ x = torch.flatten(x, 1, -1)
491
+
492
+ return x, fmap
493
+
494
+ def spectrogram(self, x: Tensor) -> Tensor:
495
+ n_fft, hop_length, win_length = self.resolution
496
+ x = F.pad(
497
+ x,
498
+ (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
499
+ mode="reflect",
500
+ )
501
+ x = x.squeeze(1)
502
+ x = torch.stft(
503
+ x,
504
+ n_fft=n_fft,
505
+ hop_length=hop_length,
506
+ win_length=win_length,
507
+ center=False,
508
+ return_complex=True,
509
+ window=self.window,
510
+ )
511
+ x = torch.view_as_real(x) # [B, F, TT, 2]
512
+ mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
513
+
514
+ return mag
515
+
516
+
517
+ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
518
+ def __init__(
519
+ self,
520
+ use_spectral_norm: bool = False,
521
+ discriminator_channel_mult: int = 1,
522
+ resolutions: List[List[int]] = [
523
+ [1024, 120, 600],
524
+ [2048, 240, 1200],
525
+ [512, 50, 240],
526
+ ],
527
+ ):
528
+ super().__init__()
529
+ self.resolutions = resolutions
530
+ assert (
531
+ len(self.resolutions) == 3
532
+ ), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}, type: {type(self.resolutions)}"
533
+ self.discriminators = nn.ModuleList(
534
+ [
535
+ DiscriminatorR(
536
+ resolution, use_spectral_norm, discriminator_channel_mult
537
+ )
538
+ for resolution in self.resolutions
539
+ ]
540
+ )
541
+
542
+ def forward(self, y: Tensor, y_hat: Tensor) -> MULTI_DISC_OUT_TYPE:
543
+ y_d_rs = []
544
+ y_d_gs = []
545
+ fmap_rs = []
546
+ fmap_gs = []
547
+ for disc in self.discriminators:
548
+ y_d_r, fmap_r = disc(x=y)
549
+ y_d_g, fmap_g = disc(x=y_hat)
550
+ y_d_rs.append(y_d_r)
551
+ fmap_rs.append(fmap_r)
552
+ y_d_gs.append(y_d_g)
553
+ fmap_gs.append(fmap_g)
554
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
555
+
556
+
557
+ class MultiDiscriminatorStep(Model):
558
+ def __init__(self, list_discriminator: List[MultiDiscriminatorWrapper]):
559
+ super().__init__()
560
+ self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
561
+ list_discriminator
562
+ )
563
+ self.total = len(self.disc)
564
+
565
+ def forward(
566
+ self,
567
+ y: Tensor,
568
+ y_hat: Tensor,
569
+ step_type: Literal["discriminator", "generator"],
570
+ ) -> Union[
571
+ Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
572
+ ]:
573
+ """
574
+ It returns the content based on the choice of "step_type", being it a
575
+ 'discriminator' or 'generator'
576
+
577
+ For generator it returns:
578
+ Tuple[Tensor, Tensor, List[float]]
579
+ "gen_loss, feat_loss, all_g_losses"
580
+
581
+ For 'discriminator' it returns:
582
+ Tuple[Tensor, List[float], List[float]]
583
+ "disc_loss, disc_real_losses, disc_gen_losses"
584
+ """
585
+ if step_type == "generator":
586
+ all_g_losses: List[float] = []
587
+ feat_loss: Tensor = 0
588
+ gen_loss: Tensor = 0
589
+ else:
590
+ disc_loss: Tensor = 0
591
+ disc_real_losses: List[float] = []
592
+ disc_gen_losses: List[float] = []
593
+
594
+ for disc in self.disc:
595
+ if step_type == "generator":
596
+ # feature loss, generator loss, list of generator losses (float)]
597
+ f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
598
+ gen_loss += g_loss
599
+ feat_loss += f_loss
600
+ all_g_losses.extend(g_losses)
601
+ else:
602
+ # [discriminator loss, (disc losses real, disc losses generated)]
603
+ d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
604
+ disc_loss += d_loss
605
+ disc_real_losses.extend(d_real_losses)
606
+ disc_gen_losses.extend(d_gen_losses)
607
+
608
+ if step_type == "generator":
609
+ return gen_loss, feat_loss, all_g_losses
610
+ return disc_loss, disc_real_losses, disc_gen_losses
@@ -105,7 +105,6 @@ class AudioProcessor(Model):
105
105
  onesided=self.cfg.onesided,
106
106
  normalized=self.cfg.normalized,
107
107
  )
108
- self.griffin_lm_iters = 32
109
108
  self.mel_rscale = torchaudio.transforms.InverseMelScale(
110
109
  n_stft=self.cfg.n_stft,
111
110
  n_mels=self.cfg.n_mels,
@@ -114,21 +113,19 @@ class AudioProcessor(Model):
114
113
  f_max=self.cfg.f_max,
115
114
  mel_scale=self.cfg.mel_scale,
116
115
  )
117
- self.giffin_lim = torchaudio.transforms.GriffinLim(
118
- n_fft=self.cfg.n_fft,
119
- win_length=self.cfg.win_length,
120
- hop_length=self.cfg.hop_length,
121
- )
116
+
122
117
  self.register_buffer(
123
118
  "window",
124
119
  (torch.hann_window(self.cfg.win_length) if window is None else window),
125
120
  )
126
121
 
127
122
  def _apply_device(self):
128
- print(f"Audio Processor Device: {self.device.type}")
129
- self.giffin_lim.to(device=self.device)
130
123
  self._mel_spec.to(device=self.device)
131
124
  self.mel_rscale.to(device=self.device)
125
+ try:
126
+ self.window.to(device=self.device)
127
+ except:
128
+ pass
132
129
 
133
130
  def from_numpy(
134
131
  self,
@@ -173,7 +170,9 @@ class AudioProcessor(Model):
173
170
  )
174
171
 
175
172
  if audio is None and mel is not None:
176
- return self.from_numpy(librosa.feature.rms(S=mel, **rms_kwargs)[0])
173
+ return self.from_numpy(
174
+ librosa.feature.rms(S=mel, **rms_kwargs)[0]
175
+ ).squeeze()
177
176
  default_dtype = audio.dtype
178
177
  default_device = audio.device
179
178
  if audio.ndim > 1:
@@ -192,8 +191,12 @@ class AudioProcessor(Model):
192
191
  audio = self.to_numpy_safe(audio)
193
192
  if B == 1:
194
193
  if mel is None:
195
- return self.from_numpy(librosa.feature.rms(y=audio, **rms_kwargs)[0])
196
- return self.from_numpy(librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0])
194
+ return self.from_numpy(
195
+ librosa.feature.rms(y=audio, **rms_kwargs)[0]
196
+ ).squeeze()
197
+ return self.from_numpy(
198
+ librosa.feature.rms(y=audio, S=mel, **rms_kwargs)[0]
199
+ ).squeeze()
197
200
  else:
198
201
  rms_ = []
199
202
  for i in range(B):
@@ -201,7 +204,7 @@ class AudioProcessor(Model):
201
204
  0
202
205
  ]
203
206
  rms_.append(_r)
204
- return self.from_numpy_batch(rms_, default_device, default_dtype)
207
+ return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
205
208
 
206
209
  def compute_pitch(
207
210
  self,
@@ -250,7 +253,7 @@ class AudioProcessor(Model):
250
253
  for i in range(B):
251
254
  f0_.append(librosa.yin(self.to_numpy_safe(audio[i, :]), **yn_kwargs))
252
255
  f0 = self.from_numpy_batch(f0_, default_device, default_dtype)
253
- return f0
256
+ return f0.squeeze()
254
257
 
255
258
  def compute_pitch_torch(
256
259
  self,
@@ -273,7 +276,7 @@ class AudioProcessor(Model):
273
276
  win_length=win_length,
274
277
  freq_low=fmin,
275
278
  freq_high=fmax,
276
- )
279
+ ).squeeze()
277
280
 
278
281
  def interpolate(
279
282
  self,
@@ -312,7 +315,7 @@ class AudioProcessor(Model):
312
315
  antialias=antialias,
313
316
  )
314
317
 
315
- def inverse_transform(
318
+ def istft(
316
319
  self,
317
320
  spec: Tensor,
318
321
  phase: Tensor,
@@ -320,6 +323,10 @@ class AudioProcessor(Model):
320
323
  hop_length: Optional[int] = None,
321
324
  win_length: Optional[int] = None,
322
325
  length: Optional[int] = None,
326
+ center: Optional[bool] = None,
327
+ normalized: Optional[bool] = None,
328
+ onesided: Optional[bool] = None,
329
+ return_complex: bool = False,
323
330
  *,
324
331
  _recall: bool = False,
325
332
  ):
@@ -331,25 +338,25 @@ class AudioProcessor(Model):
331
338
  try:
332
339
  return torch.istft(
333
340
  spec * torch.exp(phase * 1j),
334
- n_fft=n_fft or self.cfg.n_fft,
335
- hop_length=hop_length or self.cfg.hop_length,
336
- win_length=win_length or self.cfg.win_length,
341
+ n_fft=default(n_fft, self.cfg.n_fft),
342
+ hop_length=default(hop_length, self.cfg.hop_length),
343
+ win_length=default(win_length, self.cfg.win_length),
337
344
  window=window,
338
- center=self.cfg.center,
339
- normalized=self.cfg.normalized,
340
- onesided=self.cfg.onesided,
345
+ center=default(center, self.cfg.center),
346
+ normalized=default(normalized, self.cfg.normalized),
347
+ onesided=default(onesided, self.cfg.onesided),
341
348
  length=length,
342
- return_complex=False,
349
+ return_complex=return_complex,
343
350
  )
344
351
  except RuntimeError as e:
345
352
  if not _recall and spec.device != self.window.device:
346
353
  self.window = self.window.to(spec.device)
347
- return self.inverse_transform(
354
+ return self.istft(
348
355
  spec, phase, n_fft, hop_length, win_length, length, _recall=True
349
356
  )
350
357
  raise e
351
358
 
352
- def normalize_audio(
359
+ def istft_norm(
353
360
  self,
354
361
  wave: Tensor,
355
362
  length: Optional[int] = None,
@@ -389,7 +396,7 @@ class AudioProcessor(Model):
389
396
  except RuntimeError as e:
390
397
  if not _recall and wave.device != self.window.device:
391
398
  self.window = self.window.to(wave.device)
392
- return self.normalize_audio(wave, length, _recall=True)
399
+ return self.istft_norm(wave, length, _recall=True)
393
400
  raise e
394
401
 
395
402
  def compute_mel(
@@ -407,11 +414,7 @@ class AudioProcessor(Model):
407
414
  mel_tensor = (
408
415
  torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
409
416
  ) / self.cfg.std
410
- if mel_tensor.ndim == 4:
411
- return mel_tensor.squeeze()
412
- elif mel_tensor.ndim == 2:
413
- return mel_tensor.unsqueeze(0)
414
- return mel_tensor
417
+ return mel_tensor.squeeze()
415
418
 
416
419
  except RuntimeError as e:
417
420
  if not _recall:
@@ -419,14 +422,6 @@ class AudioProcessor(Model):
419
422
  return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
420
423
  raise e
421
424
 
422
- def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
423
- if isinstance(n_iter, int) and n_iter != self.griffin_lm_iters:
424
- self.giffin_lim.n_iter = n_iter
425
- self.griffin_lm_iters = n_iter
426
- return self.giffin_lim.forward(
427
- self.mel_rscale(mel),
428
- )
429
-
430
425
  def load_audio(
431
426
  self,
432
427
  path: PathLike,
@@ -510,14 +505,9 @@ class AudioProcessor(Model):
510
505
  maximum,
511
506
  )
512
507
 
513
- def stft_loss(
514
- self,
515
- signal: Tensor,
516
- ground: Tensor,
517
- ):
518
- with torch.no_grad():
519
- ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
520
- return F.l1_loss(signal, ground)
508
+ def stft_loss(self, signal: Tensor, ground: Tensor, magnitude: float = 1.0):
509
+ ground = F.interpolate(ground, signal.shape[-1]).to(signal.device)
510
+ return F.l1_loss(signal.squeeze(), ground.squeeze()) * magnitude
521
511
 
522
512
  def forward(
523
513
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a26
3
+ Version: 0.0.1a28
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -4,12 +4,12 @@ lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
5
  lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
6
6
  lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
7
- lt_tensor/model_base.py,sha256=GvmQdt97ZSfOObBpBIq7UUTwpIE1g-aBm23za36YA0M,18431
7
+ lt_tensor/model_base.py,sha256=DTg44N6eTXLmpIAj_ac29-M5dI_iY_sC0yA_K3E13GI,17446
8
8
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
9
9
  lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
10
  lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
11
11
  lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
- lt_tensor/model_zoo/__init__.py,sha256=ltVTvmOlbOCfDc5Trvg0-Ta_Ujgkw0UVF9V5rqHx-RI,378
12
+ lt_tensor/model_zoo/__init__.py,sha256=yPUVchgVhU2nAJ2ocA4HFfG7IMEiBu8qOi8I1KWTTkU,404
13
13
  lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
14
14
  lt_tensor/model_zoo/convs.py,sha256=YQRxek75Qpsha8nfc7wLhmJS9XxPeCa4WxuftLg6IcE,3927
15
15
  lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
@@ -26,10 +26,12 @@ lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9
26
26
  lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH3tEgfAuM3ZHAWzimD6ElMqEQ,9073
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
+ lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
30
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=yYh7HzRTUtr0RVTG7cWpcYsJZsRCz6yzg6Loq8FtyOk,20405
29
31
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
30
- lt_tensor/processors/audio.py,sha256=WkumFNx8OXGQkTEU5Rkede9NLMrsGaTGY37Ti784Wv8,17028
31
- lt_tensor-0.0.1a26.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
32
- lt_tensor-0.0.1a26.dist-info/METADATA,sha256=2STSK6jgD_qECwz9WygTXNDwfapEAR2mpHiS14bi9tQ,1062
33
- lt_tensor-0.0.1a26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- lt_tensor-0.0.1a26.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
35
- lt_tensor-0.0.1a26.dist-info/RECORD,,
32
+ lt_tensor/processors/audio.py,sha256=rsnnNi8MtxPq9vAYoiRQ7lGjorfJIpRvrKEe3zA8YJk,16668
33
+ lt_tensor-0.0.1a28.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
34
+ lt_tensor-0.0.1a28.dist-info/METADATA,sha256=2LLguzaCAM2bcAdy_D66j4PS9Oh5PU3ZnA9qy7xcx0w,1062
35
+ lt_tensor-0.0.1a28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ lt_tensor-0.0.1a28.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
+ lt_tensor-0.0.1a28.dist-info/RECORD,,