lt-tensor 0.0.1a14__py3-none-any.whl → 0.0.1a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +23 -6
- lt_tensor/model_base.py +163 -123
- lt_tensor/model_zoo/__init__.py +8 -6
- lt_tensor/model_zoo/audio_models/__init__.py +1 -0
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py +3 -0
- lt_tensor/model_zoo/audio_models/diffwave/model.py +201 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +393 -0
- lt_tensor/model_zoo/audio_models/istft/__init__.py +409 -0
- lt_tensor/model_zoo/basic.py +139 -0
- lt_tensor/model_zoo/features.py +102 -11
- lt_tensor/model_zoo/residual.py +133 -64
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/RECORD +16 -16
- lt_tensor/model_zoo/discriminator.py +0 -196
- lt_tensor/model_zoo/istft/__init__.py +0 -5
- lt_tensor/model_zoo/istft/generator.py +0 -90
- lt_tensor/model_zoo/istft/sg.py +0 -142
- lt_tensor/model_zoo/istft/trainer.py +0 -618
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
|
|
1
|
+
__all__ = ["DiffWave", "SpectrogramUpsampler", "DiffusionEmbedding"]
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
|
7
|
+
from math import sqrt
|
8
|
+
|
9
|
+
|
10
|
+
class AttrDict(dict):
|
11
|
+
def __init__(self, *args, **kwargs):
|
12
|
+
super(AttrDict, self).__init__(*args, **kwargs)
|
13
|
+
self.__dict__ = self
|
14
|
+
|
15
|
+
def override(self, attrs):
|
16
|
+
if isinstance(attrs, dict):
|
17
|
+
self.__dict__.update(**attrs)
|
18
|
+
elif isinstance(attrs, (list, tuple, set)):
|
19
|
+
for attr in attrs:
|
20
|
+
self.override(attr)
|
21
|
+
elif attrs is not None:
|
22
|
+
raise NotImplementedError
|
23
|
+
return self
|
24
|
+
|
25
|
+
|
26
|
+
params = AttrDict(
|
27
|
+
# Training params
|
28
|
+
batch_size=16,
|
29
|
+
learning_rate=2e-4,
|
30
|
+
max_grad_norm=None,
|
31
|
+
# Data params
|
32
|
+
sample_rate=22050,
|
33
|
+
n_mels=80,
|
34
|
+
n_fft=1024,
|
35
|
+
hop_samples=256,
|
36
|
+
crop_mel_frames=62, # Probably an error in paper.
|
37
|
+
# Model params
|
38
|
+
residual_layers=30,
|
39
|
+
residual_channels=64,
|
40
|
+
dilation_cycle_length=10,
|
41
|
+
unconditional=False,
|
42
|
+
noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
|
43
|
+
inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],
|
44
|
+
# unconditional sample len
|
45
|
+
audio_len=22050 * 5, # unconditional_synthesis_samples
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
def Conv1d(*args, **kwargs):
|
50
|
+
layer = nn.Conv1d(*args, **kwargs)
|
51
|
+
nn.init.kaiming_normal_(layer.weight)
|
52
|
+
return layer
|
53
|
+
|
54
|
+
|
55
|
+
class DiffusionEmbedding(nn.Module):
|
56
|
+
def __init__(self, max_steps):
|
57
|
+
super().__init__()
|
58
|
+
self.register_buffer(
|
59
|
+
"embedding", self._build_embedding(max_steps), persistent=False
|
60
|
+
)
|
61
|
+
self.projection1 = nn.Linear(128, 512)
|
62
|
+
self.projection2 = nn.Linear(512, 512)
|
63
|
+
self.activation = nn.SiLU()
|
64
|
+
|
65
|
+
def forward(self, diffusion_step):
|
66
|
+
if diffusion_step.dtype in [torch.int32, torch.int64]:
|
67
|
+
x = self.embedding[diffusion_step]
|
68
|
+
else:
|
69
|
+
x = self._lerp_embedding(diffusion_step)
|
70
|
+
x = self.projection1(x)
|
71
|
+
x = self.activation(x)
|
72
|
+
x = self.projection2(x)
|
73
|
+
x = self.activation(x)
|
74
|
+
return x
|
75
|
+
|
76
|
+
def _lerp_embedding(self, t):
|
77
|
+
low_idx = torch.floor(t).long()
|
78
|
+
high_idx = torch.ceil(t).long()
|
79
|
+
low = self.embedding[low_idx]
|
80
|
+
high = self.embedding[high_idx]
|
81
|
+
return low + (high - low) * (t - low_idx)
|
82
|
+
|
83
|
+
def _build_embedding(self, max_steps):
|
84
|
+
steps = torch.arange(max_steps).unsqueeze(1) # [T,1]
|
85
|
+
dims = torch.arange(64).unsqueeze(0) # [1,64]
|
86
|
+
table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64]
|
87
|
+
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
|
88
|
+
return table
|
89
|
+
|
90
|
+
|
91
|
+
class SpectrogramUpsampler(nn.Module):
|
92
|
+
def __init__(self, n_mels):
|
93
|
+
super().__init__()
|
94
|
+
self.conv1 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
|
95
|
+
self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
|
96
|
+
|
97
|
+
def forward(self, x):
|
98
|
+
x = torch.unsqueeze(x, 1)
|
99
|
+
x = self.conv1(x)
|
100
|
+
x = F.leaky_relu(x, 0.4)
|
101
|
+
x = self.conv2(x)
|
102
|
+
x = F.leaky_relu(x, 0.4)
|
103
|
+
x = torch.squeeze(x, 1)
|
104
|
+
return x
|
105
|
+
|
106
|
+
|
107
|
+
class ResidualBlock(nn.Module):
|
108
|
+
def __init__(self, n_mels, residual_channels, dilation, uncond=False):
|
109
|
+
"""
|
110
|
+
:param n_mels: inplanes of conv1x1 for spectrogram conditional
|
111
|
+
:param residual_channels: audio conv
|
112
|
+
:param dilation: audio conv dilation
|
113
|
+
:param uncond: disable spectrogram conditional
|
114
|
+
"""
|
115
|
+
super().__init__()
|
116
|
+
self.dilated_conv = Conv1d(
|
117
|
+
residual_channels,
|
118
|
+
2 * residual_channels,
|
119
|
+
3,
|
120
|
+
padding=dilation,
|
121
|
+
dilation=dilation,
|
122
|
+
)
|
123
|
+
self.diffusion_projection = nn.Linear(512, residual_channels)
|
124
|
+
if not uncond: # conditional model
|
125
|
+
self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
|
126
|
+
else: # unconditional model
|
127
|
+
self.conditioner_projection = None
|
128
|
+
|
129
|
+
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
130
|
+
|
131
|
+
def forward(self, x, diffusion_step, conditioner=None):
|
132
|
+
assert (conditioner is None and self.conditioner_projection is None) or (
|
133
|
+
conditioner is not None and self.conditioner_projection is not None
|
134
|
+
)
|
135
|
+
|
136
|
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
137
|
+
y = x + diffusion_step
|
138
|
+
if self.conditioner_projection is None: # using a unconditional model
|
139
|
+
y = self.dilated_conv(y)
|
140
|
+
else:
|
141
|
+
conditioner = self.conditioner_projection(conditioner)
|
142
|
+
y = self.dilated_conv(y) + conditioner
|
143
|
+
|
144
|
+
gate, filter = torch.chunk(y, 2, dim=1)
|
145
|
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
146
|
+
|
147
|
+
y = self.output_projection(y)
|
148
|
+
residual, skip = torch.chunk(y, 2, dim=1)
|
149
|
+
return (x + residual) / sqrt(2.0), skip
|
150
|
+
|
151
|
+
|
152
|
+
class DiffWave(nn.Module):
|
153
|
+
def __init__(self, params):
|
154
|
+
super().__init__()
|
155
|
+
self.params = params
|
156
|
+
self.input_projection = Conv1d(1, params.residual_channels, 1)
|
157
|
+
self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
|
158
|
+
if self.params.unconditional: # use unconditional model
|
159
|
+
self.spectrogram_upsampler = None
|
160
|
+
else:
|
161
|
+
self.spectrogram_upsampler = SpectrogramUpsampler(params.n_mels)
|
162
|
+
|
163
|
+
self.residual_layers = nn.ModuleList(
|
164
|
+
[
|
165
|
+
ResidualBlock(
|
166
|
+
params.n_mels,
|
167
|
+
params.residual_channels,
|
168
|
+
2 ** (i % params.dilation_cycle_length),
|
169
|
+
uncond=params.unconditional,
|
170
|
+
)
|
171
|
+
for i in range(params.residual_layers)
|
172
|
+
]
|
173
|
+
)
|
174
|
+
self.skip_projection = Conv1d(
|
175
|
+
params.residual_channels, params.residual_channels, 1
|
176
|
+
)
|
177
|
+
self.output_projection = Conv1d(params.residual_channels, 1, 1)
|
178
|
+
nn.init.zeros_(self.output_projection.weight)
|
179
|
+
|
180
|
+
def forward(self, audio, diffusion_step, spectrogram=None):
|
181
|
+
assert (spectrogram is None and self.spectrogram_upsampler is None) or (
|
182
|
+
spectrogram is not None and self.spectrogram_upsampler is not None
|
183
|
+
)
|
184
|
+
x = audio.unsqueeze(1)
|
185
|
+
x = self.input_projection(x)
|
186
|
+
x = F.relu(x)
|
187
|
+
|
188
|
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
189
|
+
if self.spectrogram_upsampler: # use conditional model
|
190
|
+
spectrogram = self.spectrogram_upsampler(spectrogram)
|
191
|
+
|
192
|
+
skip = None
|
193
|
+
for layer in self.residual_layers:
|
194
|
+
x, skip_connection = layer(x, diffusion_step, spectrogram)
|
195
|
+
skip = skip_connection if skip is None else skip_connection + skip
|
196
|
+
|
197
|
+
x = skip / sqrt(len(self.residual_layers))
|
198
|
+
x = self.skip_projection(x)
|
199
|
+
x = F.relu(x)
|
200
|
+
x = self.output_projection(x)
|
201
|
+
return x
|
@@ -0,0 +1,393 @@
|
|
1
|
+
__all__ = ["HifiganGenerator"]
|
2
|
+
from lt_utils.common import *
|
3
|
+
from lt_tensor.torch_commons import *
|
4
|
+
from lt_tensor.model_zoo.residual import ConvNets
|
5
|
+
from torch.nn import functional as F
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
import torch.nn as nn
|
10
|
+
|
11
|
+
|
12
|
+
def get_padding(kernel_size, dilation=1):
|
13
|
+
return int((kernel_size * dilation - dilation) / 2)
|
14
|
+
|
15
|
+
|
16
|
+
class ResBlock1(ConvNets):
|
17
|
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
18
|
+
super().__init__()
|
19
|
+
|
20
|
+
self.convs1 = nn.ModuleList(
|
21
|
+
[
|
22
|
+
weight_norm(
|
23
|
+
nn.Conv1d(
|
24
|
+
channels,
|
25
|
+
channels,
|
26
|
+
kernel_size,
|
27
|
+
1,
|
28
|
+
dilation=dilation[0],
|
29
|
+
padding=get_padding(kernel_size, dilation[0]),
|
30
|
+
)
|
31
|
+
),
|
32
|
+
weight_norm(
|
33
|
+
nn.Conv1d(
|
34
|
+
channels,
|
35
|
+
channels,
|
36
|
+
kernel_size,
|
37
|
+
1,
|
38
|
+
dilation=dilation[1],
|
39
|
+
padding=get_padding(kernel_size, dilation[1]),
|
40
|
+
)
|
41
|
+
),
|
42
|
+
weight_norm(
|
43
|
+
nn.Conv1d(
|
44
|
+
channels,
|
45
|
+
channels,
|
46
|
+
kernel_size,
|
47
|
+
1,
|
48
|
+
dilation=dilation[2],
|
49
|
+
padding=get_padding(kernel_size, dilation[2]),
|
50
|
+
)
|
51
|
+
),
|
52
|
+
]
|
53
|
+
)
|
54
|
+
self.convs1.apply(self.init_weights)
|
55
|
+
|
56
|
+
self.convs2 = nn.ModuleList(
|
57
|
+
[
|
58
|
+
weight_norm(
|
59
|
+
nn.Conv1d(
|
60
|
+
channels,
|
61
|
+
channels,
|
62
|
+
kernel_size,
|
63
|
+
1,
|
64
|
+
dilation=1,
|
65
|
+
padding=get_padding(kernel_size, 1),
|
66
|
+
)
|
67
|
+
),
|
68
|
+
weight_norm(
|
69
|
+
nn.Conv1d(
|
70
|
+
channels,
|
71
|
+
channels,
|
72
|
+
kernel_size,
|
73
|
+
1,
|
74
|
+
dilation=1,
|
75
|
+
padding=get_padding(kernel_size, 1),
|
76
|
+
)
|
77
|
+
),
|
78
|
+
weight_norm(
|
79
|
+
nn.Conv1d(
|
80
|
+
channels,
|
81
|
+
channels,
|
82
|
+
kernel_size,
|
83
|
+
1,
|
84
|
+
dilation=1,
|
85
|
+
padding=get_padding(kernel_size, 1),
|
86
|
+
)
|
87
|
+
),
|
88
|
+
]
|
89
|
+
)
|
90
|
+
self.convs2.apply(self.init_weights)
|
91
|
+
self.activation = nn.LeakyReLU(0.1)
|
92
|
+
|
93
|
+
def forward(self, x):
|
94
|
+
for c1, c2 in zip(self.convs1, self.convs2):
|
95
|
+
xt = c1(self.activation(x))
|
96
|
+
xt = c2(self.activation(xt))
|
97
|
+
x = xt + x
|
98
|
+
return x
|
99
|
+
|
100
|
+
def remove_weight_norm(self):
|
101
|
+
for l in self.convs1:
|
102
|
+
remove_weight_norm(l)
|
103
|
+
for l in self.convs2:
|
104
|
+
remove_weight_norm(l)
|
105
|
+
|
106
|
+
|
107
|
+
class ResBlock2(ConvNets):
|
108
|
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
109
|
+
super().__init__()
|
110
|
+
self.convs = nn.ModuleList(
|
111
|
+
[
|
112
|
+
weight_norm(
|
113
|
+
nn.Conv1d(
|
114
|
+
channels,
|
115
|
+
channels,
|
116
|
+
kernel_size,
|
117
|
+
1,
|
118
|
+
dilation=dilation[0],
|
119
|
+
padding=get_padding(kernel_size, dilation[0]),
|
120
|
+
)
|
121
|
+
),
|
122
|
+
weight_norm(
|
123
|
+
nn.Conv1d(
|
124
|
+
channels,
|
125
|
+
channels,
|
126
|
+
kernel_size,
|
127
|
+
1,
|
128
|
+
dilation=dilation[1],
|
129
|
+
padding=get_padding(kernel_size, dilation[1]),
|
130
|
+
)
|
131
|
+
),
|
132
|
+
]
|
133
|
+
)
|
134
|
+
self.convs.apply(self.init_weights)
|
135
|
+
self.activation = nn.LeakyReLU(0.1)
|
136
|
+
|
137
|
+
def forward(self, x):
|
138
|
+
for c in self.convs:
|
139
|
+
xt = c(self.activation(x))
|
140
|
+
x = xt + x
|
141
|
+
return x
|
142
|
+
|
143
|
+
|
144
|
+
class HifiganGenerator(ConvNets):
|
145
|
+
def __init__(self, h):
|
146
|
+
super().__init__()
|
147
|
+
self.h = h
|
148
|
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
149
|
+
self.num_upsamples = len(h.upsample_rates)
|
150
|
+
self.conv_pre = weight_norm(
|
151
|
+
nn.Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
152
|
+
)
|
153
|
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
154
|
+
self.activation = nn.LeakyReLU(0.1)
|
155
|
+
self.ups = nn.ModuleList()
|
156
|
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
157
|
+
self.ups.append(
|
158
|
+
weight_norm(
|
159
|
+
nn.ConvTranspose1d(
|
160
|
+
h.upsample_initial_channel // (2**i),
|
161
|
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
162
|
+
k,
|
163
|
+
u,
|
164
|
+
padding=(k - u) // 2,
|
165
|
+
)
|
166
|
+
)
|
167
|
+
)
|
168
|
+
|
169
|
+
self.resblocks = nn.ModuleList()
|
170
|
+
for i in range(len(self.ups)):
|
171
|
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
172
|
+
for j, (k, d) in enumerate(
|
173
|
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
174
|
+
):
|
175
|
+
self.resblocks.append(resblock(h, ch, k, d))
|
176
|
+
|
177
|
+
self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
|
178
|
+
self.ups.apply(self.init_weights)
|
179
|
+
self.conv_post.apply(self.init_weights)
|
180
|
+
|
181
|
+
def forward(self, x):
|
182
|
+
x = self.conv_pre(x)
|
183
|
+
for i in range(self.num_upsamples):
|
184
|
+
x = self.ups[i](self.activation(x))
|
185
|
+
xs = None
|
186
|
+
for j in range(self.num_kernels):
|
187
|
+
if xs is None:
|
188
|
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
189
|
+
else:
|
190
|
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
191
|
+
x = xs / self.num_kernels
|
192
|
+
x = self.conv_post(self.activation(x))
|
193
|
+
x = torch.tanh(x)
|
194
|
+
|
195
|
+
return x
|
196
|
+
|
197
|
+
|
198
|
+
class DiscriminatorP(ConvNets):
|
199
|
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
200
|
+
super(DiscriminatorP, self).__init__()
|
201
|
+
self.period = period
|
202
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
203
|
+
self.convs = nn.ModuleList(
|
204
|
+
[
|
205
|
+
norm_f(
|
206
|
+
nn.Conv2d(
|
207
|
+
1,
|
208
|
+
32,
|
209
|
+
(kernel_size, 1),
|
210
|
+
(stride, 1),
|
211
|
+
padding=(get_padding(5, 1), 0),
|
212
|
+
)
|
213
|
+
),
|
214
|
+
norm_f(
|
215
|
+
nn.Conv2d(
|
216
|
+
32,
|
217
|
+
128,
|
218
|
+
(kernel_size, 1),
|
219
|
+
(stride, 1),
|
220
|
+
padding=(get_padding(5, 1), 0),
|
221
|
+
)
|
222
|
+
),
|
223
|
+
norm_f(
|
224
|
+
nn.Conv2d(
|
225
|
+
128,
|
226
|
+
512,
|
227
|
+
(kernel_size, 1),
|
228
|
+
(stride, 1),
|
229
|
+
padding=(get_padding(5, 1), 0),
|
230
|
+
)
|
231
|
+
),
|
232
|
+
norm_f(
|
233
|
+
nn.Conv2d(
|
234
|
+
512,
|
235
|
+
1024,
|
236
|
+
(kernel_size, 1),
|
237
|
+
(stride, 1),
|
238
|
+
padding=(get_padding(5, 1), 0),
|
239
|
+
)
|
240
|
+
),
|
241
|
+
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
242
|
+
]
|
243
|
+
)
|
244
|
+
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
245
|
+
self.activation = nn.LeakyReLU(0.1)
|
246
|
+
|
247
|
+
def forward(self, x):
|
248
|
+
fmap = []
|
249
|
+
|
250
|
+
# 1d to 2d
|
251
|
+
b, c, t = x.shape
|
252
|
+
if t % self.period != 0: # pad first
|
253
|
+
n_pad = self.period - (t % self.period)
|
254
|
+
x = F.pad(x, (0, n_pad), "reflect")
|
255
|
+
t = t + n_pad
|
256
|
+
x = x.view(b, c, t // self.period, self.period)
|
257
|
+
|
258
|
+
for l in self.convs:
|
259
|
+
x = l(x)
|
260
|
+
x = self.activation(x)
|
261
|
+
fmap.append(x)
|
262
|
+
x = self.conv_post(x)
|
263
|
+
fmap.append(x)
|
264
|
+
x = torch.flatten(x, 1, -1)
|
265
|
+
|
266
|
+
return x, fmap
|
267
|
+
|
268
|
+
|
269
|
+
class MultiPeriodDiscriminator(ConvNets):
|
270
|
+
def __init__(self):
|
271
|
+
super(MultiPeriodDiscriminator, self).__init__()
|
272
|
+
self.discriminators = nn.ModuleList(
|
273
|
+
[
|
274
|
+
DiscriminatorP(2),
|
275
|
+
DiscriminatorP(3),
|
276
|
+
DiscriminatorP(5),
|
277
|
+
DiscriminatorP(7),
|
278
|
+
DiscriminatorP(11),
|
279
|
+
]
|
280
|
+
)
|
281
|
+
|
282
|
+
def forward(self, y, y_hat):
|
283
|
+
y_d_rs = []
|
284
|
+
y_d_gs = []
|
285
|
+
fmap_rs = []
|
286
|
+
fmap_gs = []
|
287
|
+
for i, d in enumerate(self.discriminators):
|
288
|
+
y_d_r, fmap_r = d(y)
|
289
|
+
y_d_g, fmap_g = d(y_hat)
|
290
|
+
y_d_rs.append(y_d_r)
|
291
|
+
fmap_rs.append(fmap_r)
|
292
|
+
y_d_gs.append(y_d_g)
|
293
|
+
fmap_gs.append(fmap_g)
|
294
|
+
|
295
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
296
|
+
|
297
|
+
|
298
|
+
class DiscriminatorS(ConvNets):
|
299
|
+
def __init__(self, use_spectral_norm=False):
|
300
|
+
super(DiscriminatorS, self).__init__()
|
301
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
302
|
+
self.convs = nn.ModuleList(
|
303
|
+
[
|
304
|
+
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
305
|
+
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
306
|
+
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
307
|
+
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
308
|
+
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
309
|
+
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
310
|
+
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
311
|
+
]
|
312
|
+
)
|
313
|
+
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
314
|
+
self.activation = nn.LeakyReLU(0.1)
|
315
|
+
|
316
|
+
def forward(self, x):
|
317
|
+
fmap = []
|
318
|
+
for l in self.convs:
|
319
|
+
x = l(x)
|
320
|
+
x = self.activation(x)
|
321
|
+
fmap.append(x)
|
322
|
+
x = self.conv_post(x)
|
323
|
+
fmap.append(x)
|
324
|
+
x = torch.flatten(x, 1, -1)
|
325
|
+
|
326
|
+
return x, fmap
|
327
|
+
|
328
|
+
|
329
|
+
class MultiScaleDiscriminator(ConvNets):
|
330
|
+
def __init__(self):
|
331
|
+
super(MultiScaleDiscriminator, self).__init__()
|
332
|
+
self.discriminators = nn.ModuleList(
|
333
|
+
[
|
334
|
+
DiscriminatorS(use_spectral_norm=True),
|
335
|
+
DiscriminatorS(),
|
336
|
+
DiscriminatorS(),
|
337
|
+
]
|
338
|
+
)
|
339
|
+
self.meanpools = nn.ModuleList(
|
340
|
+
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
341
|
+
)
|
342
|
+
|
343
|
+
def forward(self, y, y_hat):
|
344
|
+
y_d_rs = []
|
345
|
+
y_d_gs = []
|
346
|
+
fmap_rs = []
|
347
|
+
fmap_gs = []
|
348
|
+
for i, d in enumerate(self.discriminators):
|
349
|
+
if i != 0:
|
350
|
+
y = self.meanpools[i - 1](y)
|
351
|
+
y_hat = self.meanpools[i - 1](y_hat)
|
352
|
+
y_d_r, fmap_r = d(y)
|
353
|
+
y_d_g, fmap_g = d(y_hat)
|
354
|
+
y_d_rs.append(y_d_r)
|
355
|
+
fmap_rs.append(fmap_r)
|
356
|
+
y_d_gs.append(y_d_g)
|
357
|
+
fmap_gs.append(fmap_g)
|
358
|
+
|
359
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
360
|
+
|
361
|
+
|
362
|
+
def feature_loss(fmap_r, fmap_g):
|
363
|
+
loss = 0
|
364
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
365
|
+
for rl, gl in zip(dr, dg):
|
366
|
+
loss += torch.mean(torch.abs(rl - gl))
|
367
|
+
|
368
|
+
return loss * 2
|
369
|
+
|
370
|
+
|
371
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
372
|
+
loss = 0
|
373
|
+
r_losses = []
|
374
|
+
g_losses = []
|
375
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
376
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
377
|
+
g_loss = torch.mean(dg**2)
|
378
|
+
loss += r_loss + g_loss
|
379
|
+
r_losses.append(r_loss.item())
|
380
|
+
g_losses.append(g_loss.item())
|
381
|
+
|
382
|
+
return loss, r_losses, g_losses
|
383
|
+
|
384
|
+
|
385
|
+
def generator_loss(disc_outputs):
|
386
|
+
loss = 0
|
387
|
+
gen_losses = []
|
388
|
+
for dg in disc_outputs:
|
389
|
+
l = torch.mean((1 - dg) ** 2)
|
390
|
+
gen_losses.append(l)
|
391
|
+
loss += l
|
392
|
+
|
393
|
+
return loss, gen_losses
|