lt-tensor 0.0.1a22__py3-none-any.whl → 0.0.1a27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  __all__ = ["iSTFTNetGenerator", "iSTFTNetConfig"]
2
2
  from lt_utils.common import *
3
3
  from lt_tensor.torch_commons import *
4
- from lt_tensor.model_zoo.residual import ConvNets
4
+ from lt_tensor.model_zoo.convs import ConvNets
5
5
  from torch.nn import functional as F
6
6
  from lt_tensor.config_templates import ModelConfig
7
+ from lt_tensor.misc_utils import get_config, get_weights
7
8
  from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
8
- from huggingface_hub import hf_hub_download
9
9
 
10
10
 
11
11
  class iSTFTNetConfig(ModelConfig):
@@ -39,7 +39,7 @@ class iSTFTNetConfig(ModelConfig):
39
39
  [1, 3, 5],
40
40
  ],
41
41
  activation: nn.Module = nn.LeakyReLU(0.1),
42
- resblock: int = 0,
42
+ resblock: Union[int, str] = "1",
43
43
  gen_istft_n_fft: int = 16,
44
44
  sampling_rate: Number = 24000,
45
45
  *args,
@@ -59,7 +59,11 @@ class iSTFTNetConfig(ModelConfig):
59
59
  }
60
60
  super().__init__(**settings)
61
61
 
62
-
62
+ def post_process(self):
63
+ if isinstance(self.resblock, str):
64
+ self.resblock = 0 if self.resblock == "1" else 1
65
+
66
+
63
67
  def get_padding(ks, d):
64
68
  return int((ks * d - d) / 2)
65
69
 
@@ -271,15 +275,15 @@ class iSTFTNetGenerator(ConvNets):
271
275
  def load_weights(
272
276
  self,
273
277
  path,
274
- raise_if_not_exists=False,
275
278
  strict=False,
276
279
  assign=False,
277
- weights_only=True,
280
+ weights_only=False,
278
281
  mmap=None,
282
+ raise_if_not_exists=False,
279
283
  **pickle_load_args,
280
284
  ):
281
285
  try:
282
- return super().load_weights(
286
+ incompatible_keys = super().load_weights(
283
287
  path,
284
288
  raise_if_not_exists,
285
289
  strict,
@@ -288,6 +292,18 @@ class iSTFTNetGenerator(ConvNets):
288
292
  mmap,
289
293
  **pickle_load_args,
290
294
  )
295
+ if incompatible_keys:
296
+ self.remove_norms()
297
+ incompatible_keys = super().load_weights(
298
+ path,
299
+ raise_if_not_exists,
300
+ strict,
301
+ assign,
302
+ weights_only,
303
+ mmap,
304
+ **pickle_load_args,
305
+ )
306
+ return incompatible_keys
291
307
  except RuntimeError:
292
308
  self.remove_norms()
293
309
  return super().load_weights(
@@ -303,95 +319,37 @@ class iSTFTNetGenerator(ConvNets):
303
319
  @classmethod
304
320
  def from_pretrained(
305
321
  cls,
306
- model_id: str,
307
- map_location: str = "cpu",
308
- local_files_only: bool = False,
309
- strict: bool = False,
322
+ model_file: PathLike,
323
+ model_config: Union[iSTFTNetConfig, Dict[str, Any]],
310
324
  *,
311
- subfolder: str | None = None,
312
- repo_type: str | None = None,
313
- revision: str | None = None,
314
- cache_dir: str | Path | None = None,
315
- force_download: bool = False,
316
- proxies: Dict | None = None,
317
- token: bool | str | None = None,
318
- resume_download: bool | None = None,
319
- local_dir_use_symlinks: bool | Literal["auto"] = "auto",
325
+ strict: bool = False,
326
+ map_location: str = "cpu",
327
+ weights_only: bool = False,
320
328
  **kwargs,
321
329
  ):
322
- """Load Pytorch pretrained weights and return the loaded model."""
323
- hub_kwargs = dict(
324
- repo_id=model_id,
325
- subfolder=subfolder,
326
- repo_type=repo_type,
327
- revision=revision,
328
- cache_dir=cache_dir,
329
- force_download=force_download,
330
- proxies=proxies,
331
- resume_download=resume_download,
332
- token=token,
333
- local_files_only=local_files_only,
334
- local_dir_use_symlinks=local_dir_use_symlinks,
335
- )
336
330
 
337
- # Download and load hyperparameters (h) used by BigVGAN
338
- _model_path = Path(model_id)
339
- config_file = None
340
- if is_path_valid(model_id):
341
- if is_file(model_id):
342
- _p_conf = _model_path.parent / "config.json"
343
- else:
344
- _p_conf = _model_path / "config.json"
345
-
346
- if is_file(_p_conf):
347
- print("Loading config.json from local directory")
348
- config_file = Path(model_id, "config.json")
349
- else:
350
- if not local_files_only:
351
- print(f"Loading config from {model_id}")
352
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
353
- else:
354
- if not local_files_only:
355
- print(f"Loading config from {model_id}")
356
- config_file = hf_hub_download(filename="config.json", **hub_kwargs)
357
-
358
- if config_file is not None:
359
- model = cls(iSTFTNetConfig(**load_json(config_file)))
360
- else:
361
- model = cls()
362
-
363
- # Download and load pretrained generator weight
364
- _retrieve_kwargs = dict(
365
- **hub_kwargs,
366
- filename="generator.pt",
331
+ is_file(model_file, validate=True)
332
+ model_state_dict = torch.load(
333
+ model_file, weights_only=weights_only, map_location=map_location
367
334
  )
368
- path = Path(model_id)
369
- if path.exists():
370
- if path.is_dir():
371
- path = path / "generator.pt"
372
- if path.exists():
373
- print("Loading weights from local directory")
374
- model_file = str(path)
375
- else:
376
- print(f"Loading weights from {model_id}")
377
- model_file = hf_hub_download(**_retrieve_kwargs)
378
- else:
379
- print("Loading weights from local directory")
380
- model_file = str(path)
335
+
336
+ if isinstance(model_config, iSTFTNetConfig):
337
+ h = model_config
381
338
  else:
382
- print(f"Loading weights from {model_id}")
383
- model_file = hf_hub_download(**_retrieve_kwargs)
384
- checkpoint_dict = torch.load(model_file, map_location=map_location)
339
+ h = iSTFTNetConfig(**model_config)
385
340
 
341
+ model = cls(h)
386
342
  try:
387
- model.load_state_dict(checkpoint_dict, strict=strict)
343
+ incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
344
+ if incompatible_keys:
345
+ model.remove_norms()
346
+ model.load_state_dict(model_state_dict, strict=strict)
388
347
  except RuntimeError:
389
348
  print(
390
349
  f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
391
350
  )
392
351
  model.remove_norms()
393
- model.load_state_dict(checkpoint_dict, strict=strict)
394
-
352
+ model.load_state_dict(model_state_dict, strict=strict)
395
353
  return model
396
354
 
397
355
 
@@ -0,0 +1,124 @@
1
+ __all__ = ["ConvNets", "Conv1dEXT"]
2
+ import math
3
+ from lt_utils.common import *
4
+ import torch.nn.functional as F
5
+ from lt_tensor.torch_commons import *
6
+ from lt_tensor.model_base import Model
7
+ from lt_tensor.misc_utils import log_tensor
8
+ from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
9
+
10
+
11
+ def spectral_norm_select(module: nn.Module, enabled: bool):
12
+ if enabled:
13
+ return spectral_norm(module)
14
+ return module
15
+
16
+
17
+ def get_weight_norm(norm_type: Optional[Literal["weight", "spectral"]] = None):
18
+ if not norm_type:
19
+ return lambda x: x
20
+ if norm_type == "weight":
21
+ return lambda x: weight_norm(x)
22
+ return lambda x: spectral_norm(x)
23
+
24
+
25
+ def remove_norm(module, name: str = "weight"):
26
+ try:
27
+ try:
28
+ remove_parametrizations(module, name, leave_parametrized=False)
29
+ except:
30
+ # many times will fail with 'leave_parametrized'
31
+ remove_parametrizations(module, name)
32
+ except ValueError:
33
+ pass # not parametrized
34
+
35
+
36
+ class ConvNets(Model):
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def remove_norms(self, name: str = "weight"):
42
+ for module in self.modules():
43
+ if "Conv" in module.__class__.__name__:
44
+ remove_norm(module, name)
45
+
46
+ @staticmethod
47
+ def init_weights(
48
+ m: nn.Module,
49
+ norm: Optional[Literal["spectral", "weight"]] = None,
50
+ mean=0.0,
51
+ std=0.02,
52
+ name: str = "weight",
53
+ n_power_iterations: int = 1,
54
+ eps: float = 1e-9,
55
+ dim_sn: Optional[int] = None,
56
+ dim_wn: int = 0,
57
+ ):
58
+ if "Conv" in m.__class__.__name__:
59
+ if norm is not None:
60
+ try:
61
+ if norm == "spectral":
62
+ m.apply(
63
+ lambda m: spectral_norm(
64
+ m,
65
+ n_power_iterations=n_power_iterations,
66
+ eps=eps,
67
+ name=name,
68
+ dim=dim_sn,
69
+ )
70
+ )
71
+ else:
72
+ m.apply(lambda m: weight_norm(m, name=name, dim=dim_wn))
73
+ except ValueError:
74
+ pass
75
+ m.weight.data.normal_(mean, std)
76
+
77
+
78
+ class Conv1dEXT(ConvNets):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ out_channels: Optional[int] = None,
83
+ kernel_size: Union[int, Tuple[int, ...]] = 1,
84
+ stride: Union[int, Tuple[int, ...]] = 1,
85
+ padding: Union[int, Tuple[int, ...]] = 0,
86
+ dilation: Union[int, Tuple[int, ...]] = 1,
87
+ groups: int = 1,
88
+ bias: bool = True,
89
+ padding_mode: str = "zeros",
90
+ device: Optional[Any] = None,
91
+ dtype: Optional[Any] = None,
92
+ apply_norm: Optional[Literal["weight", "spectral"]] = None,
93
+ activation: nn.Module = nn.Identity(),
94
+ *args,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(*args, **kwargs)
98
+ if not out_channels:
99
+ out_channels = in_channels
100
+ cnn_kwargs = dict(
101
+ in_channels=in_channels,
102
+ out_channels=out_channels,
103
+ kernel_size=kernel_size,
104
+ stride=stride,
105
+ padding=padding,
106
+ dilation=dilation,
107
+ groups=groups,
108
+ bias=bias,
109
+ padding_mode=padding_mode,
110
+ device=device,
111
+ dtype=dtype,
112
+ )
113
+ if apply_norm is None:
114
+ self.cnn = nn.Conv1d(**cnn_kwargs)
115
+ else:
116
+ if apply_norm == "spectral":
117
+ self.cnn = spectral_norm(nn.Conv1d(**cnn_kwargs))
118
+ else:
119
+ self.cnn = weight_norm(nn.Conv1d(**cnn_kwargs))
120
+ self.activation = activation
121
+ self.cnn.apply(self.init_weights)
122
+
123
+ def forward(self, input: Tensor):
124
+ return self.cnn(self.activation(input))
@@ -2,12 +2,10 @@ __all__ = [
2
2
  "spectral_norm_select",
3
3
  "get_weight_norm",
4
4
  "ResBlock1D",
5
- "ResBlock2D",
6
5
  "ResBlock1DShuffled",
7
6
  "AdaResBlock1D",
8
7
  "ResBlocks1D",
9
8
  "ResBlock1D2",
10
- "ShuffleBlock2D",
11
9
  ]
12
10
  import math
13
11
  from lt_utils.common import *
@@ -16,73 +14,7 @@ from lt_tensor.torch_commons import *
16
14
  from lt_tensor.model_base import Model
17
15
  from lt_tensor.misc_utils import log_tensor
18
16
  from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
19
-
20
-
21
- def spectral_norm_select(module: nn.Module, enabled: bool):
22
- if enabled:
23
- return spectral_norm(module)
24
- return module
25
-
26
-
27
- def get_weight_norm(norm_type: Optional[Literal["weight", "spectral"]] = None):
28
- if not norm_type:
29
- return lambda x: x
30
- if norm_type == "weight":
31
- return lambda x: weight_norm(x)
32
- return lambda x: spectral_norm(x)
33
-
34
-
35
- def remove_norm(module, name: str = "weight"):
36
- try:
37
- try:
38
- remove_parametrizations(module, name, leave_parametrized=False)
39
- except:
40
- # many times will fail with 'leave_parametrized'
41
- remove_parametrizations(module, name)
42
- except ValueError:
43
- pass # not parametrized
44
-
45
-
46
- class ConvNets(Model):
47
-
48
- def __init__(self, *args, **kwargs):
49
- super().__init__(*args, **kwargs)
50
-
51
- def remove_norms(self, name: str = "weight"):
52
- for module in self.modules():
53
- if "Conv" in module.__class__.__name__:
54
- remove_norm(module, name)
55
-
56
- @staticmethod
57
- def init_weights(
58
- m: nn.Module,
59
- norm: Optional[Literal["spectral", "weight"]] = None,
60
- mean=0.0,
61
- std=0.02,
62
- name: str = "weight",
63
- n_power_iterations: int = 1,
64
- eps: float = 1e-9,
65
- dim_sn: Optional[int] = None,
66
- dim_wn: int = 0,
67
- ):
68
- if "Conv" in m.__class__.__name__:
69
- if norm is not None:
70
- try:
71
- if norm == "spectral":
72
- m.apply(
73
- lambda m: spectral_norm(
74
- m,
75
- n_power_iterations=n_power_iterations,
76
- eps=eps,
77
- name=name,
78
- dim=dim_sn,
79
- )
80
- )
81
- else:
82
- m.apply(lambda m: weight_norm(m, name=name, dim=dim_wn))
83
- except ValueError:
84
- pass
85
- m.weight.data.normal_(mean, std)
17
+ from lt_tensor.model_zoo.convs import ConvNets
86
18
 
87
19
 
88
20
  def get_padding(ks, d):
@@ -151,7 +83,6 @@ class ResBlock1DShuffled(ConvNets):
151
83
  self.last_index = len(self.conv_nets) - 1
152
84
 
153
85
  def _get_conv_layer(self, id, ch, k, stride, d, actv):
154
- get_padding = lambda ks, d: int((ks * d - d) / 2)
155
86
  return nn.Sequential(
156
87
  actv, # 1
157
88
  weight_norm(
@@ -172,72 +103,6 @@ class ResBlock1DShuffled(ConvNets):
172
103
  return x
173
104
 
174
105
 
175
- class ResBlock2D(Model):
176
- def __init__(
177
- self,
178
- in_channels: int,
179
- out_channels: Optional[int] = None,
180
- hidden_dim: int = 32,
181
- downscale: bool = False,
182
- activation: nn.Module = nn.LeakyReLU(0.2),
183
- ):
184
- super().__init__()
185
- stride = 2 if downscale else 1
186
- if out_channels is None:
187
- out_channels = in_channels
188
-
189
- self.block = nn.Sequential(
190
- nn.Conv2d(in_channels, hidden_dim, 3, stride, 1),
191
- activation,
192
- nn.Conv2d(hidden_dim, hidden_dim, 7, 1, 3),
193
- activation,
194
- nn.Conv2d(hidden_dim, out_channels, 3, 1, 1),
195
- )
196
-
197
- self.skip = nn.Identity()
198
- if downscale or in_channels != out_channels:
199
- self.skip = spectral_norm_select(
200
- nn.Conv2d(in_channels, out_channels, 1, stride)
201
- )
202
- # on less to be handled every cycle
203
- self.sqrt_2 = math.sqrt(2)
204
-
205
- def forward(self, x: Tensor):
206
- return x + ((self.block(x) + self.skip(x)) / self.sqrt_2)
207
-
208
-
209
- class ShuffleBlock2D(ConvNets):
210
- def __init__(
211
- self,
212
- channels: int,
213
- out_channels: Optional[int] = None,
214
- hidden_dim: int = 32,
215
- downscale: bool = False,
216
- activation: nn.Module = nn.LeakyReLU(0.1),
217
- ):
218
- super().__init__()
219
- if out_channels is None:
220
- out_channels = channels
221
- self.shuffle = nn.ChannelShuffle(groups=2)
222
- self.ch_split = lambda tensor: torch.split(tensor, 1, dim=1)
223
- self.activation = activation
224
- self.resblock_2d = ResBlock2D(
225
- channels, out_channels, hidden_dim, downscale, activation
226
- )
227
-
228
- def shuffle_channels(self, tensor: torch.Tensor):
229
- with torch.no_grad():
230
- x = F.channel_shuffle(tensor.transpose(1, -1), tensor.shape[1]).transpose(
231
- -1, 1
232
- )
233
- return self.ch_split(x)
234
-
235
- def forward(self, x: torch.Tensor):
236
- ch1, ch2 = self.shuffle_channels(x)
237
- ch2 = self.resblock_2d(ch2)
238
- return torch.cat((ch1, ch2), dim=1)
239
-
240
-
241
106
  class AdaResBlock1D(ConvNets):
242
107
  def __init__(
243
108
  self,
@@ -1,3 +1,3 @@
1
- from .audio import AudioProcessor
1
+ from .audio import AudioProcessor, AudioProcessorConfig
2
2
 
3
- __all__ = ["AudioProcessor"]
3
+ __all__ = ["AudioProcessor", "AudioProcessorConfig"]