lt-tensor 0.0.1a22__tar.gz → 0.0.1a26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/PKG-INFO +1 -1
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/config_templates.py +9 -5
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/misc_utils.py +15 -3
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +41 -14
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +40 -82
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/istft/__init__.py +41 -83
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/convs.py +124 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/residual.py +1 -136
- lt_tensor-0.0.1a26/lt_tensor/processors/__init__.py +3 -0
- lt_tensor-0.0.1a26/lt_tensor/processors/audio.py +527 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/PKG-INFO +1 -1
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/SOURCES.txt +1 -2
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/setup.py +1 -1
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +0 -536
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/cuda/__init__.py +0 -160
- lt_tensor-0.0.1a22/lt_tensor/processors/__init__.py +0 -3
- lt_tensor-0.0.1a22/lt_tensor/processors/audio.py +0 -456
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/LICENSE +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/README.md +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/__init__.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/losses.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/lr_schedulers.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/math_ops.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_base.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/__init__.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/act.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/basic.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/features.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/fusion.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/pos_encoder.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/transformer.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/monotonic_align.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/noise_tools.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/torch_commons.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/transform.py +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/dependency_links.txt +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/requires.txt +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/top_level.txt +0 -0
- {lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/setup.cfg +0 -0
@@ -6,9 +6,7 @@ from lt_tensor.misc_utils import updateDict
|
|
6
6
|
|
7
7
|
|
8
8
|
class ModelConfig(ABC, OrderedDict):
|
9
|
-
_default_settings: Dict[str, Any] = {}
|
10
9
|
_forbidden_list: List[str] = [
|
11
|
-
"_default_settings",
|
12
10
|
"_forbidden_list",
|
13
11
|
]
|
14
12
|
|
@@ -16,12 +14,15 @@ class ModelConfig(ABC, OrderedDict):
|
|
16
14
|
self,
|
17
15
|
**settings,
|
18
16
|
):
|
19
|
-
self.
|
20
|
-
self.set_state_dict(self._default_settings)
|
17
|
+
self.set_state_dict(settings)
|
21
18
|
|
22
19
|
def reset_settings(self):
|
23
20
|
raise NotImplementedError("Not implemented")
|
24
|
-
|
21
|
+
|
22
|
+
def post_process(self):
|
23
|
+
"""Implement the post process, to do a final check to the input data"""
|
24
|
+
pass
|
25
|
+
|
25
26
|
def save_config(
|
26
27
|
self,
|
27
28
|
path: str,
|
@@ -48,6 +49,7 @@ class ModelConfig(ABC, OrderedDict):
|
|
48
49
|
}
|
49
50
|
updateDict(self, new_state)
|
50
51
|
self.update(**new_state)
|
52
|
+
self.post_process()
|
51
53
|
|
52
54
|
def state_dict(self):
|
53
55
|
return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
|
@@ -89,3 +91,5 @@ class ModelConfig(ABC, OrderedDict):
|
|
89
91
|
settings.pop("path_name", None)
|
90
92
|
|
91
93
|
return ModelConfig(**settings)
|
94
|
+
|
95
|
+
|
@@ -111,10 +111,10 @@ def get_weights(directory: Union[str, PathLike]):
|
|
111
111
|
directory = Path(directory)
|
112
112
|
if is_file(directory):
|
113
113
|
if directory.name.endswith((".pt", ".ckpt", ".pth")):
|
114
|
-
return directory
|
114
|
+
return [directory]
|
115
115
|
directory = directory.parent
|
116
116
|
res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
|
117
|
-
return res
|
117
|
+
return res
|
118
118
|
|
119
119
|
|
120
120
|
def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
|
@@ -128,7 +128,19 @@ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
|
|
128
128
|
return load_json(directory, default)
|
129
129
|
return load_yaml(directory, default)
|
130
130
|
directory = directory.parent
|
131
|
-
res = sorted(
|
131
|
+
res = sorted(
|
132
|
+
find_files(
|
133
|
+
directory,
|
134
|
+
[
|
135
|
+
"config*.json",
|
136
|
+
"*config.json",
|
137
|
+
"config*.yml",
|
138
|
+
"*config.yml",
|
139
|
+
"*config.yaml",
|
140
|
+
"config*.yaml",
|
141
|
+
],
|
142
|
+
)
|
143
|
+
)
|
132
144
|
if res:
|
133
145
|
res = res[-1]
|
134
146
|
if Path(res).name.endswith(".json"):
|
{lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py
RENAMED
@@ -5,6 +5,7 @@ from lt_tensor.torch_commons import *
|
|
5
5
|
from torch.nn import functional as F
|
6
6
|
from lt_tensor.config_templates import ModelConfig
|
7
7
|
from lt_tensor.torch_commons import *
|
8
|
+
from lt_tensor.model_zoo.convs import ConvNets, Conv1dEXT
|
8
9
|
from lt_tensor.model_base import Model
|
9
10
|
from math import sqrt
|
10
11
|
from lt_utils.common import *
|
@@ -18,6 +19,8 @@ class DiffWaveConfig(ModelConfig):
|
|
18
19
|
residual_channels = 64
|
19
20
|
dilation_cycle_length = 10
|
20
21
|
unconditional = False
|
22
|
+
apply_norm: Optional[Literal["weight", "spectral"]] = None
|
23
|
+
apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None
|
21
24
|
noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist()
|
22
25
|
# settings for auto-fixes
|
23
26
|
interpolate = False
|
@@ -44,6 +47,8 @@ class DiffWaveConfig(ModelConfig):
|
|
44
47
|
"area",
|
45
48
|
"nearest-exact",
|
46
49
|
] = "nearest",
|
50
|
+
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
51
|
+
apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None,
|
47
52
|
):
|
48
53
|
settings = {
|
49
54
|
"n_mels": n_mels,
|
@@ -55,16 +60,12 @@ class DiffWaveConfig(ModelConfig):
|
|
55
60
|
"noise_schedule": noise_schedule,
|
56
61
|
"interpolate": interpolate_cond,
|
57
62
|
"interpolation_mode": interpolation_mode,
|
63
|
+
"apply_norm": apply_norm,
|
64
|
+
"apply_norm_resblock": apply_norm_resblock,
|
58
65
|
}
|
59
66
|
super().__init__(**settings)
|
60
67
|
|
61
68
|
|
62
|
-
def Conv1d(*args, **kwargs):
|
63
|
-
layer = nn.Conv1d(*args, **kwargs)
|
64
|
-
nn.init.kaiming_normal_(layer.weight)
|
65
|
-
return layer
|
66
|
-
|
67
|
-
|
68
69
|
class DiffusionEmbedding(Model):
|
69
70
|
def __init__(self, max_steps: int):
|
70
71
|
super().__init__()
|
@@ -117,7 +118,14 @@ class SpectrogramUpsample(Model):
|
|
117
118
|
|
118
119
|
|
119
120
|
class ResidualBlock(Model):
|
120
|
-
def __init__(
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
n_mels,
|
124
|
+
residual_channels,
|
125
|
+
dilation,
|
126
|
+
uncond=False,
|
127
|
+
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
128
|
+
):
|
121
129
|
"""
|
122
130
|
:param n_mels: inplanes of conv1x1 for spectrogram conditional
|
123
131
|
:param residual_channels: audio conv
|
@@ -125,20 +133,28 @@ class ResidualBlock(Model):
|
|
125
133
|
:param uncond: disable spectrogram conditional
|
126
134
|
"""
|
127
135
|
super().__init__()
|
128
|
-
self.dilated_conv =
|
136
|
+
self.dilated_conv = Conv1dEXT(
|
129
137
|
residual_channels,
|
130
138
|
2 * residual_channels,
|
131
139
|
3,
|
132
140
|
padding=dilation,
|
133
141
|
dilation=dilation,
|
142
|
+
apply_norm=apply_norm,
|
134
143
|
)
|
135
144
|
self.diffusion_projection = nn.Linear(512, residual_channels)
|
136
145
|
if not uncond: # conditional model
|
137
|
-
self.conditioner_projection =
|
146
|
+
self.conditioner_projection = Conv1dEXT(
|
147
|
+
n_mels,
|
148
|
+
2 * residual_channels,
|
149
|
+
1,
|
150
|
+
apply_norm=apply_norm,
|
151
|
+
)
|
138
152
|
else: # unconditional model
|
139
153
|
self.conditioner_projection = None
|
140
154
|
|
141
|
-
self.output_projection =
|
155
|
+
self.output_projection = Conv1dEXT(
|
156
|
+
residual_channels, 2 * residual_channels, 1, apply_norm == apply_norm
|
157
|
+
)
|
142
158
|
|
143
159
|
def forward(
|
144
160
|
self,
|
@@ -172,7 +188,12 @@ class DiffWave(Model):
|
|
172
188
|
self.n_hop = self.params.hop_samples
|
173
189
|
self.interpolate = self.params.interpolate
|
174
190
|
self.interpolate_mode = self.params.interpolation_mode
|
175
|
-
self.input_projection =
|
191
|
+
self.input_projection = Conv1dEXT(
|
192
|
+
in_channels=1,
|
193
|
+
out_channels=params.residual_channels,
|
194
|
+
kernel_size=1,
|
195
|
+
apply_norm=self.params.apply_norm,
|
196
|
+
)
|
176
197
|
self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
|
177
198
|
if self.params.unconditional: # use unconditional model
|
178
199
|
self.spectrogram_upsample = None
|
@@ -186,14 +207,20 @@ class DiffWave(Model):
|
|
186
207
|
params.residual_channels,
|
187
208
|
2 ** (i % params.dilation_cycle_length),
|
188
209
|
uncond=params.unconditional,
|
210
|
+
apply_norm=self.params.apply_norm_resblock,
|
189
211
|
)
|
190
212
|
for i in range(params.residual_layers)
|
191
213
|
]
|
192
214
|
)
|
193
|
-
self.skip_projection =
|
194
|
-
params.residual_channels,
|
215
|
+
self.skip_projection = Conv1dEXT(
|
216
|
+
in_channels=params.residual_channels,
|
217
|
+
out_channels=params.residual_channels,
|
218
|
+
kernel_size=1,
|
219
|
+
apply_norm=self.params.apply_norm,
|
220
|
+
)
|
221
|
+
self.output_projection = Conv1dEXT(
|
222
|
+
params.residual_channels, 1, 1, apply_norm=self.params.apply_norm
|
195
223
|
)
|
196
|
-
self.output_projection = Conv1d(params.residual_channels, 1, 1)
|
197
224
|
self.activation = nn.LeakyReLU(0.1)
|
198
225
|
self.r_sqrt = sqrt(len(self.residual_layers))
|
199
226
|
nn.init.zeros_(self.output_projection.weight)
|
{lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py
RENAMED
@@ -1,10 +1,10 @@
|
|
1
1
|
__all__ = ["HifiganGenerator", "HifiganConfig"]
|
2
2
|
from lt_utils.common import *
|
3
3
|
from lt_tensor.torch_commons import *
|
4
|
-
from lt_tensor.model_zoo.
|
4
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
5
5
|
from torch.nn import functional as F
|
6
6
|
from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
|
7
|
-
from
|
7
|
+
from lt_tensor.misc_utils import get_config, get_weights
|
8
8
|
|
9
9
|
|
10
10
|
def get_padding(kernel_size, dilation=1):
|
@@ -43,7 +43,7 @@ class HifiganConfig(ModelConfig):
|
|
43
43
|
[1, 3, 5],
|
44
44
|
],
|
45
45
|
activation: nn.Module = nn.LeakyReLU(0.1),
|
46
|
-
resblock: int =
|
46
|
+
resblock: Union[int, str] = "1",
|
47
47
|
*args,
|
48
48
|
**kwargs,
|
49
49
|
):
|
@@ -59,6 +59,10 @@ class HifiganConfig(ModelConfig):
|
|
59
59
|
}
|
60
60
|
super().__init__(**settings)
|
61
61
|
|
62
|
+
def post_process(self):
|
63
|
+
if isinstance(self.resblock, str):
|
64
|
+
self.resblock = 0 if self.resblock == "1" else 1
|
65
|
+
|
62
66
|
|
63
67
|
class ResBlock1(ConvNets):
|
64
68
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
@@ -240,15 +244,15 @@ class HifiganGenerator(ConvNets):
|
|
240
244
|
def load_weights(
|
241
245
|
self,
|
242
246
|
path,
|
243
|
-
raise_if_not_exists=False,
|
244
247
|
strict=False,
|
245
248
|
assign=False,
|
246
|
-
weights_only=
|
249
|
+
weights_only=False,
|
247
250
|
mmap=None,
|
251
|
+
raise_if_not_exists=False,
|
248
252
|
**pickle_load_args,
|
249
253
|
):
|
250
254
|
try:
|
251
|
-
|
255
|
+
incompatible_keys = super().load_weights(
|
252
256
|
path,
|
253
257
|
raise_if_not_exists,
|
254
258
|
strict,
|
@@ -257,6 +261,18 @@ class HifiganGenerator(ConvNets):
|
|
257
261
|
mmap,
|
258
262
|
**pickle_load_args,
|
259
263
|
)
|
264
|
+
if incompatible_keys:
|
265
|
+
self.remove_norms()
|
266
|
+
incompatible_keys = super().load_weights(
|
267
|
+
path,
|
268
|
+
raise_if_not_exists,
|
269
|
+
strict,
|
270
|
+
assign,
|
271
|
+
weights_only,
|
272
|
+
mmap,
|
273
|
+
**pickle_load_args,
|
274
|
+
)
|
275
|
+
return incompatible_keys
|
260
276
|
except RuntimeError:
|
261
277
|
self.remove_norms()
|
262
278
|
return super().load_weights(
|
@@ -272,95 +288,37 @@ class HifiganGenerator(ConvNets):
|
|
272
288
|
@classmethod
|
273
289
|
def from_pretrained(
|
274
290
|
cls,
|
275
|
-
|
276
|
-
|
277
|
-
local_files_only: bool = False,
|
278
|
-
strict: bool = False,
|
291
|
+
model_file: PathLike,
|
292
|
+
model_config: Union[HifiganConfig, Dict[str, Any]],
|
279
293
|
*,
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
cache_dir: str | Path | None = None,
|
284
|
-
force_download: bool = False,
|
285
|
-
proxies: Dict | None = None,
|
286
|
-
token: bool | str | None = None,
|
287
|
-
resume_download: bool | None = None,
|
288
|
-
local_dir_use_symlinks: bool | Literal["auto"] = "auto",
|
294
|
+
strict: bool = False,
|
295
|
+
map_location: str = "cpu",
|
296
|
+
weights_only: bool = False,
|
289
297
|
**kwargs,
|
290
298
|
):
|
291
|
-
"""Load Pytorch pretrained weights and return the loaded model."""
|
292
|
-
hub_kwargs = dict(
|
293
|
-
repo_id=model_id,
|
294
|
-
subfolder=subfolder,
|
295
|
-
repo_type=repo_type,
|
296
|
-
revision=revision,
|
297
|
-
cache_dir=cache_dir,
|
298
|
-
force_download=force_download,
|
299
|
-
proxies=proxies,
|
300
|
-
resume_download=resume_download,
|
301
|
-
token=token,
|
302
|
-
local_files_only=local_files_only,
|
303
|
-
local_dir_use_symlinks=local_dir_use_symlinks,
|
304
|
-
)
|
305
|
-
|
306
|
-
# Download and load hyperparameters (h) used by BigVGAN
|
307
|
-
_model_path = Path(model_id)
|
308
|
-
config_file = None
|
309
|
-
if is_path_valid(model_id):
|
310
|
-
if is_file(model_id):
|
311
|
-
_p_conf = _model_path.parent / "config.json"
|
312
|
-
else:
|
313
|
-
_p_conf = _model_path / "config.json"
|
314
|
-
|
315
|
-
if is_file(_p_conf):
|
316
|
-
print("Loading config.json from local directory")
|
317
|
-
config_file = Path(model_id, "config.json")
|
318
|
-
else:
|
319
|
-
if not local_files_only:
|
320
|
-
print(f"Loading config from {model_id}")
|
321
|
-
config_file = hf_hub_download(filename="config.json", **hub_kwargs)
|
322
|
-
else:
|
323
|
-
if not local_files_only:
|
324
|
-
print(f"Loading config from {model_id}")
|
325
|
-
config_file = hf_hub_download(filename="config.json", **hub_kwargs)
|
326
|
-
|
327
|
-
if config_file is not None:
|
328
|
-
model = cls(HifiganConfig(**load_json(config_file)))
|
329
|
-
else:
|
330
|
-
model = cls()
|
331
299
|
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
filename="generator.pt",
|
300
|
+
is_file(model_file, validate=True)
|
301
|
+
model_state_dict = torch.load(
|
302
|
+
model_file, weights_only=weights_only, map_location=map_location
|
336
303
|
)
|
337
|
-
|
338
|
-
if
|
339
|
-
|
340
|
-
path = path / "generator.pt"
|
341
|
-
if path.exists():
|
342
|
-
print("Loading weights from local directory")
|
343
|
-
model_file = str(path)
|
344
|
-
else:
|
345
|
-
print(f"Loading weights from {model_id}")
|
346
|
-
model_file = hf_hub_download(**_retrieve_kwargs)
|
347
|
-
else:
|
348
|
-
print("Loading weights from local directory")
|
349
|
-
model_file = str(path)
|
304
|
+
|
305
|
+
if isinstance(model_config, HifiganConfig):
|
306
|
+
h = model_config
|
350
307
|
else:
|
351
|
-
|
352
|
-
model_file = hf_hub_download(**_retrieve_kwargs)
|
353
|
-
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
308
|
+
h = HifiganConfig(**model_config)
|
354
309
|
|
310
|
+
model = cls(h)
|
355
311
|
try:
|
356
|
-
model.load_state_dict(
|
312
|
+
incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
|
313
|
+
if incompatible_keys:
|
314
|
+
model.remove_norms()
|
315
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
357
316
|
except RuntimeError:
|
358
317
|
print(
|
359
318
|
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
360
319
|
)
|
361
320
|
model.remove_norms()
|
362
|
-
model.load_state_dict(
|
363
|
-
|
321
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
364
322
|
return model
|
365
323
|
|
366
324
|
|
{lt_tensor-0.0.1a22 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/istft/__init__.py
RENAMED
@@ -1,11 +1,11 @@
|
|
1
1
|
__all__ = ["iSTFTNetGenerator", "iSTFTNetConfig"]
|
2
2
|
from lt_utils.common import *
|
3
3
|
from lt_tensor.torch_commons import *
|
4
|
-
from lt_tensor.model_zoo.
|
4
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
5
5
|
from torch.nn import functional as F
|
6
6
|
from lt_tensor.config_templates import ModelConfig
|
7
|
+
from lt_tensor.misc_utils import get_config, get_weights
|
7
8
|
from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
|
8
|
-
from huggingface_hub import hf_hub_download
|
9
9
|
|
10
10
|
|
11
11
|
class iSTFTNetConfig(ModelConfig):
|
@@ -39,7 +39,7 @@ class iSTFTNetConfig(ModelConfig):
|
|
39
39
|
[1, 3, 5],
|
40
40
|
],
|
41
41
|
activation: nn.Module = nn.LeakyReLU(0.1),
|
42
|
-
resblock: int =
|
42
|
+
resblock: Union[int, str] = "1",
|
43
43
|
gen_istft_n_fft: int = 16,
|
44
44
|
sampling_rate: Number = 24000,
|
45
45
|
*args,
|
@@ -59,7 +59,11 @@ class iSTFTNetConfig(ModelConfig):
|
|
59
59
|
}
|
60
60
|
super().__init__(**settings)
|
61
61
|
|
62
|
-
|
62
|
+
def post_process(self):
|
63
|
+
if isinstance(self.resblock, str):
|
64
|
+
self.resblock = 0 if self.resblock == "1" else 1
|
65
|
+
|
66
|
+
|
63
67
|
def get_padding(ks, d):
|
64
68
|
return int((ks * d - d) / 2)
|
65
69
|
|
@@ -271,15 +275,15 @@ class iSTFTNetGenerator(ConvNets):
|
|
271
275
|
def load_weights(
|
272
276
|
self,
|
273
277
|
path,
|
274
|
-
raise_if_not_exists=False,
|
275
278
|
strict=False,
|
276
279
|
assign=False,
|
277
|
-
weights_only=
|
280
|
+
weights_only=False,
|
278
281
|
mmap=None,
|
282
|
+
raise_if_not_exists=False,
|
279
283
|
**pickle_load_args,
|
280
284
|
):
|
281
285
|
try:
|
282
|
-
|
286
|
+
incompatible_keys = super().load_weights(
|
283
287
|
path,
|
284
288
|
raise_if_not_exists,
|
285
289
|
strict,
|
@@ -288,6 +292,18 @@ class iSTFTNetGenerator(ConvNets):
|
|
288
292
|
mmap,
|
289
293
|
**pickle_load_args,
|
290
294
|
)
|
295
|
+
if incompatible_keys:
|
296
|
+
self.remove_norms()
|
297
|
+
incompatible_keys = super().load_weights(
|
298
|
+
path,
|
299
|
+
raise_if_not_exists,
|
300
|
+
strict,
|
301
|
+
assign,
|
302
|
+
weights_only,
|
303
|
+
mmap,
|
304
|
+
**pickle_load_args,
|
305
|
+
)
|
306
|
+
return incompatible_keys
|
291
307
|
except RuntimeError:
|
292
308
|
self.remove_norms()
|
293
309
|
return super().load_weights(
|
@@ -303,95 +319,37 @@ class iSTFTNetGenerator(ConvNets):
|
|
303
319
|
@classmethod
|
304
320
|
def from_pretrained(
|
305
321
|
cls,
|
306
|
-
|
307
|
-
|
308
|
-
local_files_only: bool = False,
|
309
|
-
strict: bool = False,
|
322
|
+
model_file: PathLike,
|
323
|
+
model_config: Union[iSTFTNetConfig, Dict[str, Any]],
|
310
324
|
*,
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
cache_dir: str | Path | None = None,
|
315
|
-
force_download: bool = False,
|
316
|
-
proxies: Dict | None = None,
|
317
|
-
token: bool | str | None = None,
|
318
|
-
resume_download: bool | None = None,
|
319
|
-
local_dir_use_symlinks: bool | Literal["auto"] = "auto",
|
325
|
+
strict: bool = False,
|
326
|
+
map_location: str = "cpu",
|
327
|
+
weights_only: bool = False,
|
320
328
|
**kwargs,
|
321
329
|
):
|
322
|
-
"""Load Pytorch pretrained weights and return the loaded model."""
|
323
|
-
hub_kwargs = dict(
|
324
|
-
repo_id=model_id,
|
325
|
-
subfolder=subfolder,
|
326
|
-
repo_type=repo_type,
|
327
|
-
revision=revision,
|
328
|
-
cache_dir=cache_dir,
|
329
|
-
force_download=force_download,
|
330
|
-
proxies=proxies,
|
331
|
-
resume_download=resume_download,
|
332
|
-
token=token,
|
333
|
-
local_files_only=local_files_only,
|
334
|
-
local_dir_use_symlinks=local_dir_use_symlinks,
|
335
|
-
)
|
336
330
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
if is_path_valid(model_id):
|
341
|
-
if is_file(model_id):
|
342
|
-
_p_conf = _model_path.parent / "config.json"
|
343
|
-
else:
|
344
|
-
_p_conf = _model_path / "config.json"
|
345
|
-
|
346
|
-
if is_file(_p_conf):
|
347
|
-
print("Loading config.json from local directory")
|
348
|
-
config_file = Path(model_id, "config.json")
|
349
|
-
else:
|
350
|
-
if not local_files_only:
|
351
|
-
print(f"Loading config from {model_id}")
|
352
|
-
config_file = hf_hub_download(filename="config.json", **hub_kwargs)
|
353
|
-
else:
|
354
|
-
if not local_files_only:
|
355
|
-
print(f"Loading config from {model_id}")
|
356
|
-
config_file = hf_hub_download(filename="config.json", **hub_kwargs)
|
357
|
-
|
358
|
-
if config_file is not None:
|
359
|
-
model = cls(iSTFTNetConfig(**load_json(config_file)))
|
360
|
-
else:
|
361
|
-
model = cls()
|
362
|
-
|
363
|
-
# Download and load pretrained generator weight
|
364
|
-
_retrieve_kwargs = dict(
|
365
|
-
**hub_kwargs,
|
366
|
-
filename="generator.pt",
|
331
|
+
is_file(model_file, validate=True)
|
332
|
+
model_state_dict = torch.load(
|
333
|
+
model_file, weights_only=weights_only, map_location=map_location
|
367
334
|
)
|
368
|
-
|
369
|
-
if
|
370
|
-
|
371
|
-
path = path / "generator.pt"
|
372
|
-
if path.exists():
|
373
|
-
print("Loading weights from local directory")
|
374
|
-
model_file = str(path)
|
375
|
-
else:
|
376
|
-
print(f"Loading weights from {model_id}")
|
377
|
-
model_file = hf_hub_download(**_retrieve_kwargs)
|
378
|
-
else:
|
379
|
-
print("Loading weights from local directory")
|
380
|
-
model_file = str(path)
|
335
|
+
|
336
|
+
if isinstance(model_config, iSTFTNetConfig):
|
337
|
+
h = model_config
|
381
338
|
else:
|
382
|
-
|
383
|
-
model_file = hf_hub_download(**_retrieve_kwargs)
|
384
|
-
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
339
|
+
h = iSTFTNetConfig(**model_config)
|
385
340
|
|
341
|
+
model = cls(h)
|
386
342
|
try:
|
387
|
-
model.load_state_dict(
|
343
|
+
incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
|
344
|
+
if incompatible_keys:
|
345
|
+
model.remove_norms()
|
346
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
388
347
|
except RuntimeError:
|
389
348
|
print(
|
390
349
|
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
391
350
|
)
|
392
351
|
model.remove_norms()
|
393
|
-
model.load_state_dict(
|
394
|
-
|
352
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
395
353
|
return model
|
396
354
|
|
397
355
|
|