lt-tensor 0.0.1a38__py3-none-any.whl → 0.0.1a40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/model_zoo/audio_models/bemaganv2/__init__.py +205 -0
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +14 -39
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py +20 -19
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +24 -44
- lt_tensor/model_zoo/audio_models/istft/__init__.py +15 -39
- lt_tensor/model_zoo/convs.py +35 -4
- lt_tensor/model_zoo/losses/_envelope_disc/__init__.py +116 -0
- lt_tensor/model_zoo/losses/discriminators.py +34 -64
- lt_tensor/noise_tools.py +22 -13
- lt_tensor/processors/audio.py +116 -62
- {lt_tensor-0.0.1a38.dist-info → lt_tensor-0.0.1a40.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a38.dist-info → lt_tensor-0.0.1a40.dist-info}/RECORD +16 -14
- {lt_tensor-0.0.1a38.dist-info → lt_tensor-0.0.1a40.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a38.dist-info → lt_tensor-0.0.1a40.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a38.dist-info → lt_tensor-0.0.1a40.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
@@ -0,0 +1,205 @@
|
|
1
|
+
from lt_utils.common import *
|
2
|
+
from lt_tensor.torch_commons import *
|
3
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
4
|
+
from lt_tensor.config_templates import ModelConfig
|
5
|
+
from lt_utils.file_ops import is_file, load_json
|
6
|
+
from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
|
7
|
+
from lt_tensor.model_zoo.activations import snake, alias_free
|
8
|
+
from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
|
9
|
+
|
10
|
+
|
11
|
+
class BemaGANv2Config(ModelConfig):
|
12
|
+
# Training params
|
13
|
+
in_channels: int = 80
|
14
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2]
|
15
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4]
|
16
|
+
upsample_initial_channel: int = 1536
|
17
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
|
18
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
19
|
+
[1, 3, 5],
|
20
|
+
[1, 3, 5],
|
21
|
+
[1, 3, 5],
|
22
|
+
]
|
23
|
+
|
24
|
+
activation: Literal["snake", "snakebeta"] = "snakebeta"
|
25
|
+
resblock_activation: Literal["snake", "snakebeta"] = "snakebeta"
|
26
|
+
resblock: int = 0
|
27
|
+
use_bias_at_final: bool = True
|
28
|
+
use_tanh_at_final: bool = True
|
29
|
+
snake_logscale: bool = True
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
in_channels: int = 80,
|
34
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2],
|
35
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4],
|
36
|
+
upsample_initial_channel: int = 1536,
|
37
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
38
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
39
|
+
[1, 3, 5],
|
40
|
+
[1, 3, 5],
|
41
|
+
[1, 3, 5],
|
42
|
+
],
|
43
|
+
activation: Literal["snake", "snakebeta"] = "snakebeta",
|
44
|
+
resblock_activation: Literal["snake", "snakebeta"] = "snakebeta",
|
45
|
+
resblock: Union[int, str] = "1",
|
46
|
+
use_bias_at_final: bool = False,
|
47
|
+
use_tanh_at_final: bool = False,
|
48
|
+
*args,
|
49
|
+
**kwargs,
|
50
|
+
):
|
51
|
+
settings = {
|
52
|
+
"in_channels": in_channels,
|
53
|
+
"upsample_rates": upsample_rates,
|
54
|
+
"upsample_kernel_sizes": upsample_kernel_sizes,
|
55
|
+
"upsample_initial_channel": upsample_initial_channel,
|
56
|
+
"resblock_kernel_sizes": resblock_kernel_sizes,
|
57
|
+
"resblock_dilation_sizes": resblock_dilation_sizes,
|
58
|
+
"activation": activation,
|
59
|
+
"resblock_activation": resblock_activation,
|
60
|
+
"resblock": resblock,
|
61
|
+
"use_bias_at_final": use_bias_at_final,
|
62
|
+
"use_tanh_at_final": use_tanh_at_final,
|
63
|
+
}
|
64
|
+
super().__init__(**settings)
|
65
|
+
|
66
|
+
def post_process(self):
|
67
|
+
if isinstance(self.resblock, str):
|
68
|
+
self.resblock = 0 if self.resblock == "1" else 1
|
69
|
+
|
70
|
+
|
71
|
+
class BemaGANv2Generator(ConvNets):
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self, cfg: Union[BemaGANv2Config, Dict[str, object]] = BemaGANv2Config()
|
75
|
+
):
|
76
|
+
super().__init__()
|
77
|
+
cfg = cfg if isinstance(cfg, BemaGANv2Config) else BemaGANv2Config(**cfg)
|
78
|
+
self.cfg = cfg
|
79
|
+
|
80
|
+
actv = get_snake(self.cfg.activation)
|
81
|
+
|
82
|
+
self.num_kernels = len(cfg.resblock_kernel_sizes)
|
83
|
+
self.num_upsamples = len(cfg.upsample_rates)
|
84
|
+
|
85
|
+
self.conv_pre = weight_norm(
|
86
|
+
nn.Conv1d(cfg.num_mels, cfg.upsample_initial_channel, 7, 1, padding=3)
|
87
|
+
)
|
88
|
+
|
89
|
+
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
90
|
+
resblock = AMPBlock1 if cfg.resblock == 0 else AMPBlock2
|
91
|
+
|
92
|
+
# transposed conv-based upsamplers. does not apply anti-aliasing
|
93
|
+
self.ups = nn.ModuleList()
|
94
|
+
for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
|
95
|
+
self.ups.append(
|
96
|
+
nn.ModuleList(
|
97
|
+
[
|
98
|
+
weight_norm(
|
99
|
+
nn.ConvTranspose1d(
|
100
|
+
cfg.upsample_initial_channel // (2**i),
|
101
|
+
cfg.upsample_initial_channel // (2 ** (i + 1)),
|
102
|
+
k,
|
103
|
+
u,
|
104
|
+
padding=(k - u) // 2,
|
105
|
+
)
|
106
|
+
)
|
107
|
+
]
|
108
|
+
)
|
109
|
+
)
|
110
|
+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
111
|
+
self.resblocks = nn.ModuleList()
|
112
|
+
for i in range(len(self.ups)):
|
113
|
+
ch = cfg.upsample_initial_channel // (2 ** (i + 1))
|
114
|
+
for k, d in zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes):
|
115
|
+
self.resblocks.append(
|
116
|
+
resblock(
|
117
|
+
ch,
|
118
|
+
k,
|
119
|
+
d,
|
120
|
+
snake_logscale=cfg.snake_logscale,
|
121
|
+
activation=cfg.resblock_activation,
|
122
|
+
)
|
123
|
+
)
|
124
|
+
|
125
|
+
self.activation_post = actv(ch, alpha_logscale=cfg.snake_logscale)
|
126
|
+
# post conv
|
127
|
+
|
128
|
+
self.conv_post = weight_norm(
|
129
|
+
nn.Conv1d(ch, 1, 7, 1, padding=3, bias=self.cfg.use_bias_at_final)
|
130
|
+
)
|
131
|
+
self._use_tanh = self.cfg.use_tanh_at_final
|
132
|
+
|
133
|
+
# weight initialization
|
134
|
+
for i in range(len(self.ups)):
|
135
|
+
self.ups[i].apply(self.init_weights)
|
136
|
+
self.conv_post.apply(self.init_weights)
|
137
|
+
|
138
|
+
def forward(self, x: Tensor):
|
139
|
+
# pre conv
|
140
|
+
x = self.conv_pre(x)
|
141
|
+
|
142
|
+
for i in range(self.num_upsamples):
|
143
|
+
# upsampling
|
144
|
+
for i_up in range(len(self.ups[i])):
|
145
|
+
x = self.ups[i][i_up](x)
|
146
|
+
# AMP blocks
|
147
|
+
xs = None
|
148
|
+
for j in range(self.num_kernels):
|
149
|
+
if xs is None:
|
150
|
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
151
|
+
else:
|
152
|
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
153
|
+
x = xs / self.num_kernels
|
154
|
+
|
155
|
+
# post conv
|
156
|
+
x = self.activation_post(x)
|
157
|
+
x = self.conv_post(x)
|
158
|
+
if self._use_tanh:
|
159
|
+
return x.tanh()
|
160
|
+
return x
|
161
|
+
|
162
|
+
@classmethod
|
163
|
+
def from_pretrained(
|
164
|
+
cls,
|
165
|
+
model_file: PathLike,
|
166
|
+
model_config: Union[
|
167
|
+
BemaGANv2Config, Dict[str, Any], PathLike
|
168
|
+
] = BemaGANv2Config(),
|
169
|
+
*,
|
170
|
+
remove_norms: bool = False,
|
171
|
+
strict: bool = True,
|
172
|
+
map_location: str = "cpu",
|
173
|
+
weights_only: bool = False,
|
174
|
+
**kwargs,
|
175
|
+
):
|
176
|
+
|
177
|
+
is_file(model_file, validate=True)
|
178
|
+
model_state_dict = torch.load(
|
179
|
+
model_file,
|
180
|
+
weights_only=weights_only,
|
181
|
+
map_location=map_location,
|
182
|
+
)
|
183
|
+
|
184
|
+
if isinstance(model_config, BemaGANv2Config):
|
185
|
+
h = model_config
|
186
|
+
elif isinstance(model_config, dict):
|
187
|
+
h = BemaGANv2Config(**model_config)
|
188
|
+
elif isinstance(model_config, (str, Path, bytes)):
|
189
|
+
h = BemaGANv2Config(
|
190
|
+
**load_json(model_config, BemaGANv2Config().state_dict())
|
191
|
+
)
|
192
|
+
|
193
|
+
model = cls(h)
|
194
|
+
if remove_norms:
|
195
|
+
model.remove_norms()
|
196
|
+
try:
|
197
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
198
|
+
return model
|
199
|
+
except RuntimeError:
|
200
|
+
print(
|
201
|
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
202
|
+
)
|
203
|
+
model.remove_norms()
|
204
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
205
|
+
return model
|
@@ -2,9 +2,9 @@ from lt_utils.common import *
|
|
2
2
|
from lt_tensor.torch_commons import *
|
3
3
|
from lt_tensor.model_zoo.convs import ConvNets
|
4
4
|
from lt_tensor.config_templates import ModelConfig
|
5
|
-
from lt_tensor.model_zoo.activations import
|
5
|
+
from lt_tensor.model_zoo.activations import alias_free
|
6
6
|
from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
|
7
|
-
from lt_utils.file_ops import load_json, is_file
|
7
|
+
from lt_utils.file_ops import load_json, is_file
|
8
8
|
|
9
9
|
|
10
10
|
class BigVGANConfig(ModelConfig):
|
@@ -78,8 +78,9 @@ class BigVGAN(ConvNets):
|
|
78
78
|
|
79
79
|
"""
|
80
80
|
|
81
|
-
def __init__(self, cfg: BigVGANConfig):
|
81
|
+
def __init__(self, cfg: Union[BigVGANConfig, Dict[str, object]] = BigVGANConfig()):
|
82
82
|
super().__init__()
|
83
|
+
cfg = cfg if isinstance(cfg, BigVGANConfig) else BigVGANConfig(**cfg)
|
83
84
|
self.cfg = cfg
|
84
85
|
actv = get_snake(self.cfg.activation)
|
85
86
|
|
@@ -173,46 +174,16 @@ class BigVGAN(ConvNets):
|
|
173
174
|
return x.tanh()
|
174
175
|
return x.clamp(min=-1.0, max=1.0)
|
175
176
|
|
176
|
-
def load_weights(
|
177
|
-
self,
|
178
|
-
path,
|
179
|
-
strict=False,
|
180
|
-
assign=False,
|
181
|
-
weights_only=False,
|
182
|
-
mmap=None,
|
183
|
-
raise_if_not_exists=False,
|
184
|
-
**pickle_load_args,
|
185
|
-
):
|
186
|
-
try:
|
187
|
-
return super().load_weights(
|
188
|
-
path,
|
189
|
-
raise_if_not_exists,
|
190
|
-
strict,
|
191
|
-
assign,
|
192
|
-
weights_only,
|
193
|
-
mmap,
|
194
|
-
**pickle_load_args,
|
195
|
-
)
|
196
|
-
except RuntimeError:
|
197
|
-
self.remove_norms()
|
198
|
-
return super().load_weights(
|
199
|
-
path,
|
200
|
-
raise_if_not_exists,
|
201
|
-
strict,
|
202
|
-
assign,
|
203
|
-
weights_only,
|
204
|
-
mmap,
|
205
|
-
**pickle_load_args,
|
206
|
-
)
|
207
|
-
|
208
177
|
@classmethod
|
209
178
|
def from_pretrained(
|
210
179
|
cls,
|
211
180
|
model_file: PathLike,
|
212
|
-
model_config: Union[
|
181
|
+
model_config: Union[
|
182
|
+
BigVGANConfig, Dict[str, Any], Dict[str, Any], PathLike
|
183
|
+
] = BigVGANConfig(),
|
213
184
|
*,
|
214
185
|
remove_norms: bool = False,
|
215
|
-
strict: bool =
|
186
|
+
strict: bool = True,
|
216
187
|
map_location: str = "cpu",
|
217
188
|
weights_only: bool = False,
|
218
189
|
**kwargs,
|
@@ -220,13 +191,17 @@ class BigVGAN(ConvNets):
|
|
220
191
|
|
221
192
|
is_file(model_file, validate=True)
|
222
193
|
model_state_dict = torch.load(
|
223
|
-
model_file,
|
194
|
+
model_file,
|
195
|
+
weights_only=weights_only,
|
196
|
+
map_location=map_location,
|
224
197
|
)
|
225
198
|
|
226
199
|
if isinstance(model_config, BigVGANConfig):
|
227
200
|
h = model_config
|
228
|
-
|
201
|
+
elif isinstance(model_config, dict):
|
229
202
|
h = BigVGANConfig(**model_config)
|
203
|
+
elif isinstance(model_config, (str, Path, bytes)):
|
204
|
+
h = BigVGANConfig(**load_json(model_config, BigVGANConfig().state_dict()))
|
230
205
|
|
231
206
|
model = cls(h)
|
232
207
|
if remove_norms:
|
@@ -177,43 +177,44 @@ class ResidualBlock(Model):
|
|
177
177
|
|
178
178
|
|
179
179
|
class DiffWave(Model):
|
180
|
-
def __init__(self,
|
180
|
+
def __init__(self, cfg: Union[DiffWaveConfig, dict[str, object]] = DiffWaveConfig()):
|
181
181
|
super().__init__()
|
182
|
-
|
183
|
-
self.
|
182
|
+
cfg = cfg if isinstance(cfg, DiffWaveConfig) else DiffWaveConfig(**cfg)
|
183
|
+
self.cfg = cfg
|
184
|
+
self.n_hop = self.cfg.hop_samples
|
184
185
|
self.input_projection = ConvEXT(
|
185
186
|
in_channels=1,
|
186
|
-
out_channels=
|
187
|
+
out_channels=cfg.residual_channels,
|
187
188
|
kernel_size=1,
|
188
|
-
apply_norm=self.
|
189
|
+
apply_norm=self.cfg.apply_norm,
|
189
190
|
activation_out=nn.LeakyReLU(0.1),
|
190
191
|
)
|
191
|
-
self.diffusion_embedding = DiffusionEmbedding(len(
|
192
|
+
self.diffusion_embedding = DiffusionEmbedding(len(cfg.noise_schedule))
|
192
193
|
self.spectrogram_upsampler = (
|
193
|
-
SpectrogramUpsampler() if not self.
|
194
|
+
SpectrogramUpsampler() if not self.cfg.unconditional else None
|
194
195
|
)
|
195
196
|
|
196
197
|
self.residual_layers = nn.ModuleList(
|
197
198
|
[
|
198
199
|
ResidualBlock(
|
199
|
-
|
200
|
-
|
201
|
-
2 ** (i %
|
202
|
-
uncond=
|
203
|
-
apply_norm=self.
|
200
|
+
cfg.n_mels,
|
201
|
+
cfg.residual_channels,
|
202
|
+
2 ** (i % cfg.dilation_cycle_length),
|
203
|
+
uncond=cfg.unconditional,
|
204
|
+
apply_norm=self.cfg.apply_norm_resblock,
|
204
205
|
)
|
205
|
-
for i in range(
|
206
|
+
for i in range(cfg.residual_layers)
|
206
207
|
]
|
207
208
|
)
|
208
209
|
self.skip_projection = ConvEXT(
|
209
|
-
in_channels=
|
210
|
-
out_channels=
|
210
|
+
in_channels=cfg.residual_channels,
|
211
|
+
out_channels=cfg.residual_channels,
|
211
212
|
kernel_size=1,
|
212
|
-
apply_norm=self.
|
213
|
+
apply_norm=self.cfg.apply_norm,
|
213
214
|
activation_out=nn.LeakyReLU(0.1),
|
214
215
|
)
|
215
216
|
self.output_projection = ConvEXT(
|
216
|
-
|
217
|
+
cfg.residual_channels, 1, 1, apply_norm=self.cfg.apply_norm, init_weights=True,
|
217
218
|
)
|
218
219
|
self.activation = nn.LeakyReLU(0.1)
|
219
220
|
self._res_d = sqrt(len(self.residual_layers))
|
@@ -224,7 +225,7 @@ class DiffWave(Model):
|
|
224
225
|
diffusion_step: Tensor,
|
225
226
|
spectrogram: Optional[Tensor] = None,
|
226
227
|
):
|
227
|
-
if not self.
|
228
|
+
if not self.cfg.unconditional:
|
228
229
|
assert spectrogram is not None
|
229
230
|
if audio.ndim < 3:
|
230
231
|
if audio.ndim == 2:
|
@@ -234,7 +235,7 @@ class DiffWave(Model):
|
|
234
235
|
|
235
236
|
x = self.input_projection(audio)
|
236
237
|
diffusion_step = self.diffusion_embedding(diffusion_step)
|
237
|
-
if not self.
|
238
|
+
if not self.cfg.unconditional: # use conditional model
|
238
239
|
spectrogram = self.spectrogram_upsampler(spectrogram)
|
239
240
|
|
240
241
|
skip = torch.zeros_like(x, device=x.device)
|
@@ -5,7 +5,7 @@ from lt_utils.common import *
|
|
5
5
|
from lt_tensor.torch_commons import *
|
6
6
|
from lt_tensor.model_zoo.convs import ConvNets
|
7
7
|
from lt_tensor.config_templates import ModelConfig
|
8
|
-
from lt_utils.file_ops import is_file
|
8
|
+
from lt_utils.file_ops import is_file, load_json
|
9
9
|
from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
|
10
10
|
|
11
11
|
|
@@ -16,11 +16,15 @@ def get_padding(kernel_size, dilation=1):
|
|
16
16
|
class HifiganConfig(ModelConfig):
|
17
17
|
# Training params
|
18
18
|
in_channels: int = 80
|
19
|
-
upsample_rates: List[Union[int, List[int]]] = [8,8,2,2]
|
20
|
-
upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4]
|
19
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2]
|
20
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4]
|
21
21
|
upsample_initial_channel: int = 512
|
22
22
|
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
|
23
|
-
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
23
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
24
|
+
[1, 3, 5],
|
25
|
+
[1, 3, 5],
|
26
|
+
[1, 3, 5],
|
27
|
+
]
|
24
28
|
|
25
29
|
activation: nn.Module = nn.LeakyReLU(0.1)
|
26
30
|
resblock_activation: nn.Module = nn.LeakyReLU(0.1)
|
@@ -29,10 +33,10 @@ class HifiganConfig(ModelConfig):
|
|
29
33
|
def __init__(
|
30
34
|
self,
|
31
35
|
in_channels: int = 80,
|
32
|
-
upsample_rates: List[Union[int, List[int]]] = [8,8,2,2],
|
33
|
-
upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4],
|
36
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2],
|
37
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4],
|
34
38
|
upsample_initial_channel: int = 512,
|
35
|
-
resblock_kernel_sizes: List[Union[int, List[int]]] = [3,7,11],
|
39
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
36
40
|
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
37
41
|
[1, 3, 5],
|
38
42
|
[1, 3, 5],
|
@@ -63,9 +67,11 @@ class HifiganConfig(ModelConfig):
|
|
63
67
|
|
64
68
|
|
65
69
|
class HifiganGenerator(ConvNets):
|
66
|
-
def __init__(self, cfg: HifiganConfig = HifiganConfig()):
|
70
|
+
def __init__(self, cfg: Union[HifiganConfig, Dict[str, object]] = HifiganConfig()):
|
67
71
|
super().__init__()
|
72
|
+
cfg = cfg if isinstance(cfg, HifiganConfig) else HifiganConfig(**cfg)
|
68
73
|
self.cfg = cfg
|
74
|
+
|
69
75
|
self.num_kernels = len(cfg.resblock_kernel_sizes)
|
70
76
|
self.num_upsamples = len(cfg.upsample_rates)
|
71
77
|
self.conv_pre = weight_norm(
|
@@ -115,46 +121,16 @@ class HifiganGenerator(ConvNets):
|
|
115
121
|
x = self.conv_post(self.activation(x))
|
116
122
|
return x.tanh()
|
117
123
|
|
118
|
-
def load_weights(
|
119
|
-
self,
|
120
|
-
path,
|
121
|
-
strict=False,
|
122
|
-
assign=False,
|
123
|
-
weights_only=False,
|
124
|
-
mmap=None,
|
125
|
-
raise_if_not_exists=False,
|
126
|
-
**pickle_load_args,
|
127
|
-
):
|
128
|
-
try:
|
129
|
-
return super().load_weights(
|
130
|
-
path,
|
131
|
-
raise_if_not_exists,
|
132
|
-
strict,
|
133
|
-
assign,
|
134
|
-
weights_only,
|
135
|
-
mmap,
|
136
|
-
**pickle_load_args,
|
137
|
-
)
|
138
|
-
except RuntimeError:
|
139
|
-
self.remove_norms()
|
140
|
-
return super().load_weights(
|
141
|
-
path,
|
142
|
-
raise_if_not_exists,
|
143
|
-
strict,
|
144
|
-
assign,
|
145
|
-
weights_only,
|
146
|
-
mmap,
|
147
|
-
**pickle_load_args,
|
148
|
-
)
|
149
|
-
|
150
124
|
@classmethod
|
151
125
|
def from_pretrained(
|
152
126
|
cls,
|
153
127
|
model_file: PathLike,
|
154
|
-
model_config: Union[
|
128
|
+
model_config: Union[
|
129
|
+
HifiganConfig, Dict[str, Any], Dict[str, Any], PathLike
|
130
|
+
] = HifiganConfig(),
|
155
131
|
*,
|
156
132
|
remove_norms: bool = False,
|
157
|
-
strict: bool =
|
133
|
+
strict: bool = True,
|
158
134
|
map_location: str = "cpu",
|
159
135
|
weights_only: bool = False,
|
160
136
|
**kwargs,
|
@@ -162,13 +138,17 @@ class HifiganGenerator(ConvNets):
|
|
162
138
|
|
163
139
|
is_file(model_file, validate=True)
|
164
140
|
model_state_dict = torch.load(
|
165
|
-
model_file,
|
141
|
+
model_file,
|
142
|
+
weights_only=weights_only,
|
143
|
+
map_location=map_location,
|
166
144
|
)
|
167
145
|
|
168
146
|
if isinstance(model_config, HifiganConfig):
|
169
147
|
h = model_config
|
170
|
-
|
148
|
+
elif isinstance(model_config, dict):
|
171
149
|
h = HifiganConfig(**model_config)
|
150
|
+
elif isinstance(model_config, (str, Path, bytes)):
|
151
|
+
h = HifiganConfig(**load_json(model_config, HifiganConfig().state_dict()))
|
172
152
|
|
173
153
|
model = cls(h)
|
174
154
|
if remove_norms:
|
@@ -3,7 +3,7 @@ from lt_utils.common import *
|
|
3
3
|
from lt_tensor.torch_commons import *
|
4
4
|
from lt_tensor.model_zoo.convs import ConvNets
|
5
5
|
from lt_tensor.config_templates import ModelConfig
|
6
|
-
from lt_utils.file_ops import is_file
|
6
|
+
from lt_utils.file_ops import is_file, load_json
|
7
7
|
from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
|
8
8
|
|
9
9
|
|
@@ -67,8 +67,11 @@ class iSTFTNetConfig(ModelConfig):
|
|
67
67
|
|
68
68
|
|
69
69
|
class iSTFTNetGenerator(ConvNets):
|
70
|
-
def __init__(
|
70
|
+
def __init__(
|
71
|
+
self, cfg: Union[iSTFTNetConfig, Dict[str, object]] = iSTFTNetConfig()
|
72
|
+
):
|
71
73
|
super().__init__()
|
74
|
+
cfg = cfg if isinstance(cfg, iSTFTNetConfig) else iSTFTNetConfig(**cfg)
|
72
75
|
self.cfg = cfg
|
73
76
|
self.num_kernels = len(cfg.resblock_kernel_sizes)
|
74
77
|
self.num_upsamples = len(cfg.upsample_rates)
|
@@ -146,46 +149,16 @@ class iSTFTNetGenerator(ConvNets):
|
|
146
149
|
|
147
150
|
return spec, phase
|
148
151
|
|
149
|
-
def load_weights(
|
150
|
-
self,
|
151
|
-
path,
|
152
|
-
strict=False,
|
153
|
-
assign=False,
|
154
|
-
weights_only=False,
|
155
|
-
mmap=None,
|
156
|
-
raise_if_not_exists=False,
|
157
|
-
**pickle_load_args,
|
158
|
-
):
|
159
|
-
try:
|
160
|
-
return super().load_weights(
|
161
|
-
path,
|
162
|
-
raise_if_not_exists,
|
163
|
-
strict,
|
164
|
-
assign,
|
165
|
-
weights_only,
|
166
|
-
mmap,
|
167
|
-
**pickle_load_args,
|
168
|
-
)
|
169
|
-
except RuntimeError:
|
170
|
-
self.remove_norms()
|
171
|
-
return super().load_weights(
|
172
|
-
path,
|
173
|
-
raise_if_not_exists,
|
174
|
-
strict,
|
175
|
-
assign,
|
176
|
-
weights_only,
|
177
|
-
mmap,
|
178
|
-
**pickle_load_args,
|
179
|
-
)
|
180
|
-
|
181
152
|
@classmethod
|
182
153
|
def from_pretrained(
|
183
154
|
cls,
|
184
155
|
model_file: PathLike,
|
185
|
-
model_config: Union[
|
156
|
+
model_config: Union[
|
157
|
+
iSTFTNetConfig, Dict[str, Any], Dict[str, Any], PathLike
|
158
|
+
] = iSTFTNetConfig(),
|
186
159
|
*,
|
187
160
|
remove_norms: bool = False,
|
188
|
-
strict: bool =
|
161
|
+
strict: bool = True,
|
189
162
|
map_location: str = "cpu",
|
190
163
|
weights_only: bool = False,
|
191
164
|
**kwargs,
|
@@ -193,14 +166,17 @@ class iSTFTNetGenerator(ConvNets):
|
|
193
166
|
|
194
167
|
is_file(model_file, validate=True)
|
195
168
|
model_state_dict = torch.load(
|
196
|
-
model_file,
|
169
|
+
model_file,
|
170
|
+
weights_only=weights_only,
|
171
|
+
map_location=map_location,
|
197
172
|
)
|
198
173
|
|
199
174
|
if isinstance(model_config, iSTFTNetConfig):
|
200
175
|
h = model_config
|
201
|
-
|
176
|
+
elif isinstance(model_config, dict):
|
202
177
|
h = iSTFTNetConfig(**model_config)
|
203
|
-
|
178
|
+
elif isinstance(model_config, (str, Path, bytes)):
|
179
|
+
h = iSTFTNetConfig(**load_json(model_config, iSTFTNetConfig().state_dict()))
|
204
180
|
model = cls(h)
|
205
181
|
if remove_norms:
|
206
182
|
model.remove_norms()
|
lt_tensor/model_zoo/convs.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1
1
|
__all__ = ["ConvNets", "ConvEXT"]
|
2
|
-
import math
|
3
2
|
from lt_utils.common import *
|
4
|
-
import torch.nn.functional as F
|
5
3
|
from lt_tensor.torch_commons import *
|
6
4
|
from lt_tensor.model_base import Model
|
7
|
-
from lt_tensor.misc_utils import log_tensor
|
8
|
-
from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
|
9
5
|
from lt_utils.misc_utils import default
|
10
6
|
|
11
7
|
|
@@ -52,6 +48,41 @@ class ConvNets(Model):
|
|
52
48
|
if "Conv" in m.__class__.__name__:
|
53
49
|
m.weight.data.normal_(mean, std)
|
54
50
|
|
51
|
+
def load_weights(
|
52
|
+
self,
|
53
|
+
path,
|
54
|
+
strict=False,
|
55
|
+
assign=False,
|
56
|
+
weights_only=False,
|
57
|
+
mmap=None,
|
58
|
+
raise_if_not_exists=False,
|
59
|
+
**pickle_load_args,
|
60
|
+
):
|
61
|
+
try:
|
62
|
+
return super().load_weights(
|
63
|
+
path,
|
64
|
+
raise_if_not_exists,
|
65
|
+
strict,
|
66
|
+
assign,
|
67
|
+
weights_only,
|
68
|
+
mmap,
|
69
|
+
**pickle_load_args,
|
70
|
+
)
|
71
|
+
except RuntimeError as e:
|
72
|
+
try:
|
73
|
+
self.remove_norms()
|
74
|
+
return super().load_weights(
|
75
|
+
path,
|
76
|
+
raise_if_not_exists,
|
77
|
+
strict,
|
78
|
+
assign,
|
79
|
+
weights_only,
|
80
|
+
mmap,
|
81
|
+
**pickle_load_args,
|
82
|
+
)
|
83
|
+
except:
|
84
|
+
raise e
|
85
|
+
|
55
86
|
|
56
87
|
class ConvEXT(ConvNets):
|
57
88
|
def __init__(
|