lt-tensor 0.0.1a14__py3-none-any.whl → 0.0.1a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,7 @@ class WaveMelDataset(Dataset):
20
20
  randomize_ranges: bool = False
21
21
  alpha_wv: float = 1.0
22
22
  limit_files: Optional[int] = None
23
+ min_frame_length: Optional[int] = None
23
24
  max_frame_length: Optional[int] = None
24
25
 
25
26
  def __init__(
@@ -27,12 +28,13 @@ class WaveMelDataset(Dataset):
27
28
  audio_processor: AudioProcessor,
28
29
  dataset_path: PathLike,
29
30
  limit_files: Optional[int] = None,
31
+ min_frame_length: Optional[int] = None,
30
32
  max_frame_length: Optional[int] = None,
31
33
  randomize_ranges: Optional[bool] = None,
32
34
  pre_load: bool = False,
33
35
  normalize_waves: Optional[bool] = None,
34
36
  alpha_wv: Optional[float] = None,
35
- n_noises: int = 0, # TODO: Implement the random noises into the dataset
37
+ lib_norm: bool = True,
36
38
  ):
37
39
  super().__init__()
38
40
  assert max_frame_length is None or max_frame_length >= (
@@ -52,11 +54,15 @@ class WaveMelDataset(Dataset):
52
54
  self.randomize_ranges = randomize_ranges
53
55
 
54
56
  self.post_n_fft = (audio_processor.n_fft // 2) + 1
55
-
57
+ self.lib_norm = lib_norm
56
58
  if max_frame_length is not None:
57
59
  max_frame_length = max(self.post_n_fft + 1, max_frame_length)
58
60
  self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
59
61
  self.max_frame_length = max_frame_length
62
+ if min_frame_length is not None:
63
+ self.min_frame_length = max(
64
+ self.post_n_fft + 1, min(min_frame_length, max_frame_length)
65
+ )
60
66
 
61
67
  self.files = self.ap.find_audios(dataset_path, maximum=None)
62
68
  if limit_files:
@@ -96,21 +102,26 @@ class WaveMelDataset(Dataset):
96
102
  }
97
103
 
98
104
  def load_data(self, file: PathLike):
99
- initial_audio = self.ap.normalize_audio(
100
- self.ap.load_audio(
101
- file, normalize=self.normalize_waves, alpha=self.alpha_wv
102
- )
105
+ initial_audio = self.ap.load_audio(
106
+ file, normalize=self.lib_norm, alpha=self.alpha_wv
103
107
  )
108
+ if self.normalize_waves:
109
+ initial_audio = self.ap.normalize_audio(initial_audio)
104
110
  if initial_audio.shape[-1] < self.post_n_fft:
105
111
  return None
106
112
 
113
+ if self.min_frame_length is not None:
114
+ if self.min_frame_length > initial_audio.shape[-1]:
115
+ return None
107
116
  if (
108
117
  not self.max_frame_length
109
118
  or initial_audio.shape[-1] <= self.max_frame_length
110
119
  ):
120
+
111
121
  audio_rms = self.ap.compute_rms(initial_audio)
112
122
  audio_pitch = self.ap.compute_pitch(initial_audio)
113
123
  audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
124
+
114
125
  return [
115
126
  self._add_dict(initial_audio, audio_mel, audio_pitch, audio_rms, file)
116
127
  ]
@@ -129,6 +140,12 @@ class WaveMelDataset(Dataset):
129
140
  if fragment.shape[-1] < self.post_n_fft:
130
141
  # Too small
131
142
  continue
143
+ if (
144
+ self.min_frame_length is not None
145
+ and self.min_frame_length > fragment.shape[-1]
146
+ ):
147
+ continue
148
+
132
149
  audio_rms = self.ap.compute_rms(fragment)
133
150
  audio_pitch = self.ap.compute_pitch(fragment)
134
151
  audio_mel = self.ap.compute_mel(fragment, add_base=True)
lt_tensor/model_base.py CHANGED
@@ -179,6 +179,20 @@ class Model(_Devices_Base, ABC):
179
179
  """
180
180
 
181
181
  _is_unfrozen: bool = False
182
+ # list with modules that can be frozen or unfrozen
183
+ registered_freezable_modules: List[str] = []
184
+ is_frozen: bool = False
185
+ _is_gradient_freezable: bool = (
186
+ False # to control if the module can or cannot be freezed by other modules from 'Model' class
187
+ )
188
+ # this is to be used on the case of they module requires low-rank adapters
189
+ _low_rank_lambda: Optional[Callable[[], nn.Module]] = (
190
+ None # Example: lambda: nn.Linear(32, 32, True)
191
+ )
192
+ low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
193
+
194
+ # dont save list:
195
+ _dont_save_items: List[str] = []
182
196
 
183
197
  def _apply_device_to(self):
184
198
  """Add here components that are needed to have device applied to them,
@@ -192,116 +206,67 @@ class Model(_Devices_Base, ABC):
192
206
  """
193
207
  pass
194
208
 
195
- def freeze_weight(self, weight: Union[str, nn.Module], freeze: bool):
196
- assert isinstance(weight, (str, nn.Module))
197
- if isinstance(weight, str):
198
- if hasattr(self, weight):
199
- w = getattr(self, weight)
200
- if isinstance(w, nn.Module):
201
-
202
- w.requires_grad_(not freeze)
203
- else:
204
- weight.requires_grad_(not freeze)
205
-
206
- def _freeze_unfreeze(
207
- self,
208
- weight: Union[str, nn.Module],
209
- task: Literal["freeze", "unfreeze"] = "freeze",
210
- _skip_except: bool = False,
211
- ):
212
- try:
213
- assert isinstance(weight, (str, nn.Module))
214
- if isinstance(weight, str):
215
- w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a valid attribute of {self._get_name()}"
216
- if hasattr(self, weight):
217
- w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a Module type."
218
- w = getattr(self, weight)
219
- if isinstance(w, nn.Module):
220
- w_txt = f"Successfully {task} the module '{weight}'."
221
- w.requires_grad_(task == "unfreeze")
222
-
223
- else:
224
- w.requires_grad_(task == "unfreeze")
225
- w_txt = f"Successfully '{task}' the module '{weight}'."
226
- return w_txt
227
- except Exception as e:
228
- if not _skip_except:
229
- raise e
230
- return str(e)
231
-
232
- def freeze_weight(
233
- self,
234
- weight: Union[str, nn.Module],
235
- _skip_except: bool = False,
236
- ):
237
- return self._freeze_unfreeze(weight, "freeze", _skip_except)
238
-
239
- def unfreeze_weight(
240
- self,
241
- weight: Union[str, nn.Module],
242
- _skip_except: bool = False,
243
- ):
244
- return self._freeze_unfreeze(weight, "freeze", _skip_except)
245
-
246
209
  def freeze_all(self, exclude: Optional[List[str]] = None):
247
210
  no_exclusions = not exclude
248
- frozen = []
249
- not_frozen = []
250
- for name, param in self.named_parameters():
211
+ no_exclusions = not exclude
212
+ results = []
213
+ for name, module in self.named_modules():
214
+ if name not in self.registered_freezable_modules:
215
+ results.append(
216
+ (
217
+ name,
218
+ "Unregistered module, to freeze/unfreeze it add its name into 'registered_freezable_modules'.",
219
+ )
220
+ )
221
+ continue
251
222
  if no_exclusions:
252
- try:
253
- if param.requires_grad:
254
- param.requires_grad_(False)
255
- frozen.append(name)
256
- else:
257
- not_frozen.append((name, "was_frozen"))
258
- except Exception as e:
259
- not_frozen.append((name, str(e)))
260
- elif any(layer in name for layer in exclude):
261
- try:
262
- if param.requires_grad:
263
- param.requires_grad_(False)
264
- frozen.append(name)
265
- else:
266
- not_frozen.append((name, "was_frozen"))
267
- except Exception as e:
268
- not_frozen.append((name, str(e)))
223
+ self.change_frozen_state(True, module)
224
+ elif not any(exclusion in name for exclusion in exclude):
225
+ results.append((name, self.change_frozen_state(True, module)))
269
226
  else:
270
- not_frozen.append((name, "excluded"))
271
- return dict(frozen=frozen, not_frozen=not_frozen)
227
+ results.append((name, "excluded"))
228
+ return results
272
229
 
273
230
  def unfreeze_all(self, exclude: Optional[list[str]] = None):
274
231
  """Unfreezes all model parameters except specified layers."""
275
232
  no_exclusions = not exclude
276
- unfrozen = []
277
- not_unfrozen = []
278
- for name, param in self.named_parameters():
233
+ results = []
234
+ for name, module in self.named_modules():
235
+ if name not in self.registered_freezable_modules:
236
+ results.append(
237
+ (
238
+ name,
239
+ "Unregistered module, to freeze/unfreeze it add it into 'registered_freezable_modules'.",
240
+ )
241
+ )
242
+ continue
279
243
  if no_exclusions:
280
- try:
281
- if not param.requires_grad:
282
- param.requires_grad_(True)
283
- unfrozen.append(name)
284
- else:
285
- not_unfrozen.append((name, "was_unfrozen"))
286
- except Exception as e:
287
- not_unfrozen.append((name, str(e)))
288
- elif any(layer in name for layer in exclude):
289
- try:
290
- if not param.requires_grad:
291
- param.requires_grad_(True)
292
- unfrozen.append(name)
293
- else:
294
- not_unfrozen.append((name, "was_unfrozen"))
295
- except Exception as e:
296
- not_unfrozen.append((name, str(e)))
244
+ self.change_frozen_state(False, module)
245
+ elif not any(exclusion in name for exclusion in exclude):
246
+ results.append((name, self.change_frozen_state(False, module)))
297
247
  else:
298
- not_unfrozen.append((name, "excluded"))
299
- return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
248
+ results.append((name, "excluded"))
249
+ return results
250
+
251
+ def change_frozen_state(self, freeze: bool, module: nn.Module):
252
+ try:
253
+ if isinstance(module, Model):
254
+ if module._is_gradient_freezable:
255
+ if freeze:
256
+ return module.freeze_all()
257
+ return module.unfreeze_all()
258
+ else:
259
+ return "Not Allowed"
260
+ elif isinstance(module, nn.Module):
261
+ module.requires_grad_(not freeze)
262
+ return not freeze
263
+ except Exception as e:
264
+ return e
300
265
 
301
- def count_trainable_parameters(self, module_name: Optional[str] = None):
266
+ def trainable_parameters(self, module_name: Optional[str] = None):
302
267
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
303
268
  if module_name is not None:
304
- assert hasattr(self, module_name), f"Module {module_name} does not exits"
269
+ assert hasattr(self, module_name), f"Module '{module_name}' not found."
305
270
  module = getattr(self, module_name)
306
271
  return sum(
307
272
  [
@@ -318,10 +283,10 @@ class Model(_Devices_Base, ABC):
318
283
  ]
319
284
  )
320
285
 
321
- def count_non_trainable_parameters(self, module_name: Optional[str] = None):
286
+ def non_trainable_parameters(self, module_name: Optional[str] = None):
322
287
  """Gets the number of non-trainable parameters from either the entire model or from a specific module."""
323
288
  if module_name is not None:
324
- assert hasattr(self, module_name), f"Module {module_name} does not exits"
289
+ assert hasattr(self, module_name), f"Module '{module_name}' not found."
325
290
  module = getattr(self, module_name)
326
291
  return sum(
327
292
  [
@@ -338,10 +303,10 @@ class Model(_Devices_Base, ABC):
338
303
  ]
339
304
  )
340
305
 
341
- def get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
306
+ def extract_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
342
307
  """Returns the weights of the model entry model or from a specified module"""
343
308
  if module_name is not None:
344
- assert hasattr(self, module_name), f"Module {module_name} does not exits"
309
+ assert hasattr(self, module_name), f"Module '{module_name}' not found."
345
310
  module = getattr(self, module_name)
346
311
  params = []
347
312
  if isinstance(module, nn.Module):
@@ -351,27 +316,29 @@ class Model(_Devices_Base, ABC):
351
316
  raise (f"{module_name} is has no weights")
352
317
  return [x.data.detach() for x in self.parameters()]
353
318
 
354
- def print_trainable_parameters(
355
- self, module_name: Optional[str] = None
356
- ) -> List[Tensor]:
357
- params = format(self.count_trainable_parameters(module_name), ",").replace(
319
+ def format_trainable_parameters(self, module_name: Optional[str] = None) -> str:
320
+ params = format(self.trainable_parameters(module_name), ",").replace(",", ".")
321
+ return params
322
+
323
+ def format_non_trainable_parameters(self, module_name: Optional[str] = None) -> str:
324
+ params = format(self.non_trainable_parameters(module_name), ",").replace(
358
325
  ",", "."
359
326
  )
360
- if module_name:
361
- print(f'Trainable Parameters from "{module_name}": {params}')
327
+ return params
328
+
329
+ def print_trainable_parameters(self, module_name: Optional[str] = None) -> str:
330
+ fmt = self.format_trainable_parameters(module_name)
331
+ if module_name is not None:
332
+ print(f"Trainable parameter(s) for module '{module_name}': {fmt}")
362
333
  else:
363
- print(f"Trainable Parameters: {params}")
334
+ print(f"Trainable parameter(s): {fmt}")
364
335
 
365
- def print_non_trainable_parameters(
366
- self, module_name: Optional[str] = None
367
- ) -> List[Tensor]:
368
- params = format(self.count_non_trainable_parameters(module_name), ",").replace(
369
- ",", "."
370
- )
371
- if module_name:
372
- print(f'Non-Trainable Parameters from "{module_name}": {params}')
336
+ def print_non_trainable_parameters(self, module_name: Optional[str] = None) -> str:
337
+ fmt = self.format_non_trainable_parameters(module_name)
338
+ if module_name is not None:
339
+ print(f"Non-Trainable parameter(s) for module '{module_name}': {fmt}")
373
340
  else:
374
- print(f"Non-Trainable Parameters: {params}")
341
+ print(f"Non-Trainable parameter(s): {fmt}")
375
342
 
376
343
  @classmethod
377
344
  def from_pretrained(
@@ -390,6 +357,7 @@ class Model(_Devices_Base, ABC):
390
357
  self,
391
358
  path: Union[Path, str],
392
359
  replace: bool = False,
360
+ save_with_adapters: bool = False,
393
361
  ):
394
362
  path = Path(path)
395
363
  model_dir = path
@@ -405,17 +373,86 @@ class Model(_Devices_Base, ABC):
405
373
  if not "." in str(path):
406
374
  model_dir = Path(path, f"model_{get_current_time()}.pt")
407
375
  path.parent.mkdir(exist_ok=True, parents=True)
408
- torch.save(obj=self.state_dict(), f=str(model_dir))
376
+
377
+ state_dict = self.state_dict()
378
+ if not save_with_adapters or isinstance(self.low_rank_adapter, nn.Identity):
379
+ state_dict.pop("low_rank_adapter", None)
380
+ torch.save(obj=state_dict, f=str(model_dir))
381
+
382
+ def save_lora(
383
+ self,
384
+ path: Union[Path, str],
385
+ replace: bool = False,
386
+ ):
387
+ assert not isinstance(
388
+ self.low_rank_adapter, nn.Identity
389
+ ), "The adapter is empty!"
390
+ path = Path(path)
391
+ model_dir = path
392
+ if path.exists():
393
+ if path.is_dir():
394
+ model_dir = Path(path, f"adapter_{get_current_time()}.pt")
395
+ elif path.is_file():
396
+ if replace:
397
+ path.unlink()
398
+ else:
399
+ model_dir = Path(path.parent, f"adapter_{get_current_time()}.pt")
400
+ else:
401
+ if not "." in str(path):
402
+ model_dir = Path(path, f"adapter_{get_current_time()}.pt")
403
+
404
+ state_dict = self.low_rank_adapter.state_dict()
405
+ torch.save(obj=state_dict, f=str(model_dir))
406
+
407
+ def load_lora(
408
+ self,
409
+ path: Union[Path, str],
410
+ raise_if_not_exists: bool = False,
411
+ strict: bool = False,
412
+ assign: bool = False,
413
+ weights_only: bool = True,
414
+ mmap: Optional[bool] = None,
415
+ **pickle_load_args,
416
+ ):
417
+ assert (
418
+ self._low_rank_lambda is not None
419
+ ), "Lora not implemented! '_low_rank_lambda' must be setup to deploy a proper module"
420
+ path = Path(path)
421
+ if not path.exists():
422
+ assert not raise_if_not_exists, "Path does not exists!"
423
+ return None
424
+
425
+ if path.is_dir():
426
+ possible_files = list(Path(path).rglob("adapter_*.pt"))
427
+ assert (
428
+ possible_files or not raise_if_not_exists
429
+ ), "No model could be found in the given path!"
430
+ if not possible_files:
431
+ return None
432
+ path = sorted(possible_files)[-1]
433
+
434
+ state_dict = torch.load(
435
+ str(path), weights_only=weights_only, mmap=mmap, **pickle_load_args
436
+ )
437
+ self.low_rank_adapter = None
438
+ gc.collect()
439
+ self.low_rank_adapter = self._low_rank_lambda()
440
+ incompatible_keys = self.low_rank_adapter.load_state_dict(
441
+ state_dict,
442
+ strict=strict,
443
+ assign=assign,
444
+ )
445
+ return incompatible_keys
409
446
 
410
447
  def load_weights(
411
448
  self,
412
449
  path: Union[Path, str],
413
450
  raise_if_not_exists: bool = False,
414
- strict: bool = True,
451
+ strict: bool = False,
415
452
  assign: bool = False,
416
- weights_only: bool = False,
453
+ weights_only: bool = True,
417
454
  mmap: Optional[bool] = None,
418
- **torch_loader_kwargs,
455
+ **pickle_load_args,
419
456
  ):
420
457
  path = Path(path)
421
458
  if not path.exists():
@@ -430,7 +467,7 @@ class Model(_Devices_Base, ABC):
430
467
  return None
431
468
  path = sorted(possible_files)[-1]
432
469
  state_dict = torch.load(
433
- str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
470
+ str(path), weights_only=weights_only, mmap=mmap, **pickle_load_args
434
471
  )
435
472
  incompatible_keys = self.load_state_dict(
436
473
  state_dict,
@@ -439,6 +476,9 @@ class Model(_Devices_Base, ABC):
439
476
  )
440
477
  return incompatible_keys
441
478
 
479
+ def lora_step(self, *arg, **kwargs):
480
+ raise NotImplementedError("Not implemented for this model")
481
+
442
482
  @torch.no_grad()
443
483
  def inference(self, *args, **kwargs):
444
484
  if self.training:
@@ -524,7 +564,7 @@ class Model(_Devices_Base, ABC):
524
564
  bool: True when its frozen and false when its trainable.
525
565
  """
526
566
  if losses is not None:
527
- self.add_loss(losses)
567
+ self.add_loss(losses, "train")
528
568
 
529
569
  if isinstance(trigger_loss, bool):
530
570
  if trigger_loss:
@@ -2,18 +2,20 @@ __all__ = [
2
2
  "basic", # basic
3
3
  "residual", # residual
4
4
  "transformer", # transformer
5
- "pos_encoder",
6
- "fusion",
7
- "features",
8
- "discriminator",
5
+ "pos_encoder",
6
+ "fusion",
7
+ "features",
8
+ "discriminator",
9
+ "audio_models",
10
+ "hifigan",
9
11
  "istft",
10
12
  ]
13
+ from .audio_models import hifigan, istft
11
14
  from . import (
12
15
  basic,
13
- discriminator,
14
16
  features,
15
17
  fusion,
16
- istft,
18
+ audio_models,
17
19
  pos_encoder,
18
20
  residual,
19
21
  transformer,
@@ -0,0 +1 @@
1
+ from . import diffwave, istft, hifigan
@@ -0,0 +1,3 @@
1
+ __all__ = ["DiffWave", "SpectrogramUpsampler", "DiffusionEmbedding"]
2
+
3
+ from .model import DiffWave, SpectrogramUpsampler, DiffusionEmbedding