lt-tensor 0.0.1a22__tar.gz → 0.0.1a26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/PKG-INFO +1 -1
  2. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/config_templates.py +9 -5
  3. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/misc_utils.py +15 -3
  4. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +41 -14
  5. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +40 -82
  6. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/istft/__init__.py +41 -83
  7. lt_tensor-0.0.1a26/lt_tensor/model_zoo/convs.py +124 -0
  8. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/residual.py +1 -136
  9. lt_tensor-0.0.1a26/lt_tensor/processors/__init__.py +3 -0
  10. lt_tensor-0.0.1a26/lt_tensor/processors/audio.py +527 -0
  11. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/PKG-INFO +1 -1
  12. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/SOURCES.txt +1 -2
  13. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/setup.py +1 -1
  14. lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +0 -536
  15. lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/cuda/__init__.py +0 -160
  16. lt_tensor-0.0.1a22/lt_tensor/processors/__init__.py +0 -3
  17. lt_tensor-0.0.1a22/lt_tensor/processors/audio.py +0 -456
  18. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/LICENSE +0 -0
  19. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/README.md +0 -0
  20. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/__init__.py +0 -0
  21. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/losses.py +0 -0
  22. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/lr_schedulers.py +0 -0
  23. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/math_ops.py +0 -0
  24. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_base.py +0 -0
  25. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/__init__.py +0 -0
  26. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -0
  27. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/act.py +0 -0
  28. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +0 -0
  29. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +0 -0
  30. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
  31. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
  32. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/basic.py +0 -0
  33. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/features.py +0 -0
  34. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/fusion.py +0 -0
  35. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  36. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/transformer.py +0 -0
  37. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/monotonic_align.py +0 -0
  38. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/noise_tools.py +0 -0
  39. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/torch_commons.py +0 -0
  40. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/transform.py +0 -0
  41. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/dependency_links.txt +0 -0
  42. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/requires.txt +0 -0
  43. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/top_level.txt +0 -0
  44. {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a22
3
+ Version: 0.0.1a26
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -6,9 +6,7 @@ from lt_tensor.misc_utils import updateDict
6
6
 
7
7
 
8
8
  class ModelConfig(ABC, OrderedDict):
9
- _default_settings: Dict[str, Any] = {}
10
9
  _forbidden_list: List[str] = [
11
- "_default_settings",
12
10
  "_forbidden_list",
13
11
  ]
14
12
 
@@ -16,12 +14,15 @@ class ModelConfig(ABC, OrderedDict):
16
14
  self,
17
15
  **settings,
18
16
  ):
19
- self._default_settings = settings
20
- self.set_state_dict(self._default_settings)
17
+ self.set_state_dict(settings)
21
18
 
22
19
  def reset_settings(self):
23
20
  raise NotImplementedError("Not implemented")
24
-
21
+
22
+ def post_process(self):
23
+ """Implement the post process, to do a final check to the input data"""
24
+ pass
25
+
25
26
  def save_config(
26
27
  self,
27
28
  path: str,
@@ -48,6 +49,7 @@ class ModelConfig(ABC, OrderedDict):
48
49
  }
49
50
  updateDict(self, new_state)
50
51
  self.update(**new_state)
52
+ self.post_process()
51
53
 
52
54
  def state_dict(self):
53
55
  return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
@@ -89,3 +91,5 @@ class ModelConfig(ABC, OrderedDict):
89
91
  settings.pop("path_name", None)
90
92
 
91
93
  return ModelConfig(**settings)
94
+
95
+
@@ -111,10 +111,10 @@ def get_weights(directory: Union[str, PathLike]):
111
111
  directory = Path(directory)
112
112
  if is_file(directory):
113
113
  if directory.name.endswith((".pt", ".ckpt", ".pth")):
114
- return directory
114
+ return [directory]
115
115
  directory = directory.parent
116
116
  res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
117
- return res[-1] if res else None
117
+ return res
118
118
 
119
119
 
120
120
  def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
@@ -128,7 +128,19 @@ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
128
128
  return load_json(directory, default)
129
129
  return load_yaml(directory, default)
130
130
  directory = directory.parent
131
- res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
131
+ res = sorted(
132
+ find_files(
133
+ directory,
134
+ [
135
+ "config*.json",
136
+ "*config.json",
137
+ "config*.yml",
138
+ "*config.yml",
139
+ "*config.yaml",
140
+ "config*.yaml",
141
+ ],
142
+ )
143
+ )
132
144
  if res:
133
145
  res = res[-1]
134
146
  if Path(res).name.endswith(".json"):
@@ -5,6 +5,7 @@ from lt_tensor.torch_commons import *
5
5
  from torch.nn import functional as F
6
6
  from lt_tensor.config_templates import ModelConfig
7
7
  from lt_tensor.torch_commons import *
8
+ from lt_tensor.model_zoo.convs import ConvNets, Conv1dEXT
8
9
  from lt_tensor.model_base import Model
9
10
  from math import sqrt
10
11
  from lt_utils.common import *
@@ -18,6 +19,8 @@ class DiffWaveConfig(ModelConfig):
18
19
  residual_channels = 64
19
20
  dilation_cycle_length = 10
20
21
  unconditional = False
22
+ apply_norm: Optional[Literal["weight", "spectral"]] = None
23
+ apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None
21
24
  noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist()
22
25
  # settings for auto-fixes
23
26
  interpolate = False
@@ -44,6 +47,8 @@ class DiffWaveConfig(ModelConfig):
44
47
  "area",
45
48
  "nearest-exact",
46
49
  ] = "nearest",
50
+ apply_norm: Optional[Literal["weight", "spectral"]] = None,
51
+ apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None,
47
52
  ):
48
53
  settings = {
49
54
  "n_mels": n_mels,
@@ -55,16 +60,12 @@ class DiffWaveConfig(ModelConfig):
55
60
  "noise_schedule": noise_schedule,
56
61
  "interpolate": interpolate_cond,
57
62
  "interpolation_mode": interpolation_mode,
63
+ "apply_norm": apply_norm,
64
+ "apply_norm_resblock": apply_norm_resblock,
58
65
  }
59
66
  super().__init__(**settings)
60
67
 
61
68
 
62
- def Conv1d(*args, **kwargs):
63
- layer = nn.Conv1d(*args, **kwargs)
64
- nn.init.kaiming_normal_(layer.weight)
65
- return layer
66
-
67
-
68
69
  class DiffusionEmbedding(Model):
69
70
  def __init__(self, max_steps: int):
70
71
  super().__init__()
@@ -117,7 +118,14 @@ class SpectrogramUpsample(Model):
117
118
 
118
119
 
119
120
  class ResidualBlock(Model):
120
- def __init__(self, n_mels, residual_channels, dilation, uncond=False):
121
+ def __init__(
122
+ self,
123
+ n_mels,
124
+ residual_channels,
125
+ dilation,
126
+ uncond=False,
127
+ apply_norm: Optional[Literal["weight", "spectral"]] = None,
128
+ ):
121
129
  """
122
130
  :param n_mels: inplanes of conv1x1 for spectrogram conditional
123
131
  :param residual_channels: audio conv
@@ -125,20 +133,28 @@ class ResidualBlock(Model):
125
133
  :param uncond: disable spectrogram conditional
126
134
  """
127
135
  super().__init__()
128
- self.dilated_conv = Conv1d(
136
+ self.dilated_conv = Conv1dEXT(
129
137
  residual_channels,
130
138
  2 * residual_channels,
131
139
  3,
132
140
  padding=dilation,
133
141
  dilation=dilation,
142
+ apply_norm=apply_norm,
134
143
  )
135
144
  self.diffusion_projection = nn.Linear(512, residual_channels)
136
145
  if not uncond: # conditional model
137
- self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
146
+ self.conditioner_projection = Conv1dEXT(
147
+ n_mels,
148
+ 2 * residual_channels,
149
+ 1,
150
+ apply_norm=apply_norm,
151
+ )
138
152
  else: # unconditional model
139
153
  self.conditioner_projection = None
140
154
 
141
- self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
155
+ self.output_projection = Conv1dEXT(
156
+ residual_channels, 2 * residual_channels, 1, apply_norm == apply_norm
157
+ )
142
158
 
143
159
  def forward(
144
160
  self,
@@ -172,7 +188,12 @@ class DiffWave(Model):
172
188
  self.n_hop = self.params.hop_samples
173
189
  self.interpolate = self.params.interpolate
174
190
  self.interpolate_mode = self.params.interpolation_mode
175
- self.input_projection = Conv1d(1, params.residual_channels, 1)
191
+ self.input_projection = Conv1dEXT(
192
+ in_channels=1,
193
+ out_channels=params.residual_channels,
194
+ kernel_size=1,
195
+ apply_norm=self.params.apply_norm,
196
+ )
176
197
  self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
177
198
  if self.params.unconditional: # use unconditional model
178
199
  self.spectrogram_upsample = None
@@ -186,14 +207,20 @@ class DiffWave(Model):
186
207
  params.residual_channels,
187
208
  2 ** (i % params.dilation_cycle_length),
188
209
  uncond=params.unconditional,
210
+ apply_norm=self.params.apply_norm_resblock,
189
211
  )
190
212
  for i in range(params.residual_layers)
191
213
  ]
192
214
  )
193
- self.skip_projection = Conv1d(
194
- params.residual_channels, params.residual_channels, 1
215
+ self.skip_projection = Conv1dEXT(
216
+ in_channels=params.residual_channels,
217
+ out_channels=params.residual_channels,
218
+ kernel_size=1,
219
+ apply_norm=self.params.apply_norm,
220
+ )
221
+ self.output_projection = Conv1dEXT(
222
+ params.residual_channels, 1, 1, apply_norm=self.params.apply_norm
195
223
  )
196
- self.output_projection = Conv1d(params.residual_channels, 1, 1)
197
224
  self.activation = nn.LeakyReLU(0.1)
198
225
  self.r_sqrt = sqrt(len(self.residual_layers))
199
226
  nn.init.zeros_(self.output_projection.weight)
@@ -1,10 +1,10 @@
1
1
  __all__ = ["HifiganGenerator", "HifiganConfig"]
2
2
  from lt_utils.common import *
3
3
  from lt_tensor.torch_commons import *
4
- from lt_tensor.model_zoo.residual import ConvNets
4
+ from lt_tensor.model_zoo.convs import ConvNets
5
5
  from torch.nn import functional as F
6
6
  from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
7
- from huggingface_hub import hf_hub_download
7
+ from lt_tensor.misc_utils import get_config, get_weights
8
8
 
9
9
 
10
10
  def get_padding(kernel_size, dilation=1):
@@ -43,7 +43,7 @@ class HifiganConfig(ModelConfig):
43
43
  [1, 3, 5],
44
44
  ],
45
45
  activation: nn.Module = nn.LeakyReLU(0.1),
46
- resblock: int = 0,
46
+ resblock: Union[int, str] = "1",
47
47
  *args,
48
48
  **kwargs,
49
49
  ):
@@ -59,6 +59,10 @@ class HifiganConfig(ModelConfig):
59
59
  }
60
60
  super().__init__(**settings)
61
61
 
62
+ def post_process(self):
63
+ if isinstance(self.resblock, str):
64
+ self.resblock = 0 if self.resblock == "1" else 1
65
+
62
66
 
63
67
  class ResBlock1(ConvNets):
64
68
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
@@ -240,15 +244,15 @@ class HifiganGenerator(ConvNets):
240
244
  def load_weights(
241
245
  self,
242
246
  path,
243
- raise_if_not_exists=False,
244
247
  strict=False,
245
248
  assign=False,
246
- weights_only=True,
249
+ weights_only=False,
247
250
  mmap=None,
251
+ raise_if_not_exists=False,
248
252
  **pickle_load_args,
249
253
  ):
250
254
  try:
251
- return super().load_weights(
255
+ incompatible_keys = super().load_weights(
252
256
  path,
253
257
  raise_if_not_exists,
254
258
  strict,
@@ -257,6 +261,18 @@ class HifiganGenerator(ConvNets):
257
261
  mmap,
258
262
  **pickle_load_args,
259
263
  )
264
+ if incompatible_keys:
265
+ self.remove_norms()
266
+ incompatible_keys = super().load_weights(
267
+ path,
268
+ raise_if_not_exists,
269
+ strict,
270
+ assign,
271
+ weights_only,
272
+ mmap,
273
+ **pickle_load_args,
274
+ )
275
+ return incompatible_keys
260
276
  except RuntimeError:
261
277
  self.remove_norms()
262
278
  return super().load_weights(
@@ -272,95 +288,37 @@ class HifiganGenerator(ConvNets):
272
288
  @classmethod
273
289
  def from_pretrained(
274
290
  cls,
275
- model_id: str,
276
- map_location: str = "cpu",
277
- local_files_only: bool = False,
278
- strict: bool = False,
291
+ model_file: PathLike,
292
+ model_config: Union[HifiganConfig, Dict[str, Any]],
279
293
  *,
280
- subfolder: str | None = None,
281
- repo_type: str | None = None,
282
- revision: str | None = None,
283
- cache_dir: str | Path | None = None,
284
- force_download: bool = False,
285
- proxies: Dict | None = None,
286
- token: bool | str | None = None,
287
- resume_download: bool | None = None,
288
- local_dir_use_symlinks: bool | Literal["auto"] = "auto",
294
+ strict: bool = False,
295
+ map_location: str = "cpu",
296
+ weights_only: bool = False,
289
297
  **kwargs,
290
298
  ):
291
- """Load Pytorch pretrained weights and return the loaded model."""
292
- hub_kwargs = dict(
293
- repo_id=model_id,
294
- subfolder=subfolder,
295
- repo_type=repo_type,
296
- revision=revision,
297
- cache_dir=cache_dir,
298
- force_download=force_download,
299
- proxies=proxies,
300
- resume_download=resume_download,
301
- token=token,
302
- local_files_only=local_files_only,
303
- local_dir_use_symlinks=local_dir_use_symlinks,
304
- )
305
-
306
- # Download and load hyperparameters (h) used by BigVGAN
307
- _model_path = Path(model_id)
308
- config_file = None
309
- if is_path_valid(model_id):
310
- if is_file(model_id):
311
- _p_conf = _model_path.parent / "config.json"
312
- else:
313
- _p_conf = _model_path / "config.json"
314
-
315
- if is_file(_p_conf):
316
- print("Loading config.json from local directory")
317
- config_file = Path(model_id, "config.json")
318
- else:
319
- if not local_files_only:
320
- print(f"Loading config from {model_id}")
321
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
322
- else:
323
- if not local_files_only:
324
- print(f"Loading config from {model_id}")
325
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
326
-
327
- if config_file is not None:
328
- model = cls(HifiganConfig(**load_json(config_file)))
329
- else:
330
- model = cls()
331
299
 
332
- # Download and load pretrained generator weight
333
- _retrieve_kwargs = dict(
334
- **hub_kwargs,
335
- filename="generator.pt",
300
+ is_file(model_file, validate=True)
301
+ model_state_dict = torch.load(
302
+ model_file, weights_only=weights_only, map_location=map_location
336
303
  )
337
- path = Path(model_id)
338
- if path.exists():
339
- if path.is_dir():
340
- path = path / "generator.pt"
341
- if path.exists():
342
- print("Loading weights from local directory")
343
- model_file = str(path)
344
- else:
345
- print(f"Loading weights from {model_id}")
346
- model_file = hf_hub_download(**_retrieve_kwargs)
347
- else:
348
- print("Loading weights from local directory")
349
- model_file = str(path)
304
+
305
+ if isinstance(model_config, HifiganConfig):
306
+ h = model_config
350
307
  else:
351
- print(f"Loading weights from {model_id}")
352
- model_file = hf_hub_download(**_retrieve_kwargs)
353
- checkpoint_dict = torch.load(model_file, map_location=map_location)
308
+ h = HifiganConfig(**model_config)
354
309
 
310
+ model = cls(h)
355
311
  try:
356
- model.load_state_dict(checkpoint_dict, strict=strict)
312
+ incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
313
+ if incompatible_keys:
314
+ model.remove_norms()
315
+ model.load_state_dict(model_state_dict, strict=strict)
357
316
  except RuntimeError:
358
317
  print(
359
318
  f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
360
319
  )
361
320
  model.remove_norms()
362
- model.load_state_dict(checkpoint_dict, strict=strict)
363
-
321
+ model.load_state_dict(model_state_dict, strict=strict)
364
322
  return model
365
323
 
366
324
 
@@ -1,11 +1,11 @@
1
1
  __all__ = ["iSTFTNetGenerator", "iSTFTNetConfig"]
2
2
  from lt_utils.common import *
3
3
  from lt_tensor.torch_commons import *
4
- from lt_tensor.model_zoo.residual import ConvNets
4
+ from lt_tensor.model_zoo.convs import ConvNets
5
5
  from torch.nn import functional as F
6
6
  from lt_tensor.config_templates import ModelConfig
7
+ from lt_tensor.misc_utils import get_config, get_weights
7
8
  from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
8
- from huggingface_hub import hf_hub_download
9
9
 
10
10
 
11
11
  class iSTFTNetConfig(ModelConfig):
@@ -39,7 +39,7 @@ class iSTFTNetConfig(ModelConfig):
39
39
  [1, 3, 5],
40
40
  ],
41
41
  activation: nn.Module = nn.LeakyReLU(0.1),
42
- resblock: int = 0,
42
+ resblock: Union[int, str] = "1",
43
43
  gen_istft_n_fft: int = 16,
44
44
  sampling_rate: Number = 24000,
45
45
  *args,
@@ -59,7 +59,11 @@ class iSTFTNetConfig(ModelConfig):
59
59
  }
60
60
  super().__init__(**settings)
61
61
 
62
-
62
+ def post_process(self):
63
+ if isinstance(self.resblock, str):
64
+ self.resblock = 0 if self.resblock == "1" else 1
65
+
66
+
63
67
  def get_padding(ks, d):
64
68
  return int((ks * d - d) / 2)
65
69
 
@@ -271,15 +275,15 @@ class iSTFTNetGenerator(ConvNets):
271
275
  def load_weights(
272
276
  self,
273
277
  path,
274
- raise_if_not_exists=False,
275
278
  strict=False,
276
279
  assign=False,
277
- weights_only=True,
280
+ weights_only=False,
278
281
  mmap=None,
282
+ raise_if_not_exists=False,
279
283
  **pickle_load_args,
280
284
  ):
281
285
  try:
282
- return super().load_weights(
286
+ incompatible_keys = super().load_weights(
283
287
  path,
284
288
  raise_if_not_exists,
285
289
  strict,
@@ -288,6 +292,18 @@ class iSTFTNetGenerator(ConvNets):
288
292
  mmap,
289
293
  **pickle_load_args,
290
294
  )
295
+ if incompatible_keys:
296
+ self.remove_norms()
297
+ incompatible_keys = super().load_weights(
298
+ path,
299
+ raise_if_not_exists,
300
+ strict,
301
+ assign,
302
+ weights_only,
303
+ mmap,
304
+ **pickle_load_args,
305
+ )
306
+ return incompatible_keys
291
307
  except RuntimeError:
292
308
  self.remove_norms()
293
309
  return super().load_weights(
@@ -303,95 +319,37 @@ class iSTFTNetGenerator(ConvNets):
303
319
  @classmethod
304
320
  def from_pretrained(
305
321
  cls,
306
- model_id: str,
307
- map_location: str = "cpu",
308
- local_files_only: bool = False,
309
- strict: bool = False,
322
+ model_file: PathLike,
323
+ model_config: Union[iSTFTNetConfig, Dict[str, Any]],
310
324
  *,
311
- subfolder: str | None = None,
312
- repo_type: str | None = None,
313
- revision: str | None = None,
314
- cache_dir: str | Path | None = None,
315
- force_download: bool = False,
316
- proxies: Dict | None = None,
317
- token: bool | str | None = None,
318
- resume_download: bool | None = None,
319
- local_dir_use_symlinks: bool | Literal["auto"] = "auto",
325
+ strict: bool = False,
326
+ map_location: str = "cpu",
327
+ weights_only: bool = False,
320
328
  **kwargs,
321
329
  ):
322
- """Load Pytorch pretrained weights and return the loaded model."""
323
- hub_kwargs = dict(
324
- repo_id=model_id,
325
- subfolder=subfolder,
326
- repo_type=repo_type,
327
- revision=revision,
328
- cache_dir=cache_dir,
329
- force_download=force_download,
330
- proxies=proxies,
331
- resume_download=resume_download,
332
- token=token,
333
- local_files_only=local_files_only,
334
- local_dir_use_symlinks=local_dir_use_symlinks,
335
- )
336
330
 
337
- # Download and load hyperparameters (h) used by BigVGAN
338
- _model_path = Path(model_id)
339
- config_file = None
340
- if is_path_valid(model_id):
341
- if is_file(model_id):
342
- _p_conf = _model_path.parent / "config.json"
343
- else:
344
- _p_conf = _model_path / "config.json"
345
-
346
- if is_file(_p_conf):
347
- print("Loading config.json from local directory")
348
- config_file = Path(model_id, "config.json")
349
- else:
350
- if not local_files_only:
351
- print(f"Loading config from {model_id}")
352
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
353
- else:
354
- if not local_files_only:
355
- print(f"Loading config from {model_id}")
356
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
357
-
358
- if config_file is not None:
359
- model = cls(iSTFTNetConfig(**load_json(config_file)))
360
- else:
361
- model = cls()
362
-
363
- # Download and load pretrained generator weight
364
- _retrieve_kwargs = dict(
365
- **hub_kwargs,
366
- filename="generator.pt",
331
+ is_file(model_file, validate=True)
332
+ model_state_dict = torch.load(
333
+ model_file, weights_only=weights_only, map_location=map_location
367
334
  )
368
- path = Path(model_id)
369
- if path.exists():
370
- if path.is_dir():
371
- path = path / "generator.pt"
372
- if path.exists():
373
- print("Loading weights from local directory")
374
- model_file = str(path)
375
- else:
376
- print(f"Loading weights from {model_id}")
377
- model_file = hf_hub_download(**_retrieve_kwargs)
378
- else:
379
- print("Loading weights from local directory")
380
- model_file = str(path)
335
+
336
+ if isinstance(model_config, iSTFTNetConfig):
337
+ h = model_config
381
338
  else:
382
- print(f"Loading weights from {model_id}")
383
- model_file = hf_hub_download(**_retrieve_kwargs)
384
- checkpoint_dict = torch.load(model_file, map_location=map_location)
339
+ h = iSTFTNetConfig(**model_config)
385
340
 
341
+ model = cls(h)
386
342
  try:
387
- model.load_state_dict(checkpoint_dict, strict=strict)
343
+ incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
344
+ if incompatible_keys:
345
+ model.remove_norms()
346
+ model.load_state_dict(model_state_dict, strict=strict)
388
347
  except RuntimeError:
389
348
  print(
390
349
  f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
391
350
  )
392
351
  model.remove_norms()
393
- model.load_state_dict(checkpoint_dict, strict=strict)
394
-
352
+ model.load_state_dict(model_state_dict, strict=strict)
395
353
  return model
396
354
 
397
355