lt-tensor 0.0.1a26__py3-none-any.whl → 0.0.1a27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -80,6 +80,62 @@ class _Devices_Base(nn.Module):
80
80
  assert isinstance(device, (str, torch.device))
81
81
  self._device = torch.device(device) if isinstance(device, str) else device
82
82
 
83
+ def freeze_all(self, exclude: list[str] = []):
84
+ for name, module in self.named_modules():
85
+ if name in exclude or not hasattr(module, "requires_grad"):
86
+ continue
87
+ try:
88
+ self.freeze_module(module)
89
+ except:
90
+ pass
91
+
92
+ def unfreeze_all(self, exclude: list[str] = []):
93
+ for name, module in self.named_modules():
94
+ if name in exclude or not hasattr(module, "requires_grad"):
95
+ continue
96
+ try:
97
+ self.unfreeze_module(module)
98
+ except:
99
+ pass
100
+
101
+ def freeze_module(
102
+ self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
103
+ ):
104
+ self._change_gradient_state(module_or_name, False)
105
+
106
+ def unfreeze_module(
107
+ self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
108
+ ):
109
+ self._change_gradient_state(module_or_name, True)
110
+
111
+ def _change_gradient_state(
112
+ self,
113
+ module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor],
114
+ new_state: bool, # True = unfreeze
115
+ ):
116
+ assert isinstance(
117
+ module_or_name, (str, nn.Module, nn.Parameter, Model, Tensor)
118
+ ), f"Item '{module_or_name}' is not a valid module, parameter, tensor or a string."
119
+ if isinstance(module_or_name, (nn.Module, nn.Parameter, Model, Tensor)):
120
+ target = module_or_name
121
+ else:
122
+ target = getattr(self, module_or_name)
123
+
124
+ if isinstance(target, Tensor):
125
+ target.requires_grad = new_state
126
+ elif isinstance(target, nn.Parameter):
127
+ target.requires_grad = new_state
128
+ elif isinstance(target, Model):
129
+ target.freeze_all()
130
+ elif isinstance(target, nn.Module):
131
+ for param in target.parameters():
132
+ if hasattr(param, "requires_grad"):
133
+ param.requires_grad = new_state
134
+ else:
135
+ raise ValueError(
136
+ f"Item '{module_or_name}' is not a valid module, parameter or tensor."
137
+ )
138
+
83
139
  def _apply_device(self):
84
140
  """Add here components that are needed to have device applied to them,
85
141
  that usually the '.to()' function fails to apply
@@ -182,20 +238,12 @@ class Model(_Devices_Base, ABC):
182
238
  """
183
239
 
184
240
  _autocast: bool = False
185
- _is_unfrozen: bool = False
186
- # list with modules that can be frozen or unfrozen
187
- registered_freezable_modules: List[str] = []
188
- is_frozen: bool = False
189
- _can_be_frozen: bool = (
190
- False # to control if the module can or cannot be freezed by other modules from 'Model' class
191
- )
241
+
192
242
  # this is to be used on the case of they module requires low-rank adapters
193
243
  _low_rank_lambda: Optional[Callable[[], nn.Module]] = (
194
244
  None # Example: lambda: nn.Linear(32, 32, True)
195
245
  )
196
246
  low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
197
- # never freeze:
198
- _never_freeze_modules: List[str] = ["low_rank_adapter"]
199
247
 
200
248
  # dont save list:
201
249
  _dont_save_items: List[str] = []
@@ -208,75 +256,6 @@ class Model(_Devices_Base, ABC):
208
256
  def autocast(self, value: bool):
209
257
  self._autocast = value
210
258
 
211
- def freeze_all(self, exclude: Optional[List[str]] = None, force: bool = False):
212
- no_exclusions = not exclude
213
- no_exclusions = not exclude
214
- results = []
215
- for name, module in self.named_modules():
216
- if (
217
- name in self._never_freeze_modules
218
- or not force
219
- and name not in self.registered_freezable_modules
220
- ):
221
- results.append(
222
- (
223
- name,
224
- "Unregistered module, to freeze/unfreeze it add its name into 'registered_freezable_modules'.",
225
- )
226
- )
227
- continue
228
- if no_exclusions:
229
- self.change_frozen_state(True, module)
230
- elif not any(exclusion in name for exclusion in exclude):
231
- results.append((name, self.change_frozen_state(True, module)))
232
- else:
233
- results.append((name, "excluded"))
234
- return results
235
-
236
- def unfreeze_all(self, exclude: Optional[list[str]] = None, force: bool = False):
237
- """Unfreezes all model parameters except specified layers."""
238
- no_exclusions = not exclude
239
- results = []
240
- for name, module in self.named_modules():
241
- if (
242
- name in self._never_freeze_modules
243
- or not force
244
- and name not in self.registered_freezable_modules
245
- ):
246
-
247
- results.append(
248
- (
249
- name,
250
- "Unregistered module, to freeze/unfreeze it add it into 'registered_freezable_modules'.",
251
- )
252
- )
253
- continue
254
- if no_exclusions:
255
- self.change_frozen_state(False, module)
256
- elif not any(exclusion in name for exclusion in exclude):
257
- results.append((name, self.change_frozen_state(False, module)))
258
- else:
259
- results.append((name, "excluded"))
260
- return results
261
-
262
- def change_frozen_state(self, freeze: bool, module: nn.Module):
263
- assert isinstance(module, nn.Module)
264
- if module.__class__.__name__ in self._never_freeze_modules:
265
- return "Not Allowed"
266
- try:
267
- if isinstance(module, Model):
268
- if module._can_be_frozen:
269
- if freeze:
270
- return module.freeze_all()
271
- return module.unfreeze_all()
272
- else:
273
- return "Not Allowed"
274
- else:
275
- module.requires_grad_(not freeze)
276
- return not freeze
277
- except Exception as e:
278
- return e
279
-
280
259
  def trainable_parameters(self, module_name: Optional[str] = None):
281
260
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
282
261
  if module_name is not None:
@@ -250,7 +250,7 @@ class AudioProcessor(Model):
250
250
  for i in range(B):
251
251
  f0_.append(librosa.yin(self.to_numpy_safe(audio[i, :]), **yn_kwargs))
252
252
  f0 = self.from_numpy_batch(f0_, default_device, default_dtype)
253
- return f0
253
+ return f0.squeeze()
254
254
 
255
255
  def compute_pitch_torch(
256
256
  self,
@@ -407,11 +407,7 @@ class AudioProcessor(Model):
407
407
  mel_tensor = (
408
408
  torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
409
409
  ) / self.cfg.std
410
- if mel_tensor.ndim == 4:
411
- return mel_tensor.squeeze()
412
- elif mel_tensor.ndim == 2:
413
- return mel_tensor.unsqueeze(0)
414
- return mel_tensor
410
+ return mel_tensor.squeeze()
415
411
 
416
412
  except RuntimeError as e:
417
413
  if not _recall:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a26
3
+ Version: 0.0.1a27
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -4,7 +4,7 @@ lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
5
  lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
6
6
  lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
7
- lt_tensor/model_base.py,sha256=GvmQdt97ZSfOObBpBIq7UUTwpIE1g-aBm23za36YA0M,18431
7
+ lt_tensor/model_base.py,sha256=DTg44N6eTXLmpIAj_ac29-M5dI_iY_sC0yA_K3E13GI,17446
8
8
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
9
9
  lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
10
  lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
@@ -27,9 +27,9 @@ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
29
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
30
- lt_tensor/processors/audio.py,sha256=WkumFNx8OXGQkTEU5Rkede9NLMrsGaTGY37Ti784Wv8,17028
31
- lt_tensor-0.0.1a26.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
32
- lt_tensor-0.0.1a26.dist-info/METADATA,sha256=2STSK6jgD_qECwz9WygTXNDwfapEAR2mpHiS14bi9tQ,1062
33
- lt_tensor-0.0.1a26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- lt_tensor-0.0.1a26.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
35
- lt_tensor-0.0.1a26.dist-info/RECORD,,
30
+ lt_tensor/processors/audio.py,sha256=mZY7LOeYACnX8PLz5AeFe0zqEebPoN-Q44Bi3yrlZMQ,16881
31
+ lt_tensor-0.0.1a27.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
32
+ lt_tensor-0.0.1a27.dist-info/METADATA,sha256=NpXqioPXZMvXo-HzhXrS6O1qiftDnoc8ZzOfhfUMBaY,1062
33
+ lt_tensor-0.0.1a27.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
+ lt_tensor-0.0.1a27.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
35
+ lt_tensor-0.0.1a27.dist-info/RECORD,,