lt-tensor 0.0.1a34__py3-none-any.whl → 0.0.1a36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/losses.py +11 -7
- lt_tensor/lr_schedulers.py +147 -21
- lt_tensor/misc_utils.py +35 -42
- lt_tensor/model_zoo/activations/__init__.py +3 -0
- lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
- lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
- lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
- lt_tensor/model_zoo/audio_models/__init__.py +2 -2
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +22 -357
- lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
- lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
- lt_tensor/model_zoo/convs.py +21 -32
- lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
- lt_tensor/model_zoo/losses/CQT/transforms.py +336 -0
- lt_tensor/model_zoo/losses/CQT/utils.py +519 -0
- lt_tensor/model_zoo/losses/discriminators.py +375 -37
- lt_tensor/processors/audio.py +67 -57
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/METADATA +1 -1
- lt_tensor-0.0.1a36.dist-info/RECORD +43 -0
- lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
- lt_tensor-0.0.1a34.dist-info/RECORD +0 -37
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a36.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,243 @@
|
|
1
|
+
from lt_utils.common import *
|
2
|
+
from lt_tensor.torch_commons import *
|
3
|
+
from lt_tensor.model_zoo.convs import ConvNets
|
4
|
+
from lt_tensor.config_templates import ModelConfig
|
5
|
+
from lt_tensor.model_zoo.activations import snake, alias_free
|
6
|
+
from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
|
7
|
+
from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
|
8
|
+
|
9
|
+
|
10
|
+
class BigVGANConfig(ModelConfig):
|
11
|
+
# Training params
|
12
|
+
in_channels: int = 80
|
13
|
+
upsample_rates: List[Union[int, List[int]]] = [4, 4, 2, 2, 2, 2]
|
14
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [8, 8, 4, 4, 4, 4]
|
15
|
+
upsample_initial_channel: int = 1536
|
16
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
|
17
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
18
|
+
[1, 3, 5],
|
19
|
+
[1, 3, 5],
|
20
|
+
[1, 3, 5],
|
21
|
+
]
|
22
|
+
|
23
|
+
activation: Literal["snake", "snakebeta"] = "snakebeta"
|
24
|
+
resblock_activation: Literal["snake", "snakebeta"] = "snakebeta"
|
25
|
+
resblock: int = 0
|
26
|
+
use_bias_at_final: bool = True
|
27
|
+
use_tanh_at_final: bool = True
|
28
|
+
snake_logscale: bool = True
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
in_channels: int = 80,
|
33
|
+
upsample_rates: List[Union[int, List[int]]] = [4, 4, 2, 2, 2, 2],
|
34
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [8, 8, 4, 4, 4, 4],
|
35
|
+
upsample_initial_channel: int = 1536,
|
36
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
37
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
38
|
+
[1, 3, 5],
|
39
|
+
[1, 3, 5],
|
40
|
+
[1, 3, 5],
|
41
|
+
],
|
42
|
+
activation: Literal["snake", "snakebeta"] = "snakebeta",
|
43
|
+
resblock_activation: Literal["snake", "snakebeta"] = "snakebeta",
|
44
|
+
resblock: Union[int, str] = "1",
|
45
|
+
use_bias_at_final: bool = False,
|
46
|
+
use_tanh_at_final: bool = False,
|
47
|
+
*args,
|
48
|
+
**kwargs,
|
49
|
+
):
|
50
|
+
settings = {
|
51
|
+
"in_channels": in_channels,
|
52
|
+
"upsample_rates": upsample_rates,
|
53
|
+
"upsample_kernel_sizes": upsample_kernel_sizes,
|
54
|
+
"upsample_initial_channel": upsample_initial_channel,
|
55
|
+
"resblock_kernel_sizes": resblock_kernel_sizes,
|
56
|
+
"resblock_dilation_sizes": resblock_dilation_sizes,
|
57
|
+
"activation": activation,
|
58
|
+
"resblock_activation": resblock_activation,
|
59
|
+
"resblock": resblock,
|
60
|
+
"use_bias_at_final": use_bias_at_final,
|
61
|
+
"use_tanh_at_final": use_tanh_at_final,
|
62
|
+
}
|
63
|
+
super().__init__(**settings)
|
64
|
+
|
65
|
+
def post_process(self):
|
66
|
+
if isinstance(self.resblock, str):
|
67
|
+
self.resblock = 0 if self.resblock == "1" else 1
|
68
|
+
|
69
|
+
|
70
|
+
class BigVGAN(ConvNets):
|
71
|
+
"""Modified from 'https://github.com/NVIDIA/BigVGAN/blob/main/bigvgan.py' under mit license.
|
72
|
+
|
73
|
+
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
74
|
+
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
cfg (BigVGANConfig): Hyperparameters.
|
78
|
+
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(self, cfg: BigVGANConfig):
|
82
|
+
super().__init__()
|
83
|
+
self.cfg = cfg
|
84
|
+
actv = get_snake(self.cfg.activation)
|
85
|
+
|
86
|
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
87
|
+
|
88
|
+
self.num_kernels = len(cfg.resblock_kernel_sizes)
|
89
|
+
self.num_upsamples = len(cfg.upsample_rates)
|
90
|
+
|
91
|
+
# Pre-conv
|
92
|
+
self.conv_pre = weight_norm(
|
93
|
+
nn.Conv1d(cfg.in_channels, cfg.upsample_initial_channel, 7, 1, padding=3)
|
94
|
+
)
|
95
|
+
|
96
|
+
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
97
|
+
resblock_class = AMPBlock1 if cfg.resblock == 0 else AMPBlock2
|
98
|
+
|
99
|
+
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
100
|
+
self.ups = nn.ModuleList()
|
101
|
+
for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
|
102
|
+
self.ups.append(
|
103
|
+
nn.ModuleList(
|
104
|
+
[
|
105
|
+
weight_norm(
|
106
|
+
nn.ConvTranspose1d(
|
107
|
+
cfg.upsample_initial_channel // (2**i),
|
108
|
+
cfg.upsample_initial_channel // (2 ** (i + 1)),
|
109
|
+
k,
|
110
|
+
u,
|
111
|
+
padding=(k - u) // 2,
|
112
|
+
)
|
113
|
+
)
|
114
|
+
]
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
119
|
+
self.resblocks = nn.ModuleList()
|
120
|
+
for i in range(len(self.ups)):
|
121
|
+
ch = cfg.upsample_initial_channel // (2 ** (i + 1))
|
122
|
+
for k, d in zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes):
|
123
|
+
self.resblocks.append(
|
124
|
+
resblock_class(
|
125
|
+
ch,
|
126
|
+
k,
|
127
|
+
d,
|
128
|
+
snake_logscale=cfg.snake_logscale,
|
129
|
+
activation=cfg.resblock_activation,
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
133
|
+
# Post-conv
|
134
|
+
activation_post = actv(ch, alpha_logscale=cfg.snake_logscale)
|
135
|
+
|
136
|
+
self.activation_post = alias_free.Activation1d(activation=activation_post)
|
137
|
+
|
138
|
+
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
139
|
+
self.conv_post = weight_norm(
|
140
|
+
nn.Conv1d(ch, 1, 7, 1, padding=3, bias=self.cfg.use_bias_at_final)
|
141
|
+
)
|
142
|
+
|
143
|
+
# Weight initialization
|
144
|
+
for i in range(len(self.ups)):
|
145
|
+
self.ups[i].apply(self.init_weights)
|
146
|
+
self.conv_post.apply(self.init_weights)
|
147
|
+
|
148
|
+
# Final tanh activation. Defaults to True for backward compatibility
|
149
|
+
self.use_tanh_at_final = cfg.use_tanh_at_final
|
150
|
+
|
151
|
+
def forward(self, x):
|
152
|
+
# Pre-conv
|
153
|
+
x = self.conv_pre(x)
|
154
|
+
|
155
|
+
for i in range(self.num_upsamples):
|
156
|
+
# Upsampling
|
157
|
+
for i_up in range(len(self.ups[i])):
|
158
|
+
x = self.ups[i][i_up](x)
|
159
|
+
# AMP blocks
|
160
|
+
xs = None
|
161
|
+
for j in range(self.num_kernels):
|
162
|
+
if xs is None:
|
163
|
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
164
|
+
else:
|
165
|
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
166
|
+
x = xs / self.num_kernels
|
167
|
+
|
168
|
+
# Post-conv
|
169
|
+
x = self.activation_post(x)
|
170
|
+
x: Tensor = self.conv_post(x)
|
171
|
+
# Final tanh activation
|
172
|
+
if self.use_tanh_at_final:
|
173
|
+
return x.tanh()
|
174
|
+
return x.clamp(min=-1.0, max=1.0)
|
175
|
+
|
176
|
+
def load_weights(
|
177
|
+
self,
|
178
|
+
path,
|
179
|
+
strict=False,
|
180
|
+
assign=False,
|
181
|
+
weights_only=False,
|
182
|
+
mmap=None,
|
183
|
+
raise_if_not_exists=False,
|
184
|
+
**pickle_load_args,
|
185
|
+
):
|
186
|
+
try:
|
187
|
+
return super().load_weights(
|
188
|
+
path,
|
189
|
+
raise_if_not_exists,
|
190
|
+
strict,
|
191
|
+
assign,
|
192
|
+
weights_only,
|
193
|
+
mmap,
|
194
|
+
**pickle_load_args,
|
195
|
+
)
|
196
|
+
except RuntimeError:
|
197
|
+
self.remove_norms()
|
198
|
+
return super().load_weights(
|
199
|
+
path,
|
200
|
+
raise_if_not_exists,
|
201
|
+
strict,
|
202
|
+
assign,
|
203
|
+
weights_only,
|
204
|
+
mmap,
|
205
|
+
**pickle_load_args,
|
206
|
+
)
|
207
|
+
|
208
|
+
@classmethod
|
209
|
+
def from_pretrained(
|
210
|
+
cls,
|
211
|
+
model_file: PathLike,
|
212
|
+
model_config: Union[BigVGANConfig, Dict[str, Any]],
|
213
|
+
*,
|
214
|
+
remove_norms: bool = False,
|
215
|
+
strict: bool = False,
|
216
|
+
map_location: str = "cpu",
|
217
|
+
weights_only: bool = False,
|
218
|
+
**kwargs,
|
219
|
+
):
|
220
|
+
|
221
|
+
is_file(model_file, validate=True)
|
222
|
+
model_state_dict = torch.load(
|
223
|
+
model_file, weights_only=weights_only, map_location=map_location
|
224
|
+
)
|
225
|
+
|
226
|
+
if isinstance(model_config, BigVGANConfig):
|
227
|
+
h = model_config
|
228
|
+
else:
|
229
|
+
h = BigVGANConfig(**model_config)
|
230
|
+
|
231
|
+
model = cls(h)
|
232
|
+
if remove_norms:
|
233
|
+
model.remove_norms()
|
234
|
+
try:
|
235
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
236
|
+
return model
|
237
|
+
except RuntimeError:
|
238
|
+
print(
|
239
|
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
240
|
+
)
|
241
|
+
model.remove_norms()
|
242
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
243
|
+
return model
|
@@ -1,48 +1,45 @@
|
|
1
1
|
__all__ = ["HifiganGenerator", "HifiganConfig"]
|
2
|
+
|
3
|
+
|
2
4
|
from lt_utils.common import *
|
3
5
|
from lt_tensor.torch_commons import *
|
4
6
|
from lt_tensor.model_zoo.convs import ConvNets
|
5
|
-
from
|
6
|
-
from lt_utils.file_ops import
|
7
|
-
from lt_tensor.
|
7
|
+
from lt_tensor.config_templates import ModelConfig
|
8
|
+
from lt_utils.file_ops import is_file
|
9
|
+
from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
|
8
10
|
|
9
11
|
|
10
12
|
def get_padding(kernel_size, dilation=1):
|
11
13
|
return int((kernel_size * dilation - dilation) / 2)
|
12
14
|
|
13
15
|
|
14
|
-
from lt_tensor.config_templates import ModelConfig
|
15
|
-
|
16
|
-
|
17
16
|
class HifiganConfig(ModelConfig):
|
18
17
|
# Training params
|
19
18
|
in_channels: int = 80
|
20
|
-
upsample_rates: List[Union[int, List[int]]] = [8,
|
21
|
-
upsample_kernel_sizes: List[Union[int, List[int]]] = [16,
|
19
|
+
upsample_rates: List[Union[int, List[int]]] = [8,8,2,2]
|
20
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4]
|
22
21
|
upsample_initial_channel: int = 512
|
23
22
|
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
|
24
|
-
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
25
|
-
[1, 3, 5],
|
26
|
-
[1, 3, 5],
|
27
|
-
[1, 3, 5],
|
28
|
-
]
|
23
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [[1,3,5], [1,3,5], [1,3,5]]
|
29
24
|
|
30
25
|
activation: nn.Module = nn.LeakyReLU(0.1)
|
26
|
+
resblock_activation: nn.Module = nn.LeakyReLU(0.1)
|
31
27
|
resblock: int = 0
|
32
28
|
|
33
29
|
def __init__(
|
34
30
|
self,
|
35
31
|
in_channels: int = 80,
|
36
|
-
upsample_rates: List[Union[int, List[int]]] = [8,
|
37
|
-
upsample_kernel_sizes: List[Union[int, List[int]]] = [16,
|
32
|
+
upsample_rates: List[Union[int, List[int]]] = [8,8,2,2],
|
33
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4],
|
38
34
|
upsample_initial_channel: int = 512,
|
39
|
-
resblock_kernel_sizes: List[Union[int, List[int]]] = [3,
|
35
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3,7,11],
|
40
36
|
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
41
37
|
[1, 3, 5],
|
42
38
|
[1, 3, 5],
|
43
39
|
[1, 3, 5],
|
44
40
|
],
|
45
41
|
activation: nn.Module = nn.LeakyReLU(0.1),
|
42
|
+
resblock_activation: nn.Module = nn.LeakyReLU(0.1),
|
46
43
|
resblock: Union[int, str] = "1",
|
47
44
|
*args,
|
48
45
|
**kwargs,
|
@@ -55,6 +52,7 @@ class HifiganConfig(ModelConfig):
|
|
55
52
|
"resblock_kernel_sizes": resblock_kernel_sizes,
|
56
53
|
"resblock_dilation_sizes": resblock_dilation_sizes,
|
57
54
|
"activation": activation,
|
55
|
+
"resblock_activation": resblock_activation,
|
58
56
|
"resblock": resblock,
|
59
57
|
}
|
60
58
|
super().__init__(**settings)
|
@@ -64,128 +62,6 @@ class HifiganConfig(ModelConfig):
|
|
64
62
|
self.resblock = 0 if self.resblock == "1" else 1
|
65
63
|
|
66
64
|
|
67
|
-
class ResBlock1(ConvNets):
|
68
|
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
69
|
-
super().__init__()
|
70
|
-
|
71
|
-
self.convs1 = nn.ModuleList(
|
72
|
-
[
|
73
|
-
weight_norm(
|
74
|
-
nn.Conv1d(
|
75
|
-
channels,
|
76
|
-
channels,
|
77
|
-
kernel_size,
|
78
|
-
1,
|
79
|
-
dilation=dilation[0],
|
80
|
-
padding=get_padding(kernel_size, dilation[0]),
|
81
|
-
)
|
82
|
-
),
|
83
|
-
weight_norm(
|
84
|
-
nn.Conv1d(
|
85
|
-
channels,
|
86
|
-
channels,
|
87
|
-
kernel_size,
|
88
|
-
1,
|
89
|
-
dilation=dilation[1],
|
90
|
-
padding=get_padding(kernel_size, dilation[1]),
|
91
|
-
)
|
92
|
-
),
|
93
|
-
weight_norm(
|
94
|
-
nn.Conv1d(
|
95
|
-
channels,
|
96
|
-
channels,
|
97
|
-
kernel_size,
|
98
|
-
1,
|
99
|
-
dilation=dilation[2],
|
100
|
-
padding=get_padding(kernel_size, dilation[2]),
|
101
|
-
)
|
102
|
-
),
|
103
|
-
]
|
104
|
-
)
|
105
|
-
self.convs1.apply(self.init_weights)
|
106
|
-
|
107
|
-
self.convs2 = nn.ModuleList(
|
108
|
-
[
|
109
|
-
weight_norm(
|
110
|
-
nn.Conv1d(
|
111
|
-
channels,
|
112
|
-
channels,
|
113
|
-
kernel_size,
|
114
|
-
1,
|
115
|
-
dilation=1,
|
116
|
-
padding=get_padding(kernel_size, 1),
|
117
|
-
)
|
118
|
-
),
|
119
|
-
weight_norm(
|
120
|
-
nn.Conv1d(
|
121
|
-
channels,
|
122
|
-
channels,
|
123
|
-
kernel_size,
|
124
|
-
1,
|
125
|
-
dilation=1,
|
126
|
-
padding=get_padding(kernel_size, 1),
|
127
|
-
)
|
128
|
-
),
|
129
|
-
weight_norm(
|
130
|
-
nn.Conv1d(
|
131
|
-
channels,
|
132
|
-
channels,
|
133
|
-
kernel_size,
|
134
|
-
1,
|
135
|
-
dilation=1,
|
136
|
-
padding=get_padding(kernel_size, 1),
|
137
|
-
)
|
138
|
-
),
|
139
|
-
]
|
140
|
-
)
|
141
|
-
self.convs2.apply(self.init_weights)
|
142
|
-
self.activation = nn.LeakyReLU(0.1)
|
143
|
-
|
144
|
-
def forward(self, x):
|
145
|
-
for c1, c2 in zip(self.convs1, self.convs2):
|
146
|
-
xt = c1(self.activation(x))
|
147
|
-
xt = c2(self.activation(xt))
|
148
|
-
x = xt + x
|
149
|
-
return x
|
150
|
-
|
151
|
-
|
152
|
-
class ResBlock2(ConvNets):
|
153
|
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
154
|
-
super().__init__()
|
155
|
-
self.convs = nn.ModuleList(
|
156
|
-
[
|
157
|
-
weight_norm(
|
158
|
-
nn.Conv1d(
|
159
|
-
channels,
|
160
|
-
channels,
|
161
|
-
kernel_size,
|
162
|
-
1,
|
163
|
-
dilation=dilation[0],
|
164
|
-
padding=get_padding(kernel_size, dilation[0]),
|
165
|
-
)
|
166
|
-
),
|
167
|
-
weight_norm(
|
168
|
-
nn.Conv1d(
|
169
|
-
channels,
|
170
|
-
channels,
|
171
|
-
kernel_size,
|
172
|
-
1,
|
173
|
-
dilation=dilation[1],
|
174
|
-
padding=get_padding(kernel_size, dilation[1]),
|
175
|
-
)
|
176
|
-
),
|
177
|
-
]
|
178
|
-
)
|
179
|
-
self.convs.apply(self.init_weights)
|
180
|
-
self.activation = nn.LeakyReLU(0.1)
|
181
|
-
|
182
|
-
def forward(self, x):
|
183
|
-
for c in self.convs:
|
184
|
-
xt = c(self.activation(x))
|
185
|
-
x = xt + x
|
186
|
-
return x
|
187
|
-
|
188
|
-
|
189
65
|
class HifiganGenerator(ConvNets):
|
190
66
|
def __init__(self, cfg: HifiganConfig = HifiganConfig()):
|
191
67
|
super().__init__()
|
@@ -219,7 +95,7 @@ class HifiganGenerator(ConvNets):
|
|
219
95
|
for j, (k, d) in enumerate(
|
220
96
|
zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes)
|
221
97
|
):
|
222
|
-
self.resblocks.append(resblock(ch, k, d))
|
98
|
+
self.resblocks.append(resblock(ch, k, d, cfg.resblock_activation))
|
223
99
|
|
224
100
|
self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
|
225
101
|
self.ups.apply(self.init_weights)
|
@@ -237,9 +113,7 @@ class HifiganGenerator(ConvNets):
|
|
237
113
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
238
114
|
x = xs / self.num_kernels
|
239
115
|
x = self.conv_post(self.activation(x))
|
240
|
-
x
|
241
|
-
|
242
|
-
return x
|
116
|
+
return x.tanh()
|
243
117
|
|
244
118
|
def load_weights(
|
245
119
|
self,
|
@@ -252,7 +126,7 @@ class HifiganGenerator(ConvNets):
|
|
252
126
|
**pickle_load_args,
|
253
127
|
):
|
254
128
|
try:
|
255
|
-
|
129
|
+
return super().load_weights(
|
256
130
|
path,
|
257
131
|
raise_if_not_exists,
|
258
132
|
strict,
|
@@ -261,18 +135,6 @@ class HifiganGenerator(ConvNets):
|
|
261
135
|
mmap,
|
262
136
|
**pickle_load_args,
|
263
137
|
)
|
264
|
-
if incompatible_keys:
|
265
|
-
self.remove_norms()
|
266
|
-
incompatible_keys = super().load_weights(
|
267
|
-
path,
|
268
|
-
raise_if_not_exists,
|
269
|
-
strict,
|
270
|
-
assign,
|
271
|
-
weights_only,
|
272
|
-
mmap,
|
273
|
-
**pickle_load_args,
|
274
|
-
)
|
275
|
-
return incompatible_keys
|
276
138
|
except RuntimeError:
|
277
139
|
self.remove_norms()
|
278
140
|
return super().load_weights(
|
@@ -291,6 +153,7 @@ class HifiganGenerator(ConvNets):
|
|
291
153
|
model_file: PathLike,
|
292
154
|
model_config: Union[HifiganConfig, Dict[str, Any]],
|
293
155
|
*,
|
156
|
+
remove_norms: bool = False,
|
294
157
|
strict: bool = False,
|
295
158
|
map_location: str = "cpu",
|
296
159
|
weights_only: bool = False,
|
@@ -308,11 +171,11 @@ class HifiganGenerator(ConvNets):
|
|
308
171
|
h = HifiganConfig(**model_config)
|
309
172
|
|
310
173
|
model = cls(h)
|
174
|
+
if remove_norms:
|
175
|
+
model.remove_norms()
|
311
176
|
try:
|
312
|
-
|
313
|
-
|
314
|
-
model.remove_norms()
|
315
|
-
model.load_state_dict(model_state_dict, strict=strict)
|
177
|
+
model.load_state_dict(model_state_dict, strict=strict)
|
178
|
+
return model
|
316
179
|
except RuntimeError:
|
317
180
|
print(
|
318
181
|
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
@@ -320,201 +183,3 @@ class HifiganGenerator(ConvNets):
|
|
320
183
|
model.remove_norms()
|
321
184
|
model.load_state_dict(model_state_dict, strict=strict)
|
322
185
|
return model
|
323
|
-
|
324
|
-
|
325
|
-
class DiscriminatorP(ConvNets):
|
326
|
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
327
|
-
super(DiscriminatorP, self).__init__()
|
328
|
-
self.period = period
|
329
|
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
330
|
-
self.convs = nn.ModuleList(
|
331
|
-
[
|
332
|
-
norm_f(
|
333
|
-
nn.Conv2d(
|
334
|
-
1,
|
335
|
-
32,
|
336
|
-
(kernel_size, 1),
|
337
|
-
(stride, 1),
|
338
|
-
padding=(get_padding(5, 1), 0),
|
339
|
-
)
|
340
|
-
),
|
341
|
-
norm_f(
|
342
|
-
nn.Conv2d(
|
343
|
-
32,
|
344
|
-
128,
|
345
|
-
(kernel_size, 1),
|
346
|
-
(stride, 1),
|
347
|
-
padding=(get_padding(5, 1), 0),
|
348
|
-
)
|
349
|
-
),
|
350
|
-
norm_f(
|
351
|
-
nn.Conv2d(
|
352
|
-
128,
|
353
|
-
512,
|
354
|
-
(kernel_size, 1),
|
355
|
-
(stride, 1),
|
356
|
-
padding=(get_padding(5, 1), 0),
|
357
|
-
)
|
358
|
-
),
|
359
|
-
norm_f(
|
360
|
-
nn.Conv2d(
|
361
|
-
512,
|
362
|
-
1024,
|
363
|
-
(kernel_size, 1),
|
364
|
-
(stride, 1),
|
365
|
-
padding=(get_padding(5, 1), 0),
|
366
|
-
)
|
367
|
-
),
|
368
|
-
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
369
|
-
]
|
370
|
-
)
|
371
|
-
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
372
|
-
self.activation = nn.LeakyReLU(0.1)
|
373
|
-
|
374
|
-
def forward(self, x):
|
375
|
-
fmap = []
|
376
|
-
|
377
|
-
# 1d to 2d
|
378
|
-
b, c, t = x.shape
|
379
|
-
if t % self.period != 0: # pad first
|
380
|
-
n_pad = self.period - (t % self.period)
|
381
|
-
x = F.pad(x, (0, n_pad), "reflect")
|
382
|
-
t = t + n_pad
|
383
|
-
x = x.view(b, c, t // self.period, self.period)
|
384
|
-
|
385
|
-
for l in self.convs:
|
386
|
-
x = l(x)
|
387
|
-
x = self.activation(x)
|
388
|
-
fmap.append(x)
|
389
|
-
x = self.conv_post(x)
|
390
|
-
fmap.append(x)
|
391
|
-
x = torch.flatten(x, 1, -1)
|
392
|
-
|
393
|
-
return x, fmap
|
394
|
-
|
395
|
-
|
396
|
-
class MultiPeriodDiscriminator(ConvNets):
|
397
|
-
def __init__(self):
|
398
|
-
super(MultiPeriodDiscriminator, self).__init__()
|
399
|
-
self.discriminators = nn.ModuleList(
|
400
|
-
[
|
401
|
-
DiscriminatorP(2),
|
402
|
-
DiscriminatorP(3),
|
403
|
-
DiscriminatorP(5),
|
404
|
-
DiscriminatorP(7),
|
405
|
-
DiscriminatorP(11),
|
406
|
-
]
|
407
|
-
)
|
408
|
-
|
409
|
-
def forward(self, y, y_hat):
|
410
|
-
y_d_rs = []
|
411
|
-
y_d_gs = []
|
412
|
-
fmap_rs = []
|
413
|
-
fmap_gs = []
|
414
|
-
for i, d in enumerate(self.discriminators):
|
415
|
-
y_d_r, fmap_r = d(y)
|
416
|
-
y_d_g, fmap_g = d(y_hat)
|
417
|
-
y_d_rs.append(y_d_r)
|
418
|
-
fmap_rs.append(fmap_r)
|
419
|
-
y_d_gs.append(y_d_g)
|
420
|
-
fmap_gs.append(fmap_g)
|
421
|
-
|
422
|
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
423
|
-
|
424
|
-
|
425
|
-
class DiscriminatorS(ConvNets):
|
426
|
-
def __init__(self, use_spectral_norm=False):
|
427
|
-
super(DiscriminatorS, self).__init__()
|
428
|
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
429
|
-
self.convs = nn.ModuleList(
|
430
|
-
[
|
431
|
-
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
432
|
-
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
433
|
-
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
434
|
-
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
435
|
-
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
436
|
-
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
437
|
-
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
438
|
-
]
|
439
|
-
)
|
440
|
-
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
441
|
-
self.activation = nn.LeakyReLU(0.1)
|
442
|
-
|
443
|
-
def forward(self, x):
|
444
|
-
fmap = []
|
445
|
-
for l in self.convs:
|
446
|
-
x = l(x)
|
447
|
-
x = self.activation(x)
|
448
|
-
fmap.append(x)
|
449
|
-
x = self.conv_post(x)
|
450
|
-
fmap.append(x)
|
451
|
-
x = torch.flatten(x, 1, -1)
|
452
|
-
|
453
|
-
return x, fmap
|
454
|
-
|
455
|
-
|
456
|
-
class MultiScaleDiscriminator(ConvNets):
|
457
|
-
def __init__(self):
|
458
|
-
super(MultiScaleDiscriminator, self).__init__()
|
459
|
-
self.discriminators = nn.ModuleList(
|
460
|
-
[
|
461
|
-
DiscriminatorS(use_spectral_norm=True),
|
462
|
-
DiscriminatorS(),
|
463
|
-
DiscriminatorS(),
|
464
|
-
]
|
465
|
-
)
|
466
|
-
self.meanpools = nn.ModuleList(
|
467
|
-
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
468
|
-
)
|
469
|
-
|
470
|
-
def forward(self, y, y_hat):
|
471
|
-
y_d_rs = []
|
472
|
-
y_d_gs = []
|
473
|
-
fmap_rs = []
|
474
|
-
fmap_gs = []
|
475
|
-
for i, d in enumerate(self.discriminators):
|
476
|
-
if i != 0:
|
477
|
-
y = self.meanpools[i - 1](y)
|
478
|
-
y_hat = self.meanpools[i - 1](y_hat)
|
479
|
-
y_d_r, fmap_r = d(y)
|
480
|
-
y_d_g, fmap_g = d(y_hat)
|
481
|
-
y_d_rs.append(y_d_r)
|
482
|
-
fmap_rs.append(fmap_r)
|
483
|
-
y_d_gs.append(y_d_g)
|
484
|
-
fmap_gs.append(fmap_g)
|
485
|
-
|
486
|
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
487
|
-
|
488
|
-
|
489
|
-
def feature_loss(fmap_r, fmap_g):
|
490
|
-
loss = 0
|
491
|
-
for dr, dg in zip(fmap_r, fmap_g):
|
492
|
-
for rl, gl in zip(dr, dg):
|
493
|
-
loss += torch.mean(torch.abs(rl - gl))
|
494
|
-
|
495
|
-
return loss * 2
|
496
|
-
|
497
|
-
|
498
|
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
499
|
-
loss = 0
|
500
|
-
r_losses = []
|
501
|
-
g_losses = []
|
502
|
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
503
|
-
r_loss = torch.mean((1 - dr) ** 2)
|
504
|
-
g_loss = torch.mean(dg**2)
|
505
|
-
loss += r_loss + g_loss
|
506
|
-
r_losses.append(r_loss.item())
|
507
|
-
g_losses.append(g_loss.item())
|
508
|
-
|
509
|
-
return loss, r_losses, g_losses
|
510
|
-
|
511
|
-
|
512
|
-
def generator_loss(disc_outputs):
|
513
|
-
loss = 0
|
514
|
-
gen_losses = []
|
515
|
-
for dg in disc_outputs:
|
516
|
-
l = torch.mean((1 - dg) ** 2)
|
517
|
-
gen_losses.append(l)
|
518
|
-
loss += l
|
519
|
-
|
520
|
-
return loss, gen_losses
|