lt-tensor 0.0.1a22__py3-none-any.whl → 0.0.1a27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,9 +6,7 @@ from lt_tensor.misc_utils import updateDict
6
6
 
7
7
 
8
8
  class ModelConfig(ABC, OrderedDict):
9
- _default_settings: Dict[str, Any] = {}
10
9
  _forbidden_list: List[str] = [
11
- "_default_settings",
12
10
  "_forbidden_list",
13
11
  ]
14
12
 
@@ -16,12 +14,15 @@ class ModelConfig(ABC, OrderedDict):
16
14
  self,
17
15
  **settings,
18
16
  ):
19
- self._default_settings = settings
20
- self.set_state_dict(self._default_settings)
17
+ self.set_state_dict(settings)
21
18
 
22
19
  def reset_settings(self):
23
20
  raise NotImplementedError("Not implemented")
24
-
21
+
22
+ def post_process(self):
23
+ """Implement the post process, to do a final check to the input data"""
24
+ pass
25
+
25
26
  def save_config(
26
27
  self,
27
28
  path: str,
@@ -48,6 +49,7 @@ class ModelConfig(ABC, OrderedDict):
48
49
  }
49
50
  updateDict(self, new_state)
50
51
  self.update(**new_state)
52
+ self.post_process()
51
53
 
52
54
  def state_dict(self):
53
55
  return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
@@ -89,3 +91,5 @@ class ModelConfig(ABC, OrderedDict):
89
91
  settings.pop("path_name", None)
90
92
 
91
93
  return ModelConfig(**settings)
94
+
95
+
lt_tensor/misc_utils.py CHANGED
@@ -111,10 +111,10 @@ def get_weights(directory: Union[str, PathLike]):
111
111
  directory = Path(directory)
112
112
  if is_file(directory):
113
113
  if directory.name.endswith((".pt", ".ckpt", ".pth")):
114
- return directory
114
+ return [directory]
115
115
  directory = directory.parent
116
116
  res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
117
- return res[-1] if res else None
117
+ return res
118
118
 
119
119
 
120
120
  def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
@@ -128,7 +128,19 @@ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
128
128
  return load_json(directory, default)
129
129
  return load_yaml(directory, default)
130
130
  directory = directory.parent
131
- res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
131
+ res = sorted(
132
+ find_files(
133
+ directory,
134
+ [
135
+ "config*.json",
136
+ "*config.json",
137
+ "config*.yml",
138
+ "*config.yml",
139
+ "*config.yaml",
140
+ "config*.yaml",
141
+ ],
142
+ )
143
+ )
132
144
  if res:
133
145
  res = res[-1]
134
146
  if Path(res).name.endswith(".json"):
lt_tensor/model_base.py CHANGED
@@ -80,6 +80,62 @@ class _Devices_Base(nn.Module):
80
80
  assert isinstance(device, (str, torch.device))
81
81
  self._device = torch.device(device) if isinstance(device, str) else device
82
82
 
83
+ def freeze_all(self, exclude: list[str] = []):
84
+ for name, module in self.named_modules():
85
+ if name in exclude or not hasattr(module, "requires_grad"):
86
+ continue
87
+ try:
88
+ self.freeze_module(module)
89
+ except:
90
+ pass
91
+
92
+ def unfreeze_all(self, exclude: list[str] = []):
93
+ for name, module in self.named_modules():
94
+ if name in exclude or not hasattr(module, "requires_grad"):
95
+ continue
96
+ try:
97
+ self.unfreeze_module(module)
98
+ except:
99
+ pass
100
+
101
+ def freeze_module(
102
+ self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
103
+ ):
104
+ self._change_gradient_state(module_or_name, False)
105
+
106
+ def unfreeze_module(
107
+ self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
108
+ ):
109
+ self._change_gradient_state(module_or_name, True)
110
+
111
+ def _change_gradient_state(
112
+ self,
113
+ module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor],
114
+ new_state: bool, # True = unfreeze
115
+ ):
116
+ assert isinstance(
117
+ module_or_name, (str, nn.Module, nn.Parameter, Model, Tensor)
118
+ ), f"Item '{module_or_name}' is not a valid module, parameter, tensor or a string."
119
+ if isinstance(module_or_name, (nn.Module, nn.Parameter, Model, Tensor)):
120
+ target = module_or_name
121
+ else:
122
+ target = getattr(self, module_or_name)
123
+
124
+ if isinstance(target, Tensor):
125
+ target.requires_grad = new_state
126
+ elif isinstance(target, nn.Parameter):
127
+ target.requires_grad = new_state
128
+ elif isinstance(target, Model):
129
+ target.freeze_all()
130
+ elif isinstance(target, nn.Module):
131
+ for param in target.parameters():
132
+ if hasattr(param, "requires_grad"):
133
+ param.requires_grad = new_state
134
+ else:
135
+ raise ValueError(
136
+ f"Item '{module_or_name}' is not a valid module, parameter or tensor."
137
+ )
138
+
83
139
  def _apply_device(self):
84
140
  """Add here components that are needed to have device applied to them,
85
141
  that usually the '.to()' function fails to apply
@@ -182,20 +238,12 @@ class Model(_Devices_Base, ABC):
182
238
  """
183
239
 
184
240
  _autocast: bool = False
185
- _is_unfrozen: bool = False
186
- # list with modules that can be frozen or unfrozen
187
- registered_freezable_modules: List[str] = []
188
- is_frozen: bool = False
189
- _can_be_frozen: bool = (
190
- False # to control if the module can or cannot be freezed by other modules from 'Model' class
191
- )
241
+
192
242
  # this is to be used on the case of they module requires low-rank adapters
193
243
  _low_rank_lambda: Optional[Callable[[], nn.Module]] = (
194
244
  None # Example: lambda: nn.Linear(32, 32, True)
195
245
  )
196
246
  low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
197
- # never freeze:
198
- _never_freeze_modules: List[str] = ["low_rank_adapter"]
199
247
 
200
248
  # dont save list:
201
249
  _dont_save_items: List[str] = []
@@ -208,75 +256,6 @@ class Model(_Devices_Base, ABC):
208
256
  def autocast(self, value: bool):
209
257
  self._autocast = value
210
258
 
211
- def freeze_all(self, exclude: Optional[List[str]] = None, force: bool = False):
212
- no_exclusions = not exclude
213
- no_exclusions = not exclude
214
- results = []
215
- for name, module in self.named_modules():
216
- if (
217
- name in self._never_freeze_modules
218
- or not force
219
- and name not in self.registered_freezable_modules
220
- ):
221
- results.append(
222
- (
223
- name,
224
- "Unregistered module, to freeze/unfreeze it add its name into 'registered_freezable_modules'.",
225
- )
226
- )
227
- continue
228
- if no_exclusions:
229
- self.change_frozen_state(True, module)
230
- elif not any(exclusion in name for exclusion in exclude):
231
- results.append((name, self.change_frozen_state(True, module)))
232
- else:
233
- results.append((name, "excluded"))
234
- return results
235
-
236
- def unfreeze_all(self, exclude: Optional[list[str]] = None, force: bool = False):
237
- """Unfreezes all model parameters except specified layers."""
238
- no_exclusions = not exclude
239
- results = []
240
- for name, module in self.named_modules():
241
- if (
242
- name in self._never_freeze_modules
243
- or not force
244
- and name not in self.registered_freezable_modules
245
- ):
246
-
247
- results.append(
248
- (
249
- name,
250
- "Unregistered module, to freeze/unfreeze it add it into 'registered_freezable_modules'.",
251
- )
252
- )
253
- continue
254
- if no_exclusions:
255
- self.change_frozen_state(False, module)
256
- elif not any(exclusion in name for exclusion in exclude):
257
- results.append((name, self.change_frozen_state(False, module)))
258
- else:
259
- results.append((name, "excluded"))
260
- return results
261
-
262
- def change_frozen_state(self, freeze: bool, module: nn.Module):
263
- assert isinstance(module, nn.Module)
264
- if module.__class__.__name__ in self._never_freeze_modules:
265
- return "Not Allowed"
266
- try:
267
- if isinstance(module, Model):
268
- if module._can_be_frozen:
269
- if freeze:
270
- return module.freeze_all()
271
- return module.unfreeze_all()
272
- else:
273
- return "Not Allowed"
274
- else:
275
- module.requires_grad_(not freeze)
276
- return not freeze
277
- except Exception as e:
278
- return e
279
-
280
259
  def trainable_parameters(self, module_name: Optional[str] = None):
281
260
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
282
261
  if module_name is not None:
@@ -5,6 +5,7 @@ from lt_tensor.torch_commons import *
5
5
  from torch.nn import functional as F
6
6
  from lt_tensor.config_templates import ModelConfig
7
7
  from lt_tensor.torch_commons import *
8
+ from lt_tensor.model_zoo.convs import ConvNets, Conv1dEXT
8
9
  from lt_tensor.model_base import Model
9
10
  from math import sqrt
10
11
  from lt_utils.common import *
@@ -18,6 +19,8 @@ class DiffWaveConfig(ModelConfig):
18
19
  residual_channels = 64
19
20
  dilation_cycle_length = 10
20
21
  unconditional = False
22
+ apply_norm: Optional[Literal["weight", "spectral"]] = None
23
+ apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None
21
24
  noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist()
22
25
  # settings for auto-fixes
23
26
  interpolate = False
@@ -44,6 +47,8 @@ class DiffWaveConfig(ModelConfig):
44
47
  "area",
45
48
  "nearest-exact",
46
49
  ] = "nearest",
50
+ apply_norm: Optional[Literal["weight", "spectral"]] = None,
51
+ apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None,
47
52
  ):
48
53
  settings = {
49
54
  "n_mels": n_mels,
@@ -55,16 +60,12 @@ class DiffWaveConfig(ModelConfig):
55
60
  "noise_schedule": noise_schedule,
56
61
  "interpolate": interpolate_cond,
57
62
  "interpolation_mode": interpolation_mode,
63
+ "apply_norm": apply_norm,
64
+ "apply_norm_resblock": apply_norm_resblock,
58
65
  }
59
66
  super().__init__(**settings)
60
67
 
61
68
 
62
- def Conv1d(*args, **kwargs):
63
- layer = nn.Conv1d(*args, **kwargs)
64
- nn.init.kaiming_normal_(layer.weight)
65
- return layer
66
-
67
-
68
69
  class DiffusionEmbedding(Model):
69
70
  def __init__(self, max_steps: int):
70
71
  super().__init__()
@@ -117,7 +118,14 @@ class SpectrogramUpsample(Model):
117
118
 
118
119
 
119
120
  class ResidualBlock(Model):
120
- def __init__(self, n_mels, residual_channels, dilation, uncond=False):
121
+ def __init__(
122
+ self,
123
+ n_mels,
124
+ residual_channels,
125
+ dilation,
126
+ uncond=False,
127
+ apply_norm: Optional[Literal["weight", "spectral"]] = None,
128
+ ):
121
129
  """
122
130
  :param n_mels: inplanes of conv1x1 for spectrogram conditional
123
131
  :param residual_channels: audio conv
@@ -125,20 +133,28 @@ class ResidualBlock(Model):
125
133
  :param uncond: disable spectrogram conditional
126
134
  """
127
135
  super().__init__()
128
- self.dilated_conv = Conv1d(
136
+ self.dilated_conv = Conv1dEXT(
129
137
  residual_channels,
130
138
  2 * residual_channels,
131
139
  3,
132
140
  padding=dilation,
133
141
  dilation=dilation,
142
+ apply_norm=apply_norm,
134
143
  )
135
144
  self.diffusion_projection = nn.Linear(512, residual_channels)
136
145
  if not uncond: # conditional model
137
- self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
146
+ self.conditioner_projection = Conv1dEXT(
147
+ n_mels,
148
+ 2 * residual_channels,
149
+ 1,
150
+ apply_norm=apply_norm,
151
+ )
138
152
  else: # unconditional model
139
153
  self.conditioner_projection = None
140
154
 
141
- self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
155
+ self.output_projection = Conv1dEXT(
156
+ residual_channels, 2 * residual_channels, 1, apply_norm == apply_norm
157
+ )
142
158
 
143
159
  def forward(
144
160
  self,
@@ -172,7 +188,12 @@ class DiffWave(Model):
172
188
  self.n_hop = self.params.hop_samples
173
189
  self.interpolate = self.params.interpolate
174
190
  self.interpolate_mode = self.params.interpolation_mode
175
- self.input_projection = Conv1d(1, params.residual_channels, 1)
191
+ self.input_projection = Conv1dEXT(
192
+ in_channels=1,
193
+ out_channels=params.residual_channels,
194
+ kernel_size=1,
195
+ apply_norm=self.params.apply_norm,
196
+ )
176
197
  self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
177
198
  if self.params.unconditional: # use unconditional model
178
199
  self.spectrogram_upsample = None
@@ -186,14 +207,20 @@ class DiffWave(Model):
186
207
  params.residual_channels,
187
208
  2 ** (i % params.dilation_cycle_length),
188
209
  uncond=params.unconditional,
210
+ apply_norm=self.params.apply_norm_resblock,
189
211
  )
190
212
  for i in range(params.residual_layers)
191
213
  ]
192
214
  )
193
- self.skip_projection = Conv1d(
194
- params.residual_channels, params.residual_channels, 1
215
+ self.skip_projection = Conv1dEXT(
216
+ in_channels=params.residual_channels,
217
+ out_channels=params.residual_channels,
218
+ kernel_size=1,
219
+ apply_norm=self.params.apply_norm,
220
+ )
221
+ self.output_projection = Conv1dEXT(
222
+ params.residual_channels, 1, 1, apply_norm=self.params.apply_norm
195
223
  )
196
- self.output_projection = Conv1d(params.residual_channels, 1, 1)
197
224
  self.activation = nn.LeakyReLU(0.1)
198
225
  self.r_sqrt = sqrt(len(self.residual_layers))
199
226
  nn.init.zeros_(self.output_projection.weight)
@@ -1,10 +1,10 @@
1
1
  __all__ = ["HifiganGenerator", "HifiganConfig"]
2
2
  from lt_utils.common import *
3
3
  from lt_tensor.torch_commons import *
4
- from lt_tensor.model_zoo.residual import ConvNets
4
+ from lt_tensor.model_zoo.convs import ConvNets
5
5
  from torch.nn import functional as F
6
6
  from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
7
- from huggingface_hub import hf_hub_download
7
+ from lt_tensor.misc_utils import get_config, get_weights
8
8
 
9
9
 
10
10
  def get_padding(kernel_size, dilation=1):
@@ -43,7 +43,7 @@ class HifiganConfig(ModelConfig):
43
43
  [1, 3, 5],
44
44
  ],
45
45
  activation: nn.Module = nn.LeakyReLU(0.1),
46
- resblock: int = 0,
46
+ resblock: Union[int, str] = "1",
47
47
  *args,
48
48
  **kwargs,
49
49
  ):
@@ -59,6 +59,10 @@ class HifiganConfig(ModelConfig):
59
59
  }
60
60
  super().__init__(**settings)
61
61
 
62
+ def post_process(self):
63
+ if isinstance(self.resblock, str):
64
+ self.resblock = 0 if self.resblock == "1" else 1
65
+
62
66
 
63
67
  class ResBlock1(ConvNets):
64
68
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
@@ -240,15 +244,15 @@ class HifiganGenerator(ConvNets):
240
244
  def load_weights(
241
245
  self,
242
246
  path,
243
- raise_if_not_exists=False,
244
247
  strict=False,
245
248
  assign=False,
246
- weights_only=True,
249
+ weights_only=False,
247
250
  mmap=None,
251
+ raise_if_not_exists=False,
248
252
  **pickle_load_args,
249
253
  ):
250
254
  try:
251
- return super().load_weights(
255
+ incompatible_keys = super().load_weights(
252
256
  path,
253
257
  raise_if_not_exists,
254
258
  strict,
@@ -257,6 +261,18 @@ class HifiganGenerator(ConvNets):
257
261
  mmap,
258
262
  **pickle_load_args,
259
263
  )
264
+ if incompatible_keys:
265
+ self.remove_norms()
266
+ incompatible_keys = super().load_weights(
267
+ path,
268
+ raise_if_not_exists,
269
+ strict,
270
+ assign,
271
+ weights_only,
272
+ mmap,
273
+ **pickle_load_args,
274
+ )
275
+ return incompatible_keys
260
276
  except RuntimeError:
261
277
  self.remove_norms()
262
278
  return super().load_weights(
@@ -272,95 +288,37 @@ class HifiganGenerator(ConvNets):
272
288
  @classmethod
273
289
  def from_pretrained(
274
290
  cls,
275
- model_id: str,
276
- map_location: str = "cpu",
277
- local_files_only: bool = False,
278
- strict: bool = False,
291
+ model_file: PathLike,
292
+ model_config: Union[HifiganConfig, Dict[str, Any]],
279
293
  *,
280
- subfolder: str | None = None,
281
- repo_type: str | None = None,
282
- revision: str | None = None,
283
- cache_dir: str | Path | None = None,
284
- force_download: bool = False,
285
- proxies: Dict | None = None,
286
- token: bool | str | None = None,
287
- resume_download: bool | None = None,
288
- local_dir_use_symlinks: bool | Literal["auto"] = "auto",
294
+ strict: bool = False,
295
+ map_location: str = "cpu",
296
+ weights_only: bool = False,
289
297
  **kwargs,
290
298
  ):
291
- """Load Pytorch pretrained weights and return the loaded model."""
292
- hub_kwargs = dict(
293
- repo_id=model_id,
294
- subfolder=subfolder,
295
- repo_type=repo_type,
296
- revision=revision,
297
- cache_dir=cache_dir,
298
- force_download=force_download,
299
- proxies=proxies,
300
- resume_download=resume_download,
301
- token=token,
302
- local_files_only=local_files_only,
303
- local_dir_use_symlinks=local_dir_use_symlinks,
304
- )
305
-
306
- # Download and load hyperparameters (h) used by BigVGAN
307
- _model_path = Path(model_id)
308
- config_file = None
309
- if is_path_valid(model_id):
310
- if is_file(model_id):
311
- _p_conf = _model_path.parent / "config.json"
312
- else:
313
- _p_conf = _model_path / "config.json"
314
-
315
- if is_file(_p_conf):
316
- print("Loading config.json from local directory")
317
- config_file = Path(model_id, "config.json")
318
- else:
319
- if not local_files_only:
320
- print(f"Loading config from {model_id}")
321
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
322
- else:
323
- if not local_files_only:
324
- print(f"Loading config from {model_id}")
325
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
326
-
327
- if config_file is not None:
328
- model = cls(HifiganConfig(**load_json(config_file)))
329
- else:
330
- model = cls()
331
299
 
332
- # Download and load pretrained generator weight
333
- _retrieve_kwargs = dict(
334
- **hub_kwargs,
335
- filename="generator.pt",
300
+ is_file(model_file, validate=True)
301
+ model_state_dict = torch.load(
302
+ model_file, weights_only=weights_only, map_location=map_location
336
303
  )
337
- path = Path(model_id)
338
- if path.exists():
339
- if path.is_dir():
340
- path = path / "generator.pt"
341
- if path.exists():
342
- print("Loading weights from local directory")
343
- model_file = str(path)
344
- else:
345
- print(f"Loading weights from {model_id}")
346
- model_file = hf_hub_download(**_retrieve_kwargs)
347
- else:
348
- print("Loading weights from local directory")
349
- model_file = str(path)
304
+
305
+ if isinstance(model_config, HifiganConfig):
306
+ h = model_config
350
307
  else:
351
- print(f"Loading weights from {model_id}")
352
- model_file = hf_hub_download(**_retrieve_kwargs)
353
- checkpoint_dict = torch.load(model_file, map_location=map_location)
308
+ h = HifiganConfig(**model_config)
354
309
 
310
+ model = cls(h)
355
311
  try:
356
- model.load_state_dict(checkpoint_dict, strict=strict)
312
+ incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
313
+ if incompatible_keys:
314
+ model.remove_norms()
315
+ model.load_state_dict(model_state_dict, strict=strict)
357
316
  except RuntimeError:
358
317
  print(
359
318
  f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
360
319
  )
361
320
  model.remove_norms()
362
- model.load_state_dict(checkpoint_dict, strict=strict)
363
-
321
+ model.load_state_dict(model_state_dict, strict=strict)
364
322
  return model
365
323
 
366
324