lt-tensor 0.0.1a22__py3-none-any.whl → 0.0.1a27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/config_templates.py +9 -5
- lt_tensor/misc_utils.py +15 -3
- lt_tensor/model_base.py +57 -78
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py +41 -14
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +40 -82
- lt_tensor/model_zoo/audio_models/istft/__init__.py +41 -83
- lt_tensor/model_zoo/convs.py +124 -0
- lt_tensor/model_zoo/residual.py +1 -136
- lt_tensor/processors/__init__.py +2 -2
- lt_tensor/processors/audio.py +267 -200
- {lt_tensor-0.0.1a22.dist-info → lt_tensor-0.0.1a27.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a22.dist-info → lt_tensor-0.0.1a27.dist-info}/RECORD +15 -16
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +0 -536
- lt_tensor/model_zoo/audio_models/bigvgan/cuda/__init__.py +0 -160
- {lt_tensor-0.0.1a22.dist-info → lt_tensor-0.0.1a27.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a22.dist-info → lt_tensor-0.0.1a27.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a22.dist-info → lt_tensor-0.0.1a27.dist-info}/top_level.txt +0 -0
lt_tensor/config_templates.py
CHANGED
@@ -6,9 +6,7 @@ from lt_tensor.misc_utils import updateDict
|
|
6
6
|
|
7
7
|
|
8
8
|
class ModelConfig(ABC, OrderedDict):
|
9
|
-
_default_settings: Dict[str, Any] = {}
|
10
9
|
_forbidden_list: List[str] = [
|
11
|
-
"_default_settings",
|
12
10
|
"_forbidden_list",
|
13
11
|
]
|
14
12
|
|
@@ -16,12 +14,15 @@ class ModelConfig(ABC, OrderedDict):
|
|
16
14
|
self,
|
17
15
|
**settings,
|
18
16
|
):
|
19
|
-
self.
|
20
|
-
self.set_state_dict(self._default_settings)
|
17
|
+
self.set_state_dict(settings)
|
21
18
|
|
22
19
|
def reset_settings(self):
|
23
20
|
raise NotImplementedError("Not implemented")
|
24
|
-
|
21
|
+
|
22
|
+
def post_process(self):
|
23
|
+
"""Implement the post process, to do a final check to the input data"""
|
24
|
+
pass
|
25
|
+
|
25
26
|
def save_config(
|
26
27
|
self,
|
27
28
|
path: str,
|
@@ -48,6 +49,7 @@ class ModelConfig(ABC, OrderedDict):
|
|
48
49
|
}
|
49
50
|
updateDict(self, new_state)
|
50
51
|
self.update(**new_state)
|
52
|
+
self.post_process()
|
51
53
|
|
52
54
|
def state_dict(self):
|
53
55
|
return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
|
@@ -89,3 +91,5 @@ class ModelConfig(ABC, OrderedDict):
|
|
89
91
|
settings.pop("path_name", None)
|
90
92
|
|
91
93
|
return ModelConfig(**settings)
|
94
|
+
|
95
|
+
|
lt_tensor/misc_utils.py
CHANGED
@@ -111,10 +111,10 @@ def get_weights(directory: Union[str, PathLike]):
|
|
111
111
|
directory = Path(directory)
|
112
112
|
if is_file(directory):
|
113
113
|
if directory.name.endswith((".pt", ".ckpt", ".pth")):
|
114
|
-
return directory
|
114
|
+
return [directory]
|
115
115
|
directory = directory.parent
|
116
116
|
res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
|
117
|
-
return res
|
117
|
+
return res
|
118
118
|
|
119
119
|
|
120
120
|
def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
|
@@ -128,7 +128,19 @@ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
|
|
128
128
|
return load_json(directory, default)
|
129
129
|
return load_yaml(directory, default)
|
130
130
|
directory = directory.parent
|
131
|
-
res = sorted(
|
131
|
+
res = sorted(
|
132
|
+
find_files(
|
133
|
+
directory,
|
134
|
+
[
|
135
|
+
"config*.json",
|
136
|
+
"*config.json",
|
137
|
+
"config*.yml",
|
138
|
+
"*config.yml",
|
139
|
+
"*config.yaml",
|
140
|
+
"config*.yaml",
|
141
|
+
],
|
142
|
+
)
|
143
|
+
)
|
132
144
|
if res:
|
133
145
|
res = res[-1]
|
134
146
|
if Path(res).name.endswith(".json"):
|
lt_tensor/model_base.py
CHANGED
@@ -80,6 +80,62 @@ class _Devices_Base(nn.Module):
|
|
80
80
|
assert isinstance(device, (str, torch.device))
|
81
81
|
self._device = torch.device(device) if isinstance(device, str) else device
|
82
82
|
|
83
|
+
def freeze_all(self, exclude: list[str] = []):
|
84
|
+
for name, module in self.named_modules():
|
85
|
+
if name in exclude or not hasattr(module, "requires_grad"):
|
86
|
+
continue
|
87
|
+
try:
|
88
|
+
self.freeze_module(module)
|
89
|
+
except:
|
90
|
+
pass
|
91
|
+
|
92
|
+
def unfreeze_all(self, exclude: list[str] = []):
|
93
|
+
for name, module in self.named_modules():
|
94
|
+
if name in exclude or not hasattr(module, "requires_grad"):
|
95
|
+
continue
|
96
|
+
try:
|
97
|
+
self.unfreeze_module(module)
|
98
|
+
except:
|
99
|
+
pass
|
100
|
+
|
101
|
+
def freeze_module(
|
102
|
+
self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
|
103
|
+
):
|
104
|
+
self._change_gradient_state(module_or_name, False)
|
105
|
+
|
106
|
+
def unfreeze_module(
|
107
|
+
self, module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor]
|
108
|
+
):
|
109
|
+
self._change_gradient_state(module_or_name, True)
|
110
|
+
|
111
|
+
def _change_gradient_state(
|
112
|
+
self,
|
113
|
+
module_or_name: Union[str, nn.Module, nn.Parameter, "Model", Tensor],
|
114
|
+
new_state: bool, # True = unfreeze
|
115
|
+
):
|
116
|
+
assert isinstance(
|
117
|
+
module_or_name, (str, nn.Module, nn.Parameter, Model, Tensor)
|
118
|
+
), f"Item '{module_or_name}' is not a valid module, parameter, tensor or a string."
|
119
|
+
if isinstance(module_or_name, (nn.Module, nn.Parameter, Model, Tensor)):
|
120
|
+
target = module_or_name
|
121
|
+
else:
|
122
|
+
target = getattr(self, module_or_name)
|
123
|
+
|
124
|
+
if isinstance(target, Tensor):
|
125
|
+
target.requires_grad = new_state
|
126
|
+
elif isinstance(target, nn.Parameter):
|
127
|
+
target.requires_grad = new_state
|
128
|
+
elif isinstance(target, Model):
|
129
|
+
target.freeze_all()
|
130
|
+
elif isinstance(target, nn.Module):
|
131
|
+
for param in target.parameters():
|
132
|
+
if hasattr(param, "requires_grad"):
|
133
|
+
param.requires_grad = new_state
|
134
|
+
else:
|
135
|
+
raise ValueError(
|
136
|
+
f"Item '{module_or_name}' is not a valid module, parameter or tensor."
|
137
|
+
)
|
138
|
+
|
83
139
|
def _apply_device(self):
|
84
140
|
"""Add here components that are needed to have device applied to them,
|
85
141
|
that usually the '.to()' function fails to apply
|
@@ -182,20 +238,12 @@ class Model(_Devices_Base, ABC):
|
|
182
238
|
"""
|
183
239
|
|
184
240
|
_autocast: bool = False
|
185
|
-
|
186
|
-
# list with modules that can be frozen or unfrozen
|
187
|
-
registered_freezable_modules: List[str] = []
|
188
|
-
is_frozen: bool = False
|
189
|
-
_can_be_frozen: bool = (
|
190
|
-
False # to control if the module can or cannot be freezed by other modules from 'Model' class
|
191
|
-
)
|
241
|
+
|
192
242
|
# this is to be used on the case of they module requires low-rank adapters
|
193
243
|
_low_rank_lambda: Optional[Callable[[], nn.Module]] = (
|
194
244
|
None # Example: lambda: nn.Linear(32, 32, True)
|
195
245
|
)
|
196
246
|
low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
|
197
|
-
# never freeze:
|
198
|
-
_never_freeze_modules: List[str] = ["low_rank_adapter"]
|
199
247
|
|
200
248
|
# dont save list:
|
201
249
|
_dont_save_items: List[str] = []
|
@@ -208,75 +256,6 @@ class Model(_Devices_Base, ABC):
|
|
208
256
|
def autocast(self, value: bool):
|
209
257
|
self._autocast = value
|
210
258
|
|
211
|
-
def freeze_all(self, exclude: Optional[List[str]] = None, force: bool = False):
|
212
|
-
no_exclusions = not exclude
|
213
|
-
no_exclusions = not exclude
|
214
|
-
results = []
|
215
|
-
for name, module in self.named_modules():
|
216
|
-
if (
|
217
|
-
name in self._never_freeze_modules
|
218
|
-
or not force
|
219
|
-
and name not in self.registered_freezable_modules
|
220
|
-
):
|
221
|
-
results.append(
|
222
|
-
(
|
223
|
-
name,
|
224
|
-
"Unregistered module, to freeze/unfreeze it add its name into 'registered_freezable_modules'.",
|
225
|
-
)
|
226
|
-
)
|
227
|
-
continue
|
228
|
-
if no_exclusions:
|
229
|
-
self.change_frozen_state(True, module)
|
230
|
-
elif not any(exclusion in name for exclusion in exclude):
|
231
|
-
results.append((name, self.change_frozen_state(True, module)))
|
232
|
-
else:
|
233
|
-
results.append((name, "excluded"))
|
234
|
-
return results
|
235
|
-
|
236
|
-
def unfreeze_all(self, exclude: Optional[list[str]] = None, force: bool = False):
|
237
|
-
"""Unfreezes all model parameters except specified layers."""
|
238
|
-
no_exclusions = not exclude
|
239
|
-
results = []
|
240
|
-
for name, module in self.named_modules():
|
241
|
-
if (
|
242
|
-
name in self._never_freeze_modules
|
243
|
-
or not force
|
244
|
-
and name not in self.registered_freezable_modules
|
245
|
-
):
|
246
|
-
|
247
|
-
results.append(
|
248
|
-
(
|
249
|
-
name,
|
250
|
-
"Unregistered module, to freeze/unfreeze it add it into 'registered_freezable_modules'.",
|
251
|
-
)
|
252
|
-
)
|
253
|
-
continue
|
254
|
-
if no_exclusions:
|
255
|
-
self.change_frozen_state(False, module)
|
256
|
-
elif not any(exclusion in name for exclusion in exclude):
|
257
|
-
results.append((name, self.change_frozen_state(False, module)))
|
258
|
-
else:
|
259
|
-
results.append((name, "excluded"))
|
260
|
-
return results
|
261
|
-
|
262
|
-
def change_frozen_state(self, freeze: bool, module: nn.Module):
|
263
|
-
assert isinstance(module, nn.Module)
|
264
|
-
if module.__class__.__name__ in self._never_freeze_modules:
|
265
|
-
return "Not Allowed"
|
266
|
-
try:
|
267
|
-
if isinstance(module, Model):
|
268
|
-
if module._can_be_frozen:
|
269
|
-
if freeze:
|
270
|
-
return module.freeze_all()
|
271
|
-
return module.unfreeze_all()
|
272
|
-
else:
|
273
|
-
return "Not Allowed"
|
274
|
-
else:
|
275
|
-
module.requires_grad_(not freeze)
|
276
|
-
return not freeze
|
277
|
-
except Exception as e:
|
278
|
-
return e
|
279
|
-
|
280
259
|
def trainable_parameters(self, module_name: Optional[str] = None):
|
281
260
|
"""Gets the number of trainable parameters from either the entire model or from a specific module."""
|
282
261
|
if module_name is not None:
|
@@ -5,6 +5,7 @@ from lt_tensor.torch_commons import *
|
|
5
5
|
from torch.nn import functional as F
|
6
6
|
from lt_tensor.config_templates import ModelConfig
|
7
7
|
from lt_tensor.torch_commons import *
|
8
|
+
from lt_tensor.model_zoo.convs import ConvNets, Conv1dEXT
|
8
9
|
from lt_tensor.model_base import Model
|
9
10
|
from math import sqrt
|
10
11
|
from lt_utils.common import *
|
@@ -18,6 +19,8 @@ class DiffWaveConfig(ModelConfig):
|
|
18
19
|
residual_channels = 64
|
19
20
|
dilation_cycle_length = 10
|
20
21
|
unconditional = False
|
22
|
+
apply_norm: Optional[Literal["weight", "spectral"]] = None
|
23
|
+
apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None
|
21
24
|
noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist()
|
22
25
|
# settings for auto-fixes
|
23
26
|
interpolate = False
|
@@ -44,6 +47,8 @@ class DiffWaveConfig(ModelConfig):
|
|
44
47
|
"area",
|
45
48
|
"nearest-exact",
|
46
49
|
] = "nearest",
|
50
|
+
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
51
|
+
apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None,
|
47
52
|
):
|
48
53
|
settings = {
|
49
54
|
"n_mels": n_mels,
|
@@ -55,16 +60,12 @@ class DiffWaveConfig(ModelConfig):
|
|
55
60
|
"noise_schedule": noise_schedule,
|
56
61
|
"interpolate": interpolate_cond,
|
57
62
|
"interpolation_mode": interpolation_mode,
|
63
|
+
"apply_norm": apply_norm,
|
64
|
+
"apply_norm_resblock": apply_norm_resblock,
|
58
65
|
}
|
59
66
|
super().__init__(**settings)
|
60
67
|
|
61
68
|
|
62
|
-
def Conv1d(*args, **kwargs):
|
63
|
-
layer = nn.Conv1d(*args, **kwargs)
|
64
|
-
nn.init.kaiming_normal_(layer.weight)
|
65
|
-
return layer
|
66
|
-
|
67
|
-
|
68
69
|
class DiffusionEmbedding(Model):
|
69
70
|
def __init__(self, max_steps: int):
|
70
71
|
super().__init__()
|
@@ -117,7 +118,14 @@ class SpectrogramUpsample(Model):
|
|
117
118
|
|
118
119
|
|
119
120
|
class ResidualBlock(Model):
|
120
|
-
def __init__(
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
n_mels,
|
124
|
+
residual_channels,
|
125
|
+
dilation,
|
126
|
+
uncond=False,
|
127
|
+
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
128
|
+
):
|
121
129
|
"""
|
122
130
|
:param n_mels: inplanes of conv1x1 for spectrogram conditional
|
123
131
|
:param residual_channels: audio conv
|
@@ -125,20 +133,28 @@ class ResidualBlock(Model):
|
|
125
133
|
:param uncond: disable spectrogram conditional
|
126
134
|
"""
|
127
135
|
super().__init__()
|
128
|
-
self.dilated_conv =
|
136
|
+
self.dilated_conv = Conv1dEXT(
|
129
137
|
residual_channels,
|
130
138
|
2 * residual_channels,
|
131
139
|
3,
|
132
140
|
padding=dilation,
|
133
141
|
dilation=dilation,
|
142
|
+
apply_norm=apply_norm,
|
134
143
|
)
|
135
144
|
self.diffusion_projection = nn.Linear(512, residual_channels)
|
136
145
|
if not uncond: # conditional model
|
137
|
-
self.conditioner_projection =
|
146
|
+
self.conditioner_projection = Conv1dEXT(
|
147
|
+
n_mels,
|
148
|
+
2 * residual_channels,
|
149
|
+
1,
|
150
|
+
apply_norm=apply_norm,
|
151
|
+
)
|
138
152
|
else: # unconditional model
|
139
153
|
self.conditioner_projection = None
|
140
154
|
|
141
|
-
self.output_projection =
|
155
|
+
self.output_projection = Conv1dEXT(
|
156
|
+
residual_channels, 2 * residual_channels, 1, apply_norm == apply_norm
|
157
|
+
)
|
142
158
|
|
143
159
|
def forward(
|
144
160
|
self,
|
@@ -172,7 +188,12 @@ class DiffWave(Model):
|
|
172
188
|
self.n_hop = self.params.hop_samples
|
173
189
|
self.interpolate = self.params.interpolate
|
174
190
|
self.interpolate_mode = self.params.interpolation_mode
|
175
|
-
self.input_projection =
|
191
|
+
self.input_projection = Conv1dEXT(
|
192
|
+
in_channels=1,
|
193
|
+
out_channels=params.residual_channels,
|
194
|
+
kernel_size=1,
|
195
|
+
apply_norm=self.params.apply_norm,
|
196
|
+
)
|
176
197
|
self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
|
177
198
|
if self.params.unconditional: # use unconditional model
|
178
199
|
self.spectrogram_upsample = None
|
@@ -186,14 +207,20 @@ class DiffWave(Model):
|
|
186
207
|
params.residual_channels,
|
187
208
|
2 ** (i % params.dilation_cycle_length),
|
188
209
|
uncond=params.unconditional,
|
210
|
+
apply_norm=self.params.apply_norm_resblock,
|
189
211
|
)
|
190
212
|
for i in range(params.residual_layers)
|
191
213
|
]
|
192
214
|
)
|
193
|
-
self.skip_projection =
|
194
|
-
params.residual_channels,
|
215
|
+
self.skip_projection = Conv1dEXT(
|
216
|
+
in_channels=params.residual_channels,
|
217
|
+
out_channels=params.residual_channels,
|
218
|
+
kernel_size=1,
|
219
|
+
apply_norm=self.params.apply_norm,
|
220
|
+
)
|
221
|
+
self.output_projection = Conv1dEXT(
|
222
|
+
params.residual_channels, 1, 1, apply_norm=self.params.apply_norm
|
195
223
|
)
|
196
|
-
self.output_projection = Conv1d(params.residual_channels, 1, 1)
|
197
224
|
self.activation = nn.LeakyReLU(0.1)
|
198
225
|
self.r_sqrt = sqrt(len(self.residual_layers))
|
199
226
|
nn.init.zeros_(self.output_projection.weight)
|
@@ -1,10 +1,10 @@
|
|
1
1
|
__all__ = ["HifiganGenerator", "HifiganConfig"]
|
2
2
|
from lt_utils.common import *
|
3
3
|
from lt_tensor.torch_commons import *
|
4
|
-
from lt_tensor.model_zoo.
|
4
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
5
5
|
from torch.nn import functional as F
|
6
6
|
from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
|
7
|
-
from
|
7
|
+
from lt_tensor.misc_utils import get_config, get_weights
|
8
8
|
|
9
9
|
|
10
10
|
def get_padding(kernel_size, dilation=1):
|
@@ -43,7 +43,7 @@ class HifiganConfig(ModelConfig):
|
|
43
43
|
[1, 3, 5],
|
44
44
|
],
|
45
45
|
activation: nn.Module = nn.LeakyReLU(0.1),
|
46
|
-
resblock: int =
|
46
|
+
resblock: Union[int, str] = "1",
|
47
47
|
*args,
|
48
48
|
**kwargs,
|
49
49
|
):
|
@@ -59,6 +59,10 @@ class HifiganConfig(ModelConfig):
|
|
59
59
|
}
|
60
60
|
super().__init__(**settings)
|
61
61
|
|
62
|
+
def post_process(self):
|
63
|
+
if isinstance(self.resblock, str):
|
64
|
+
self.resblock = 0 if self.resblock == "1" else 1
|
65
|
+
|
62
66
|
|
63
67
|
class ResBlock1(ConvNets):
|
64
68
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
@@ -240,15 +244,15 @@ class HifiganGenerator(ConvNets):
|
|
240
244
|
def load_weights(
|
241
245
|
self,
|
242
246
|
path,
|
243
|
-
raise_if_not_exists=False,
|
244
247
|
strict=False,
|
245
248
|
assign=False,
|
246
|
-
weights_only=
|
249
|
+
weights_only=False,
|
247
250
|
mmap=None,
|
251
|
+
raise_if_not_exists=False,
|
248
252
|
**pickle_load_args,
|
249
253
|
):
|
250
254
|
try:
|
251
|
-
|
255
|
+
incompatible_keys = super().load_weights(
|
252
256
|
path,
|
253
257
|
raise_if_not_exists,
|
254
258
|
strict,
|
@@ -257,6 +261,18 @@ class HifiganGenerator(ConvNets):
|
|
257
261
|
mmap,
|
258
262
|
**pickle_load_args,
|
259
263
|
)
|
264
|
+
if incompatible_keys:
|
265
|
+
self.remove_norms()
|
266
|
+
incompatible_keys = super().load_weights(
|
267
|
+
path,
|
268
|
+
raise_if_not_exists,
|
269
|
+
strict,
|
270
|
+
assign,
|
271
|
+
weights_only,
|
272
|
+
mmap,
|
273
|
+
**pickle_load_args,
|
274
|
+
)
|
275
|
+
return incompatible_keys
|
260
276
|
except RuntimeError:
|
261
277
|
self.remove_norms()
|
262
278
|
return super().load_weights(
|
@@ -272,95 +288,37 @@ class HifiganGenerator(ConvNets):
|
|
272
288
|
@classmethod
|
273
289
|
def from_pretrained(
|
274
290
|
cls,
|
275
|
-
|
276
|
-
|
277
|
-
local_files_only: bool = False,
|
278
|
-
strict: bool = False,
|
291
|
+
model_file: PathLike,
|
292
|
+
model_config: Union[HifiganConfig, Dict[str, Any]],
|
279
293
|
*,
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
cache_dir: str | Path | None = None,
|
284
|
-
force_download: bool = False,
|
285
|
-
proxies: Dict | None = None,
|
286
|
-
token: bool | str | None = None,
|
287
|
-
resume_download: bool | None = None,
|
288
|
-
local_dir_use_symlinks: bool | Literal["auto"] = "auto",
|
294
|
+
strict: bool = False,
|
295
|
+
map_location: str = "cpu",
|
296
|
+
weights_only: bool = False,
|
289
297
|
**kwargs,
|
290
298
|
):
|
291
|
-
"""Load Pytorch pretrained weights and return the loaded model."""
|
292
|
-
hub_kwargs = dict(
|
293
|
-
repo_id=model_id,
|
294
|
-
subfolder=subfolder,
|
295
|
-
repo_type=repo_type,
|
296
|
-
revision=revision,
|
297
|
-
cache_dir=cache_dir,
|
298
|
-
force_download=force_download,
|
299
|
-
proxies=proxies,
|
300
|
-
resume_download=resume_download,
|
301
|
-
token=token,
|
302
|
-
local_files_only=local_files_only,
|
303
|
-
local_dir_use_symlinks=local_dir_use_symlinks,
|
304
|
-
)
|
305
|
-
|
306
|
-
# Download and load hyperparameters (h) used by BigVGAN
|
307
|
-
_model_path = Path(model_id)
|
308
|
-
config_file = None
|
309
|
-
if is_path_valid(model_id):
|
310
|
-
if is_file(model_id):
|
311
|
-
_p_conf = _model_path.parent / "config.json"
|
312
|
-
else:
|
313
|
-
_p_conf = _model_path / "config.json"
|
314
|
-
|
315
|
-
if is_file(_p_conf):
|
316
|
-
print("Loading config.json from local directory")
|
317
|
-
config_file = Path(model_id, "config.json")
|
318
|
-
else:
|
319
|
-
if not local_files_only:
|
320
|
-
print(f"Loading config from {model_id}")
|
321
|
-
config_file = hf_hub_download(filename="config.json", **hub_kwargs)
|
322
|
-
else:
|
323
|
-
if not local_files_only:
|
324
|
-
print(f"Loading config from {model_id}")
|
325
|
-
config_file = hf_hub_download(filename="config.json", **hub_kwargs)
|
326
|
-
|
327
|
-
if config_file is not None:
|
328
|
-
model = cls(HifiganConfig(**load_json(config_file)))
|
329
|
-
else:
|
330
|
-
model = cls()
|
331
299
|
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
filename="generator.pt",
|
300
|
+
is_file(model_file, validate=True)
|
301
|
+
model_state_dict = torch.load(
|
302
|
+
model_file, weights_only=weights_only, map_location=map_location
|
336
303
|
)
|
337
|
-
|
338
|
-
if
|
339
|
-
|
340
|
-
path = path / "generator.pt"
|
341
|
-
if path.exists():
|
342
|
-
print("Loading weights from local directory")
|
343
|
-
model_file = str(path)
|
344
|
-
else:
|
345
|
-
print(f"Loading weights from {model_id}")
|
346
|
-
model_file = hf_hub_download(**_retrieve_kwargs)
|
347
|
-
else:
|
348
|
-
print("Loading weights from local directory")
|
349
|
-
model_file = str(path)
|
304
|
+
|
305
|
+
if isinstance(model_config, HifiganConfig):
|
306
|
+
h = model_config
|
350
307
|
else:
|
351
|
-
|
352
|
-
model_file = hf_hub_download(**_retrieve_kwargs)
|
353
|
-
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
308
|
+
h = HifiganConfig(**model_config)
|
354
309
|
|
310
|
+
model = cls(h)
|
355
311
|
try:
|
356
|
-
model.load_state_dict(
|
312
|
+
incompatible_keys = model.load_state_dict(model_state_dict, strict=strict)
|
313
|
+
if incompatible_keys:
|
314
|
+
model.remove_norms()
|
315
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
357
316
|
except RuntimeError:
|
358
317
|
print(
|
359
318
|
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
360
319
|
)
|
361
320
|
model.remove_norms()
|
362
|
-
model.load_state_dict(
|
363
|
-
|
321
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
364
322
|
return model
|
365
323
|
|
366
324
|
|