lt-tensor 0.0.1a14__py3-none-any.whl → 0.0.1a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +23 -6
- lt_tensor/model_base.py +163 -123
- lt_tensor/model_zoo/__init__.py +8 -6
- lt_tensor/model_zoo/audio_models/__init__.py +1 -0
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py +3 -0
- lt_tensor/model_zoo/audio_models/diffwave/model.py +201 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +393 -0
- lt_tensor/model_zoo/audio_models/istft/__init__.py +409 -0
- lt_tensor/model_zoo/basic.py +139 -0
- lt_tensor/model_zoo/features.py +102 -11
- lt_tensor/model_zoo/residual.py +133 -64
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/RECORD +16 -16
- lt_tensor/model_zoo/discriminator.py +0 -196
- lt_tensor/model_zoo/istft/__init__.py +0 -5
- lt_tensor/model_zoo/istft/generator.py +0 -90
- lt_tensor/model_zoo/istft/sg.py +0 -142
- lt_tensor/model_zoo/istft/trainer.py +0 -618
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a14.dist-info → lt_tensor-0.0.1a16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,409 @@
|
|
1
|
+
__all__ = ["iSTFTGenerator"]
|
2
|
+
from lt_utils.common import *
|
3
|
+
from lt_tensor.torch_commons import *
|
4
|
+
from lt_tensor.model_zoo.residual import ConvNets
|
5
|
+
from torch.nn import functional as F
|
6
|
+
|
7
|
+
|
8
|
+
def get_padding(ks, d):
|
9
|
+
return int((ks * d - d) / 2)
|
10
|
+
|
11
|
+
|
12
|
+
class ResBlock1(ConvNets):
|
13
|
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
14
|
+
super().__init__()
|
15
|
+
self.h = h
|
16
|
+
self.convs1 = nn.ModuleList(
|
17
|
+
[
|
18
|
+
weight_norm(
|
19
|
+
nn.Conv1d(
|
20
|
+
channels,
|
21
|
+
channels,
|
22
|
+
kernel_size,
|
23
|
+
1,
|
24
|
+
dilation=dilation[0],
|
25
|
+
padding=get_padding(kernel_size, dilation[0]),
|
26
|
+
)
|
27
|
+
),
|
28
|
+
weight_norm(
|
29
|
+
nn.Conv1d(
|
30
|
+
channels,
|
31
|
+
channels,
|
32
|
+
kernel_size,
|
33
|
+
1,
|
34
|
+
dilation=dilation[1],
|
35
|
+
padding=get_padding(kernel_size, dilation[1]),
|
36
|
+
)
|
37
|
+
),
|
38
|
+
weight_norm(
|
39
|
+
nn.Conv1d(
|
40
|
+
channels,
|
41
|
+
channels,
|
42
|
+
kernel_size,
|
43
|
+
1,
|
44
|
+
dilation=dilation[2],
|
45
|
+
padding=get_padding(kernel_size, dilation[2]),
|
46
|
+
)
|
47
|
+
),
|
48
|
+
]
|
49
|
+
)
|
50
|
+
self.convs1.apply(self.init_weights)
|
51
|
+
|
52
|
+
self.convs2 = nn.ModuleList(
|
53
|
+
[
|
54
|
+
weight_norm(
|
55
|
+
nn.Conv1d(
|
56
|
+
channels,
|
57
|
+
channels,
|
58
|
+
kernel_size,
|
59
|
+
1,
|
60
|
+
dilation=1,
|
61
|
+
padding=get_padding(kernel_size, 1),
|
62
|
+
)
|
63
|
+
),
|
64
|
+
weight_norm(
|
65
|
+
nn.Conv1d(
|
66
|
+
channels,
|
67
|
+
channels,
|
68
|
+
kernel_size,
|
69
|
+
1,
|
70
|
+
dilation=1,
|
71
|
+
padding=get_padding(kernel_size, 1),
|
72
|
+
)
|
73
|
+
),
|
74
|
+
weight_norm(
|
75
|
+
nn.Conv1d(
|
76
|
+
channels,
|
77
|
+
channels,
|
78
|
+
kernel_size,
|
79
|
+
1,
|
80
|
+
dilation=1,
|
81
|
+
padding=get_padding(kernel_size, 1),
|
82
|
+
)
|
83
|
+
),
|
84
|
+
]
|
85
|
+
)
|
86
|
+
self.activation = nn.LeakyReLU(0.1)
|
87
|
+
self.convs2.apply(self.init_weights)
|
88
|
+
|
89
|
+
def forward(self, x):
|
90
|
+
for c1, c2 in zip(self.convs1, self.convs2):
|
91
|
+
xt = self.activation(x)
|
92
|
+
xt = c1(xt)
|
93
|
+
xt = self.activation(xt)
|
94
|
+
xt = c2(xt)
|
95
|
+
x = xt + x
|
96
|
+
return x
|
97
|
+
|
98
|
+
class ResBlock2(ConvNets):
|
99
|
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
100
|
+
super().__init__()
|
101
|
+
self.h = h
|
102
|
+
self.convs = nn.ModuleList(
|
103
|
+
[
|
104
|
+
weight_norm(
|
105
|
+
nn.Conv1d(
|
106
|
+
channels,
|
107
|
+
channels,
|
108
|
+
kernel_size,
|
109
|
+
1,
|
110
|
+
dilation=dilation[0],
|
111
|
+
padding=get_padding(kernel_size, dilation[0]),
|
112
|
+
)
|
113
|
+
),
|
114
|
+
weight_norm(
|
115
|
+
nn.Conv1d(
|
116
|
+
channels,
|
117
|
+
channels,
|
118
|
+
kernel_size,
|
119
|
+
1,
|
120
|
+
dilation=dilation[1],
|
121
|
+
padding=get_padding(kernel_size, dilation[1]),
|
122
|
+
)
|
123
|
+
),
|
124
|
+
]
|
125
|
+
)
|
126
|
+
self.activation = nn.LeakyReLU(0.1)
|
127
|
+
self.convs.apply(self.init_weights)
|
128
|
+
|
129
|
+
def forward(self, x):
|
130
|
+
for c in self.convs:
|
131
|
+
xt = self.activation(x)
|
132
|
+
xt = c(xt)
|
133
|
+
x = xt + x
|
134
|
+
return x
|
135
|
+
|
136
|
+
|
137
|
+
class iSTFTGenerator(ConvNets):
|
138
|
+
def __init__(self, h):
|
139
|
+
super().__init__()
|
140
|
+
self.h = h
|
141
|
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
142
|
+
self.num_upsamples = len(h.upsample_rates)
|
143
|
+
self.conv_pre = weight_norm(
|
144
|
+
nn.Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
145
|
+
)
|
146
|
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
147
|
+
|
148
|
+
self.ups = nn.ModuleList()
|
149
|
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
150
|
+
if h.sampling_rate % 16000:
|
151
|
+
self.ups.append(
|
152
|
+
weight_norm(
|
153
|
+
nn.ConvTranspose1d(
|
154
|
+
h.upsample_initial_channel // (2**i),
|
155
|
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
156
|
+
k,
|
157
|
+
u,
|
158
|
+
padding=(k - u) // 2,
|
159
|
+
)
|
160
|
+
)
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
self.ups.append(
|
164
|
+
weight_norm(
|
165
|
+
nn.ConvTranspose1d(
|
166
|
+
h.upsample_initial_channel // (2**i),
|
167
|
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
168
|
+
k,
|
169
|
+
u,
|
170
|
+
padding=(u // 2 + u % 2),
|
171
|
+
output_padding=u % 2,
|
172
|
+
)
|
173
|
+
)
|
174
|
+
)
|
175
|
+
|
176
|
+
self.resblocks = nn.ModuleList()
|
177
|
+
for i in range(len(self.ups)):
|
178
|
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
179
|
+
for j, (k, d) in enumerate(
|
180
|
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
181
|
+
):
|
182
|
+
self.resblocks.append(resblock(h, ch, k, d))
|
183
|
+
|
184
|
+
self.post_n_fft = h.gen_istft_n_fft
|
185
|
+
self.conv_post = weight_norm(
|
186
|
+
nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)
|
187
|
+
)
|
188
|
+
self.ups.apply(self.init_weights)
|
189
|
+
self.conv_post.apply(self.init_weights)
|
190
|
+
self.activation = nn.LeakyReLU(0.1)
|
191
|
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
192
|
+
|
193
|
+
def forward(self, x):
|
194
|
+
x = self.conv_pre(x)
|
195
|
+
for i in range(self.num_upsamples):
|
196
|
+
x = self.activation(x)
|
197
|
+
x = self.ups[i](x)
|
198
|
+
xs = None
|
199
|
+
for j in range(self.num_kernels):
|
200
|
+
if xs is None:
|
201
|
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
202
|
+
else:
|
203
|
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
204
|
+
x = xs / self.num_kernels
|
205
|
+
x = self.activation(x)
|
206
|
+
x = self.reflection_pad(x)
|
207
|
+
x = self.conv_post(x)
|
208
|
+
spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
|
209
|
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
|
210
|
+
|
211
|
+
return spec, phase
|
212
|
+
|
213
|
+
|
214
|
+
class DiscriminatorP(ConvNets):
|
215
|
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
216
|
+
super().__init__()
|
217
|
+
self.period = period
|
218
|
+
self.activation = nn.LeakyReLU(0.1)
|
219
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
220
|
+
self.convs = nn.ModuleList(
|
221
|
+
[
|
222
|
+
norm_f(
|
223
|
+
nn.Conv2d(
|
224
|
+
1,
|
225
|
+
32,
|
226
|
+
(kernel_size, 1),
|
227
|
+
(stride, 1),
|
228
|
+
padding=(get_padding(5, 1), 0),
|
229
|
+
)
|
230
|
+
),
|
231
|
+
norm_f(
|
232
|
+
nn.Conv2d(
|
233
|
+
32,
|
234
|
+
128,
|
235
|
+
(kernel_size, 1),
|
236
|
+
(stride, 1),
|
237
|
+
padding=(get_padding(5, 1), 0),
|
238
|
+
)
|
239
|
+
),
|
240
|
+
norm_f(
|
241
|
+
nn.Conv2d(
|
242
|
+
128,
|
243
|
+
512,
|
244
|
+
(kernel_size, 1),
|
245
|
+
(stride, 1),
|
246
|
+
padding=(get_padding(5, 1), 0),
|
247
|
+
)
|
248
|
+
),
|
249
|
+
norm_f(
|
250
|
+
nn.Conv2d(
|
251
|
+
512,
|
252
|
+
1024,
|
253
|
+
(kernel_size, 1),
|
254
|
+
(stride, 1),
|
255
|
+
padding=(get_padding(5, 1), 0),
|
256
|
+
)
|
257
|
+
),
|
258
|
+
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
259
|
+
]
|
260
|
+
)
|
261
|
+
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
262
|
+
|
263
|
+
def forward(self, x):
|
264
|
+
fmap = []
|
265
|
+
|
266
|
+
# 1d to 2d
|
267
|
+
b, c, t = x.shape
|
268
|
+
if t % self.period != 0: # pad first
|
269
|
+
n_pad = self.period - (t % self.period)
|
270
|
+
x = F.pad(x, (0, n_pad), "reflect")
|
271
|
+
t = t + n_pad
|
272
|
+
x = x.view(b, c, t // self.period, self.period)
|
273
|
+
|
274
|
+
for l in self.convs:
|
275
|
+
x = l(x)
|
276
|
+
x = self.activation(x)
|
277
|
+
fmap.append(x)
|
278
|
+
x = self.conv_post(x)
|
279
|
+
fmap.append(x)
|
280
|
+
x = torch.flatten(x, 1, -1)
|
281
|
+
|
282
|
+
return x, fmap
|
283
|
+
|
284
|
+
|
285
|
+
class MultiPeriodDiscriminator(ConvNets):
|
286
|
+
def __init__(self):
|
287
|
+
super().__init__()
|
288
|
+
self.discriminators = nn.ModuleList(
|
289
|
+
[
|
290
|
+
DiscriminatorP(2),
|
291
|
+
DiscriminatorP(3),
|
292
|
+
DiscriminatorP(5),
|
293
|
+
DiscriminatorP(7),
|
294
|
+
DiscriminatorP(11),
|
295
|
+
]
|
296
|
+
)
|
297
|
+
|
298
|
+
def forward(self, y, y_hat):
|
299
|
+
y_d_rs = []
|
300
|
+
y_d_gs = []
|
301
|
+
fmap_rs = []
|
302
|
+
fmap_gs = []
|
303
|
+
for i, d in enumerate(self.discriminators):
|
304
|
+
y_d_r, fmap_r = d(y)
|
305
|
+
y_d_g, fmap_g = d(y_hat)
|
306
|
+
y_d_rs.append(y_d_r)
|
307
|
+
fmap_rs.append(fmap_r)
|
308
|
+
y_d_gs.append(y_d_g)
|
309
|
+
fmap_gs.append(fmap_g)
|
310
|
+
|
311
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
312
|
+
|
313
|
+
|
314
|
+
class DiscriminatorS(ConvNets):
|
315
|
+
def __init__(self, use_spectral_norm=False):
|
316
|
+
super().__init__()
|
317
|
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
318
|
+
self.convs = nn.ModuleList(
|
319
|
+
[
|
320
|
+
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
321
|
+
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
322
|
+
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
323
|
+
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
324
|
+
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
325
|
+
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
326
|
+
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
327
|
+
]
|
328
|
+
)
|
329
|
+
self.activation = nn.LeakyReLU(0.1)
|
330
|
+
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
331
|
+
|
332
|
+
def forward(self, x):
|
333
|
+
fmap = []
|
334
|
+
for l in self.convs:
|
335
|
+
x = self.activation(l(x))
|
336
|
+
|
337
|
+
fmap.append(x)
|
338
|
+
x = self.conv_post(x)
|
339
|
+
fmap.append(x)
|
340
|
+
x = torch.flatten(x, 1, -1)
|
341
|
+
|
342
|
+
return x, fmap
|
343
|
+
|
344
|
+
|
345
|
+
class MultiScaleDiscriminator(ConvNets):
|
346
|
+
def __init__(self):
|
347
|
+
super().__init__()
|
348
|
+
self.discriminators = nn.ModuleList(
|
349
|
+
[
|
350
|
+
DiscriminatorS(use_spectral_norm=True),
|
351
|
+
DiscriminatorS(),
|
352
|
+
DiscriminatorS(),
|
353
|
+
]
|
354
|
+
)
|
355
|
+
self.meanpools = nn.ModuleList(
|
356
|
+
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
|
357
|
+
)
|
358
|
+
|
359
|
+
def forward(self, y, y_hat):
|
360
|
+
y_d_rs = []
|
361
|
+
y_d_gs = []
|
362
|
+
fmap_rs = []
|
363
|
+
fmap_gs = []
|
364
|
+
for i, d in enumerate(self.discriminators):
|
365
|
+
if i != 0:
|
366
|
+
y = self.meanpools[i - 1](y)
|
367
|
+
y_hat = self.meanpools[i - 1](y_hat)
|
368
|
+
y_d_r, fmap_r = d(y)
|
369
|
+
y_d_g, fmap_g = d(y_hat)
|
370
|
+
y_d_rs.append(y_d_r)
|
371
|
+
fmap_rs.append(fmap_r)
|
372
|
+
y_d_gs.append(y_d_g)
|
373
|
+
fmap_gs.append(fmap_g)
|
374
|
+
|
375
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
376
|
+
|
377
|
+
|
378
|
+
def feature_loss(fmap_r, fmap_g):
|
379
|
+
loss = 0
|
380
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
381
|
+
for rl, gl in zip(dr, dg):
|
382
|
+
loss += torch.mean(torch.abs(rl - gl))
|
383
|
+
|
384
|
+
return loss * 2
|
385
|
+
|
386
|
+
|
387
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
388
|
+
loss = 0
|
389
|
+
r_losses = []
|
390
|
+
g_losses = []
|
391
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
392
|
+
r_loss = torch.mean((1 - dr) ** 2)
|
393
|
+
g_loss = torch.mean(dg**2)
|
394
|
+
loss += r_loss + g_loss
|
395
|
+
r_losses.append(r_loss.item())
|
396
|
+
g_losses.append(g_loss.item())
|
397
|
+
|
398
|
+
return loss, r_losses, g_losses
|
399
|
+
|
400
|
+
|
401
|
+
def generator_loss(disc_outputs):
|
402
|
+
loss = 0
|
403
|
+
gen_losses = []
|
404
|
+
for dg in disc_outputs:
|
405
|
+
l = torch.mean((1 - dg) ** 2)
|
406
|
+
gen_losses.append(l)
|
407
|
+
loss += l
|
408
|
+
|
409
|
+
return loss, gen_losses
|
lt_tensor/model_zoo/basic.py
CHANGED
@@ -16,6 +16,7 @@ from lt_tensor.model_base import Model
|
|
16
16
|
from lt_tensor.transform import get_sinusoidal_embedding
|
17
17
|
from lt_utils.common import *
|
18
18
|
import math
|
19
|
+
from einops import repeat
|
19
20
|
|
20
21
|
|
21
22
|
class FeedForward(Model):
|
@@ -346,3 +347,141 @@ class LoRAConv2DLayer(nn.Module):
|
|
346
347
|
down_hidden_states = self.down(inputs.to(self._down_dt))
|
347
348
|
up_hidden_states = self.up(down_hidden_states) * self.ah
|
348
349
|
return up_hidden_states.to(orig_dtype)
|
350
|
+
|
351
|
+
|
352
|
+
class SineGen(nn.Module):
|
353
|
+
def __init__(
|
354
|
+
self,
|
355
|
+
samp_rate,
|
356
|
+
upsample_scale,
|
357
|
+
harmonic_num=0,
|
358
|
+
sine_amp=0.1,
|
359
|
+
noise_std=0.003,
|
360
|
+
voiced_threshold=0,
|
361
|
+
flag_for_pulse=False,
|
362
|
+
):
|
363
|
+
super().__init__()
|
364
|
+
self.sampling_rate = samp_rate
|
365
|
+
self.upsample_scale = upsample_scale
|
366
|
+
self.harmonic_num = harmonic_num
|
367
|
+
self.sine_amp = sine_amp
|
368
|
+
self.noise_std = noise_std
|
369
|
+
self.voiced_threshold = voiced_threshold
|
370
|
+
self.flag_for_pulse = flag_for_pulse
|
371
|
+
self.dim = self.harmonic_num + 1 # fundamental + harmonics
|
372
|
+
|
373
|
+
def _f02uv_b(self, f0):
|
374
|
+
return (f0 > self.voiced_threshold).float() # [B, T]
|
375
|
+
|
376
|
+
def _f02uv(self, f0):
|
377
|
+
return (f0 > self.voiced_threshold).float().unsqueeze(-1) # -> (B, T, 1)
|
378
|
+
|
379
|
+
@torch.no_grad()
|
380
|
+
def _f02sine(self, f0_values):
|
381
|
+
"""
|
382
|
+
f0_values: (B, T, 1)
|
383
|
+
Output: sine waves (B, T * upsample, dim)
|
384
|
+
"""
|
385
|
+
B, T, _ = f0_values.size()
|
386
|
+
f0_upsampled = repeat(
|
387
|
+
f0_values, "b t d -> b (t r) d", r=self.upsample_scale
|
388
|
+
) # (B, T_up, 1)
|
389
|
+
|
390
|
+
# Create harmonics
|
391
|
+
harmonics = (
|
392
|
+
torch.arange(1, self.dim + 1, device=f0_values.device)
|
393
|
+
.float()
|
394
|
+
.view(1, 1, -1)
|
395
|
+
)
|
396
|
+
f0_harm = f0_upsampled * harmonics # (B, T_up, dim)
|
397
|
+
|
398
|
+
# Convert Hz to radians (2πf/sr), then integrate to get phase
|
399
|
+
rad_values = f0_harm / self.sampling_rate # normalized freq
|
400
|
+
rad_values = rad_values % 1.0 # remove multiples of 2π
|
401
|
+
|
402
|
+
# Random initial phase for each harmonic (except 0th if pulse mode)
|
403
|
+
if self.flag_for_pulse:
|
404
|
+
rand_ini = torch.zeros((B, 1, self.dim), device=f0_values.device)
|
405
|
+
else:
|
406
|
+
rand_ini = torch.rand((B, 1, self.dim), device=f0_values.device)
|
407
|
+
|
408
|
+
rand_ini = rand_ini * 2 * math.pi
|
409
|
+
|
410
|
+
# Compute cumulative phase
|
411
|
+
rad_values = rad_values * 2 * math.pi
|
412
|
+
phase = torch.cumsum(rad_values, dim=1) + rand_ini # (B, T_up, dim)
|
413
|
+
|
414
|
+
sine_waves = torch.sin(phase) # (B, T_up, dim)
|
415
|
+
return sine_waves
|
416
|
+
|
417
|
+
def _forward(self, f0):
|
418
|
+
"""
|
419
|
+
f0: (B, T, 1)
|
420
|
+
returns: sine signal with harmonics and noise added
|
421
|
+
"""
|
422
|
+
sine_waves = self._f02sine(f0) # (B, T_up, dim)
|
423
|
+
uv = self._f02uv_b(f0) # (B, T, 1)
|
424
|
+
uv = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
|
425
|
+
|
426
|
+
# voiced sine + unvoiced noise
|
427
|
+
sine_signal = self.sine_amp * sine_waves * uv # (B, T_up, dim)
|
428
|
+
noise = torch.randn_like(sine_signal) * self.noise_std
|
429
|
+
output = sine_signal + noise * (1.0 - uv) # noise added only on unvoiced
|
430
|
+
|
431
|
+
return output # (B, T_up, dim)
|
432
|
+
|
433
|
+
def forward(self, f0):
|
434
|
+
"""
|
435
|
+
Args:
|
436
|
+
f0: (B, T) in Hz (before upsampling)
|
437
|
+
Returns:
|
438
|
+
sine_waves: (B, T_up, dim)
|
439
|
+
uv: (B, T_up, 1)
|
440
|
+
noise: (B, T_up, 1)
|
441
|
+
"""
|
442
|
+
B, T = f0.shape
|
443
|
+
device = f0.device
|
444
|
+
|
445
|
+
# Get uv mask (before upsampling)
|
446
|
+
uv = self._f02uv(f0) # (B, T, 1)
|
447
|
+
|
448
|
+
# Expand f0 to include harmonics: (B, T, dim)
|
449
|
+
f0 = f0.unsqueeze(-1) # (B, T, 1)
|
450
|
+
harmonics = (
|
451
|
+
torch.arange(1, self.dim + 1, device=device).float().view(1, 1, -1)
|
452
|
+
) # (1, 1, dim)
|
453
|
+
f0_harm = f0 * harmonics # (B, T, dim)
|
454
|
+
|
455
|
+
# Upsample
|
456
|
+
f0_harm_up = repeat(
|
457
|
+
f0_harm, "b t d -> b (t r) d", r=self.upsample_scale
|
458
|
+
) # (B, T_up, dim)
|
459
|
+
uv_up = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
|
460
|
+
|
461
|
+
# Convert to radians
|
462
|
+
rad_per_sample = f0_harm_up / self.sampling_rate # Hz → cycles/sample
|
463
|
+
rad_per_sample = rad_per_sample * 2 * math.pi # cycles → radians/sample
|
464
|
+
|
465
|
+
# Random phase init for each sample
|
466
|
+
B, T_up, D = rad_per_sample.shape
|
467
|
+
rand_phase = torch.rand(B, D, device=device) * 2 * math.pi # (B, D)
|
468
|
+
|
469
|
+
# Compute cumulative phase
|
470
|
+
phase = torch.cumsum(rad_per_sample, dim=1) + rand_phase.unsqueeze(
|
471
|
+
1
|
472
|
+
) # (B, T_up, D)
|
473
|
+
|
474
|
+
# Apply sine
|
475
|
+
sine_waves = torch.sin(phase) * self.sine_amp # (B, T_up, D)
|
476
|
+
|
477
|
+
# Handle unvoiced: create noise only for fundamental
|
478
|
+
noise = torch.randn(B, T_up, 1, device=device) * self.noise_std
|
479
|
+
if self.flag_for_pulse:
|
480
|
+
# If pulse mode is on, align phase at start of voiced segments
|
481
|
+
# Optional and tricky to implement — may require segmenting uv
|
482
|
+
pass
|
483
|
+
|
484
|
+
# Replace sine by noise for unvoiced (only on fundamental)
|
485
|
+
sine_waves[:, :, 0:1] = sine_waves[:, :, 0:1] * uv_up + noise * (1 - uv_up)
|
486
|
+
|
487
|
+
return sine_waves, uv_up, noise
|