lt-tensor 0.0.1a38__tar.gz → 0.0.1a40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/PKG-INFO +1 -1
  2. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/__init__.py +1 -1
  3. lt_tensor-0.0.1a40/lt_tensor/model_zoo/audio_models/bemaganv2/__init__.py +205 -0
  4. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +14 -39
  5. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +20 -19
  6. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +24 -44
  7. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/istft/__init__.py +15 -39
  8. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/convs.py +35 -4
  9. lt_tensor-0.0.1a40/lt_tensor/model_zoo/losses/_envelope_disc/__init__.py +116 -0
  10. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/discriminators.py +34 -64
  11. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/noise_tools.py +22 -13
  12. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/processors/audio.py +116 -62
  13. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/PKG-INFO +1 -1
  14. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/SOURCES.txt +2 -0
  15. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/setup.py +1 -1
  16. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/LICENSE +0 -0
  17. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/README.md +0 -0
  18. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/config_templates.py +0 -0
  19. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/losses.py +0 -0
  20. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/lr_schedulers.py +0 -0
  21. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/math_ops.py +0 -0
  22. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/misc_utils.py +0 -0
  23. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_base.py +0 -0
  24. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/__init__.py +0 -0
  25. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/__init__.py +0 -0
  26. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/__init__.py +0 -0
  27. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/act.py +0 -0
  28. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/filter.py +0 -0
  29. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/resample.py +0 -0
  30. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
  31. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
  32. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/resblocks.py +0 -0
  33. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/basic.py +0 -0
  34. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/features.py +0 -0
  35. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/fusion.py +0 -0
  36. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
  37. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/CQT/transforms.py +0 -0
  38. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/CQT/utils.py +0 -0
  39. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/__init__.py +0 -0
  40. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  41. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/residual.py +0 -0
  42. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/transformer.py +0 -0
  43. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/monotonic_align.py +0 -0
  44. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/processors/__init__.py +0 -0
  45. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/torch_commons.py +0 -0
  46. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor/transform.py +0 -0
  47. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/dependency_links.txt +0 -0
  48. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/requires.txt +0 -0
  49. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/top_level.txt +0 -0
  50. {lt_tensor-0.0.1a38 → lt_tensor-0.0.1a40}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a38
3
+ Version: 0.0.1a40
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a38"
1
+ __version__ = "0.0.1a40"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
@@ -0,0 +1,205 @@
1
+ from lt_utils.common import *
2
+ from lt_tensor.torch_commons import *
3
+ from lt_tensor.model_zoo.convs import ConvNets
4
+ from lt_tensor.config_templates import ModelConfig
5
+ from lt_utils.file_ops import is_file, load_json
6
+ from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
7
+ from lt_tensor.model_zoo.activations import snake, alias_free
8
+ from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
9
+
10
+
11
+ class BemaGANv2Config(ModelConfig):
12
+ # Training params
13
+ in_channels: int = 80
14
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2]
15
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4]
16
+ upsample_initial_channel: int = 1536
17
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
18
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
19
+ [1, 3, 5],
20
+ [1, 3, 5],
21
+ [1, 3, 5],
22
+ ]
23
+
24
+ activation: Literal["snake", "snakebeta"] = "snakebeta"
25
+ resblock_activation: Literal["snake", "snakebeta"] = "snakebeta"
26
+ resblock: int = 0
27
+ use_bias_at_final: bool = True
28
+ use_tanh_at_final: bool = True
29
+ snake_logscale: bool = True
30
+
31
+ def __init__(
32
+ self,
33
+ in_channels: int = 80,
34
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2],
35
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4],
36
+ upsample_initial_channel: int = 1536,
37
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
38
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
39
+ [1, 3, 5],
40
+ [1, 3, 5],
41
+ [1, 3, 5],
42
+ ],
43
+ activation: Literal["snake", "snakebeta"] = "snakebeta",
44
+ resblock_activation: Literal["snake", "snakebeta"] = "snakebeta",
45
+ resblock: Union[int, str] = "1",
46
+ use_bias_at_final: bool = False,
47
+ use_tanh_at_final: bool = False,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ settings = {
52
+ "in_channels": in_channels,
53
+ "upsample_rates": upsample_rates,
54
+ "upsample_kernel_sizes": upsample_kernel_sizes,
55
+ "upsample_initial_channel": upsample_initial_channel,
56
+ "resblock_kernel_sizes": resblock_kernel_sizes,
57
+ "resblock_dilation_sizes": resblock_dilation_sizes,
58
+ "activation": activation,
59
+ "resblock_activation": resblock_activation,
60
+ "resblock": resblock,
61
+ "use_bias_at_final": use_bias_at_final,
62
+ "use_tanh_at_final": use_tanh_at_final,
63
+ }
64
+ super().__init__(**settings)
65
+
66
+ def post_process(self):
67
+ if isinstance(self.resblock, str):
68
+ self.resblock = 0 if self.resblock == "1" else 1
69
+
70
+
71
+ class BemaGANv2Generator(ConvNets):
72
+
73
+ def __init__(
74
+ self, cfg: Union[BemaGANv2Config, Dict[str, object]] = BemaGANv2Config()
75
+ ):
76
+ super().__init__()
77
+ cfg = cfg if isinstance(cfg, BemaGANv2Config) else BemaGANv2Config(**cfg)
78
+ self.cfg = cfg
79
+
80
+ actv = get_snake(self.cfg.activation)
81
+
82
+ self.num_kernels = len(cfg.resblock_kernel_sizes)
83
+ self.num_upsamples = len(cfg.upsample_rates)
84
+
85
+ self.conv_pre = weight_norm(
86
+ nn.Conv1d(cfg.num_mels, cfg.upsample_initial_channel, 7, 1, padding=3)
87
+ )
88
+
89
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
90
+ resblock = AMPBlock1 if cfg.resblock == 0 else AMPBlock2
91
+
92
+ # transposed conv-based upsamplers. does not apply anti-aliasing
93
+ self.ups = nn.ModuleList()
94
+ for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
95
+ self.ups.append(
96
+ nn.ModuleList(
97
+ [
98
+ weight_norm(
99
+ nn.ConvTranspose1d(
100
+ cfg.upsample_initial_channel // (2**i),
101
+ cfg.upsample_initial_channel // (2 ** (i + 1)),
102
+ k,
103
+ u,
104
+ padding=(k - u) // 2,
105
+ )
106
+ )
107
+ ]
108
+ )
109
+ )
110
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
111
+ self.resblocks = nn.ModuleList()
112
+ for i in range(len(self.ups)):
113
+ ch = cfg.upsample_initial_channel // (2 ** (i + 1))
114
+ for k, d in zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes):
115
+ self.resblocks.append(
116
+ resblock(
117
+ ch,
118
+ k,
119
+ d,
120
+ snake_logscale=cfg.snake_logscale,
121
+ activation=cfg.resblock_activation,
122
+ )
123
+ )
124
+
125
+ self.activation_post = actv(ch, alpha_logscale=cfg.snake_logscale)
126
+ # post conv
127
+
128
+ self.conv_post = weight_norm(
129
+ nn.Conv1d(ch, 1, 7, 1, padding=3, bias=self.cfg.use_bias_at_final)
130
+ )
131
+ self._use_tanh = self.cfg.use_tanh_at_final
132
+
133
+ # weight initialization
134
+ for i in range(len(self.ups)):
135
+ self.ups[i].apply(self.init_weights)
136
+ self.conv_post.apply(self.init_weights)
137
+
138
+ def forward(self, x: Tensor):
139
+ # pre conv
140
+ x = self.conv_pre(x)
141
+
142
+ for i in range(self.num_upsamples):
143
+ # upsampling
144
+ for i_up in range(len(self.ups[i])):
145
+ x = self.ups[i][i_up](x)
146
+ # AMP blocks
147
+ xs = None
148
+ for j in range(self.num_kernels):
149
+ if xs is None:
150
+ xs = self.resblocks[i * self.num_kernels + j](x)
151
+ else:
152
+ xs += self.resblocks[i * self.num_kernels + j](x)
153
+ x = xs / self.num_kernels
154
+
155
+ # post conv
156
+ x = self.activation_post(x)
157
+ x = self.conv_post(x)
158
+ if self._use_tanh:
159
+ return x.tanh()
160
+ return x
161
+
162
+ @classmethod
163
+ def from_pretrained(
164
+ cls,
165
+ model_file: PathLike,
166
+ model_config: Union[
167
+ BemaGANv2Config, Dict[str, Any], PathLike
168
+ ] = BemaGANv2Config(),
169
+ *,
170
+ remove_norms: bool = False,
171
+ strict: bool = True,
172
+ map_location: str = "cpu",
173
+ weights_only: bool = False,
174
+ **kwargs,
175
+ ):
176
+
177
+ is_file(model_file, validate=True)
178
+ model_state_dict = torch.load(
179
+ model_file,
180
+ weights_only=weights_only,
181
+ map_location=map_location,
182
+ )
183
+
184
+ if isinstance(model_config, BemaGANv2Config):
185
+ h = model_config
186
+ elif isinstance(model_config, dict):
187
+ h = BemaGANv2Config(**model_config)
188
+ elif isinstance(model_config, (str, Path, bytes)):
189
+ h = BemaGANv2Config(
190
+ **load_json(model_config, BemaGANv2Config().state_dict())
191
+ )
192
+
193
+ model = cls(h)
194
+ if remove_norms:
195
+ model.remove_norms()
196
+ try:
197
+ model.load_state_dict(model_state_dict, strict=strict)
198
+ return model
199
+ except RuntimeError:
200
+ print(
201
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
202
+ )
203
+ model.remove_norms()
204
+ model.load_state_dict(model_state_dict, strict=strict)
205
+ return model
@@ -2,9 +2,9 @@ from lt_utils.common import *
2
2
  from lt_tensor.torch_commons import *
3
3
  from lt_tensor.model_zoo.convs import ConvNets
4
4
  from lt_tensor.config_templates import ModelConfig
5
- from lt_tensor.model_zoo.activations import snake, alias_free
5
+ from lt_tensor.model_zoo.activations import alias_free
6
6
  from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
7
- from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
7
+ from lt_utils.file_ops import load_json, is_file
8
8
 
9
9
 
10
10
  class BigVGANConfig(ModelConfig):
@@ -78,8 +78,9 @@ class BigVGAN(ConvNets):
78
78
 
79
79
  """
80
80
 
81
- def __init__(self, cfg: BigVGANConfig):
81
+ def __init__(self, cfg: Union[BigVGANConfig, Dict[str, object]] = BigVGANConfig()):
82
82
  super().__init__()
83
+ cfg = cfg if isinstance(cfg, BigVGANConfig) else BigVGANConfig(**cfg)
83
84
  self.cfg = cfg
84
85
  actv = get_snake(self.cfg.activation)
85
86
 
@@ -173,46 +174,16 @@ class BigVGAN(ConvNets):
173
174
  return x.tanh()
174
175
  return x.clamp(min=-1.0, max=1.0)
175
176
 
176
- def load_weights(
177
- self,
178
- path,
179
- strict=False,
180
- assign=False,
181
- weights_only=False,
182
- mmap=None,
183
- raise_if_not_exists=False,
184
- **pickle_load_args,
185
- ):
186
- try:
187
- return super().load_weights(
188
- path,
189
- raise_if_not_exists,
190
- strict,
191
- assign,
192
- weights_only,
193
- mmap,
194
- **pickle_load_args,
195
- )
196
- except RuntimeError:
197
- self.remove_norms()
198
- return super().load_weights(
199
- path,
200
- raise_if_not_exists,
201
- strict,
202
- assign,
203
- weights_only,
204
- mmap,
205
- **pickle_load_args,
206
- )
207
-
208
177
  @classmethod
209
178
  def from_pretrained(
210
179
  cls,
211
180
  model_file: PathLike,
212
- model_config: Union[BigVGANConfig, Dict[str, Any]],
181
+ model_config: Union[
182
+ BigVGANConfig, Dict[str, Any], Dict[str, Any], PathLike
183
+ ] = BigVGANConfig(),
213
184
  *,
214
185
  remove_norms: bool = False,
215
- strict: bool = False,
186
+ strict: bool = True,
216
187
  map_location: str = "cpu",
217
188
  weights_only: bool = False,
218
189
  **kwargs,
@@ -220,13 +191,17 @@ class BigVGAN(ConvNets):
220
191
 
221
192
  is_file(model_file, validate=True)
222
193
  model_state_dict = torch.load(
223
- model_file, weights_only=weights_only, map_location=map_location
194
+ model_file,
195
+ weights_only=weights_only,
196
+ map_location=map_location,
224
197
  )
225
198
 
226
199
  if isinstance(model_config, BigVGANConfig):
227
200
  h = model_config
228
- else:
201
+ elif isinstance(model_config, dict):
229
202
  h = BigVGANConfig(**model_config)
203
+ elif isinstance(model_config, (str, Path, bytes)):
204
+ h = BigVGANConfig(**load_json(model_config, BigVGANConfig().state_dict()))
230
205
 
231
206
  model = cls(h)
232
207
  if remove_norms:
@@ -177,43 +177,44 @@ class ResidualBlock(Model):
177
177
 
178
178
 
179
179
  class DiffWave(Model):
180
- def __init__(self, params: DiffWaveConfig = DiffWaveConfig()):
180
+ def __init__(self, cfg: Union[DiffWaveConfig, dict[str, object]] = DiffWaveConfig()):
181
181
  super().__init__()
182
- self.params = params
183
- self.n_hop = self.params.hop_samples
182
+ cfg = cfg if isinstance(cfg, DiffWaveConfig) else DiffWaveConfig(**cfg)
183
+ self.cfg = cfg
184
+ self.n_hop = self.cfg.hop_samples
184
185
  self.input_projection = ConvEXT(
185
186
  in_channels=1,
186
- out_channels=params.residual_channels,
187
+ out_channels=cfg.residual_channels,
187
188
  kernel_size=1,
188
- apply_norm=self.params.apply_norm,
189
+ apply_norm=self.cfg.apply_norm,
189
190
  activation_out=nn.LeakyReLU(0.1),
190
191
  )
191
- self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
192
+ self.diffusion_embedding = DiffusionEmbedding(len(cfg.noise_schedule))
192
193
  self.spectrogram_upsampler = (
193
- SpectrogramUpsampler() if not self.params.unconditional else None
194
+ SpectrogramUpsampler() if not self.cfg.unconditional else None
194
195
  )
195
196
 
196
197
  self.residual_layers = nn.ModuleList(
197
198
  [
198
199
  ResidualBlock(
199
- params.n_mels,
200
- params.residual_channels,
201
- 2 ** (i % params.dilation_cycle_length),
202
- uncond=params.unconditional,
203
- apply_norm=self.params.apply_norm_resblock,
200
+ cfg.n_mels,
201
+ cfg.residual_channels,
202
+ 2 ** (i % cfg.dilation_cycle_length),
203
+ uncond=cfg.unconditional,
204
+ apply_norm=self.cfg.apply_norm_resblock,
204
205
  )
205
- for i in range(params.residual_layers)
206
+ for i in range(cfg.residual_layers)
206
207
  ]
207
208
  )
208
209
  self.skip_projection = ConvEXT(
209
- in_channels=params.residual_channels,
210
- out_channels=params.residual_channels,
210
+ in_channels=cfg.residual_channels,
211
+ out_channels=cfg.residual_channels,
211
212
  kernel_size=1,
212
- apply_norm=self.params.apply_norm,
213
+ apply_norm=self.cfg.apply_norm,
213
214
  activation_out=nn.LeakyReLU(0.1),
214
215
  )
215
216
  self.output_projection = ConvEXT(
216
- params.residual_channels, 1, 1, apply_norm=self.params.apply_norm, init_weights=True,
217
+ cfg.residual_channels, 1, 1, apply_norm=self.cfg.apply_norm, init_weights=True,
217
218
  )
218
219
  self.activation = nn.LeakyReLU(0.1)
219
220
  self._res_d = sqrt(len(self.residual_layers))
@@ -224,7 +225,7 @@ class DiffWave(Model):
224
225
  diffusion_step: Tensor,
225
226
  spectrogram: Optional[Tensor] = None,
226
227
  ):
227
- if not self.params.unconditional:
228
+ if not self.cfg.unconditional:
228
229
  assert spectrogram is not None
229
230
  if audio.ndim < 3:
230
231
  if audio.ndim == 2:
@@ -234,7 +235,7 @@ class DiffWave(Model):
234
235
 
235
236
  x = self.input_projection(audio)
236
237
  diffusion_step = self.diffusion_embedding(diffusion_step)
237
- if not self.params.unconditional: # use conditional model
238
+ if not self.cfg.unconditional: # use conditional model
238
239
  spectrogram = self.spectrogram_upsampler(spectrogram)
239
240
 
240
241
  skip = torch.zeros_like(x, device=x.device)
@@ -5,7 +5,7 @@ from lt_utils.common import *
5
5
  from lt_tensor.torch_commons import *
6
6
  from lt_tensor.model_zoo.convs import ConvNets
7
7
  from lt_tensor.config_templates import ModelConfig
8
- from lt_utils.file_ops import is_file
8
+ from lt_utils.file_ops import is_file, load_json
9
9
  from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
10
10
 
11
11
 
@@ -16,11 +16,15 @@ def get_padding(kernel_size, dilation=1):
16
16
  class HifiganConfig(ModelConfig):
17
17
  # Training params
18
18
  in_channels: int = 80
19
- upsample_rates: List[Union[int, List[int]]] = [8,8,2,2]
20
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4]
19
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2]
20
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4]
21
21
  upsample_initial_channel: int = 512
22
22
  resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
23
- resblock_dilation_sizes: List[Union[int, List[int]]] = [[1,3,5], [1,3,5], [1,3,5]]
23
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
24
+ [1, 3, 5],
25
+ [1, 3, 5],
26
+ [1, 3, 5],
27
+ ]
24
28
 
25
29
  activation: nn.Module = nn.LeakyReLU(0.1)
26
30
  resblock_activation: nn.Module = nn.LeakyReLU(0.1)
@@ -29,10 +33,10 @@ class HifiganConfig(ModelConfig):
29
33
  def __init__(
30
34
  self,
31
35
  in_channels: int = 80,
32
- upsample_rates: List[Union[int, List[int]]] = [8,8,2,2],
33
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4],
36
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2],
37
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4],
34
38
  upsample_initial_channel: int = 512,
35
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3,7,11],
39
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
36
40
  resblock_dilation_sizes: List[Union[int, List[int]]] = [
37
41
  [1, 3, 5],
38
42
  [1, 3, 5],
@@ -63,9 +67,11 @@ class HifiganConfig(ModelConfig):
63
67
 
64
68
 
65
69
  class HifiganGenerator(ConvNets):
66
- def __init__(self, cfg: HifiganConfig = HifiganConfig()):
70
+ def __init__(self, cfg: Union[HifiganConfig, Dict[str, object]] = HifiganConfig()):
67
71
  super().__init__()
72
+ cfg = cfg if isinstance(cfg, HifiganConfig) else HifiganConfig(**cfg)
68
73
  self.cfg = cfg
74
+
69
75
  self.num_kernels = len(cfg.resblock_kernel_sizes)
70
76
  self.num_upsamples = len(cfg.upsample_rates)
71
77
  self.conv_pre = weight_norm(
@@ -115,46 +121,16 @@ class HifiganGenerator(ConvNets):
115
121
  x = self.conv_post(self.activation(x))
116
122
  return x.tanh()
117
123
 
118
- def load_weights(
119
- self,
120
- path,
121
- strict=False,
122
- assign=False,
123
- weights_only=False,
124
- mmap=None,
125
- raise_if_not_exists=False,
126
- **pickle_load_args,
127
- ):
128
- try:
129
- return super().load_weights(
130
- path,
131
- raise_if_not_exists,
132
- strict,
133
- assign,
134
- weights_only,
135
- mmap,
136
- **pickle_load_args,
137
- )
138
- except RuntimeError:
139
- self.remove_norms()
140
- return super().load_weights(
141
- path,
142
- raise_if_not_exists,
143
- strict,
144
- assign,
145
- weights_only,
146
- mmap,
147
- **pickle_load_args,
148
- )
149
-
150
124
  @classmethod
151
125
  def from_pretrained(
152
126
  cls,
153
127
  model_file: PathLike,
154
- model_config: Union[HifiganConfig, Dict[str, Any]],
128
+ model_config: Union[
129
+ HifiganConfig, Dict[str, Any], Dict[str, Any], PathLike
130
+ ] = HifiganConfig(),
155
131
  *,
156
132
  remove_norms: bool = False,
157
- strict: bool = False,
133
+ strict: bool = True,
158
134
  map_location: str = "cpu",
159
135
  weights_only: bool = False,
160
136
  **kwargs,
@@ -162,13 +138,17 @@ class HifiganGenerator(ConvNets):
162
138
 
163
139
  is_file(model_file, validate=True)
164
140
  model_state_dict = torch.load(
165
- model_file, weights_only=weights_only, map_location=map_location
141
+ model_file,
142
+ weights_only=weights_only,
143
+ map_location=map_location,
166
144
  )
167
145
 
168
146
  if isinstance(model_config, HifiganConfig):
169
147
  h = model_config
170
- else:
148
+ elif isinstance(model_config, dict):
171
149
  h = HifiganConfig(**model_config)
150
+ elif isinstance(model_config, (str, Path, bytes)):
151
+ h = HifiganConfig(**load_json(model_config, HifiganConfig().state_dict()))
172
152
 
173
153
  model = cls(h)
174
154
  if remove_norms:
@@ -3,7 +3,7 @@ from lt_utils.common import *
3
3
  from lt_tensor.torch_commons import *
4
4
  from lt_tensor.model_zoo.convs import ConvNets
5
5
  from lt_tensor.config_templates import ModelConfig
6
- from lt_utils.file_ops import is_file
6
+ from lt_utils.file_ops import is_file, load_json
7
7
  from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
8
8
 
9
9
 
@@ -67,8 +67,11 @@ class iSTFTNetConfig(ModelConfig):
67
67
 
68
68
 
69
69
  class iSTFTNetGenerator(ConvNets):
70
- def __init__(self, cfg: iSTFTNetConfig = iSTFTNetConfig()):
70
+ def __init__(
71
+ self, cfg: Union[iSTFTNetConfig, Dict[str, object]] = iSTFTNetConfig()
72
+ ):
71
73
  super().__init__()
74
+ cfg = cfg if isinstance(cfg, iSTFTNetConfig) else iSTFTNetConfig(**cfg)
72
75
  self.cfg = cfg
73
76
  self.num_kernels = len(cfg.resblock_kernel_sizes)
74
77
  self.num_upsamples = len(cfg.upsample_rates)
@@ -146,46 +149,16 @@ class iSTFTNetGenerator(ConvNets):
146
149
 
147
150
  return spec, phase
148
151
 
149
- def load_weights(
150
- self,
151
- path,
152
- strict=False,
153
- assign=False,
154
- weights_only=False,
155
- mmap=None,
156
- raise_if_not_exists=False,
157
- **pickle_load_args,
158
- ):
159
- try:
160
- return super().load_weights(
161
- path,
162
- raise_if_not_exists,
163
- strict,
164
- assign,
165
- weights_only,
166
- mmap,
167
- **pickle_load_args,
168
- )
169
- except RuntimeError:
170
- self.remove_norms()
171
- return super().load_weights(
172
- path,
173
- raise_if_not_exists,
174
- strict,
175
- assign,
176
- weights_only,
177
- mmap,
178
- **pickle_load_args,
179
- )
180
-
181
152
  @classmethod
182
153
  def from_pretrained(
183
154
  cls,
184
155
  model_file: PathLike,
185
- model_config: Union[iSTFTNetConfig, Dict[str, Any]],
156
+ model_config: Union[
157
+ iSTFTNetConfig, Dict[str, Any], Dict[str, Any], PathLike
158
+ ] = iSTFTNetConfig(),
186
159
  *,
187
160
  remove_norms: bool = False,
188
- strict: bool = False,
161
+ strict: bool = True,
189
162
  map_location: str = "cpu",
190
163
  weights_only: bool = False,
191
164
  **kwargs,
@@ -193,14 +166,17 @@ class iSTFTNetGenerator(ConvNets):
193
166
 
194
167
  is_file(model_file, validate=True)
195
168
  model_state_dict = torch.load(
196
- model_file, weights_only=weights_only, map_location=map_location
169
+ model_file,
170
+ weights_only=weights_only,
171
+ map_location=map_location,
197
172
  )
198
173
 
199
174
  if isinstance(model_config, iSTFTNetConfig):
200
175
  h = model_config
201
- else:
176
+ elif isinstance(model_config, dict):
202
177
  h = iSTFTNetConfig(**model_config)
203
-
178
+ elif isinstance(model_config, (str, Path, bytes)):
179
+ h = iSTFTNetConfig(**load_json(model_config, iSTFTNetConfig().state_dict()))
204
180
  model = cls(h)
205
181
  if remove_norms:
206
182
  model.remove_norms()
@@ -1,11 +1,7 @@
1
1
  __all__ = ["ConvNets", "ConvEXT"]
2
- import math
3
2
  from lt_utils.common import *
4
- import torch.nn.functional as F
5
3
  from lt_tensor.torch_commons import *
6
4
  from lt_tensor.model_base import Model
7
- from lt_tensor.misc_utils import log_tensor
8
- from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
9
5
  from lt_utils.misc_utils import default
10
6
 
11
7
 
@@ -52,6 +48,41 @@ class ConvNets(Model):
52
48
  if "Conv" in m.__class__.__name__:
53
49
  m.weight.data.normal_(mean, std)
54
50
 
51
+ def load_weights(
52
+ self,
53
+ path,
54
+ strict=False,
55
+ assign=False,
56
+ weights_only=False,
57
+ mmap=None,
58
+ raise_if_not_exists=False,
59
+ **pickle_load_args,
60
+ ):
61
+ try:
62
+ return super().load_weights(
63
+ path,
64
+ raise_if_not_exists,
65
+ strict,
66
+ assign,
67
+ weights_only,
68
+ mmap,
69
+ **pickle_load_args,
70
+ )
71
+ except RuntimeError as e:
72
+ try:
73
+ self.remove_norms()
74
+ return super().load_weights(
75
+ path,
76
+ raise_if_not_exists,
77
+ strict,
78
+ assign,
79
+ weights_only,
80
+ mmap,
81
+ **pickle_load_args,
82
+ )
83
+ except:
84
+ raise e
85
+
55
86
 
56
87
  class ConvEXT(ConvNets):
57
88
  def __init__(