lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +141 -46
- lt_tensor/misc_utils.py +37 -0
- lt_tensor/model_zoo/__init__.py +18 -9
- lt_tensor/model_zoo/{bsc.py → basic.py} +118 -2
- lt_tensor/model_zoo/features.py +416 -0
- lt_tensor/model_zoo/fusion.py +164 -0
- lt_tensor/model_zoo/istft/generator.py +2 -2
- lt_tensor/model_zoo/istft/sg.py +142 -0
- lt_tensor/model_zoo/istft/trainer.py +37 -12
- lt_tensor/model_zoo/residual.py +217 -0
- lt_tensor/model_zoo/{tfrms.py → transformer.py} +2 -2
- lt_tensor/processors/audio.py +218 -80
- lt_tensor/transform.py +7 -16
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/METADATA +6 -4
- lt_tensor-0.0.1a13.dist-info/RECORD +32 -0
- lt_tensor/model_zoo/fsn.py +0 -67
- lt_tensor/model_zoo/gns.py +0 -185
- lt_tensor/model_zoo/istft.py +0 -591
- lt_tensor/model_zoo/rsd.py +0 -107
- lt_tensor-0.0.1a12.dist-info/RECORD +0 -32
- /lt_tensor/model_zoo/{disc.py → discriminator.py} +0 -0
- /lt_tensor/model_zoo/{pos.py → pos_encoder.py} +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a12.dist-info → lt_tensor-0.0.1a13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,416 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"Downsample1D",
|
3
|
+
"Upsample1D",
|
4
|
+
"DiffusionUNet",
|
5
|
+
"UNetConvBlock1D",
|
6
|
+
"UNetUpBlock1D",
|
7
|
+
"NoisePredictor1D",
|
8
|
+
"AdaINFeaturesBlock1D",
|
9
|
+
"UpSampleConv1D",
|
10
|
+
]
|
11
|
+
import math
|
12
|
+
from lt_tensor.torch_commons import *
|
13
|
+
from lt_tensor.model_base import Model
|
14
|
+
from lt_tensor.model_zoo.residual import ResBlock1D
|
15
|
+
from lt_tensor.misc_utils import log_tensor
|
16
|
+
from lt_tensor.model_zoo.fusion import AdaIN1D, CrossAttentionFusion
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from lt_tensor.misc_utils import get_activated_conv
|
19
|
+
from lt_utils.common import *
|
20
|
+
|
21
|
+
|
22
|
+
class FeatureExtractor(Model):
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
in_channels: int = 1,
|
26
|
+
out_channels: int = 32,
|
27
|
+
hidden: int = 128,
|
28
|
+
groups: tuple[int] = [1, 1, 1, 1, 1],
|
29
|
+
kernels: tuple[int, int, int, int, int] = (5, 3, 3, 3, 3),
|
30
|
+
padding: tuple[int, int, int, int, int] = (2, 1, 1, 1, 1),
|
31
|
+
stride: tuple[int, int, int, int, int] = (2, 2, 2, 1, 1),
|
32
|
+
network: Literal[
|
33
|
+
"Conv1d",
|
34
|
+
"Conv2d",
|
35
|
+
"Conv3d",
|
36
|
+
"ConvTranspose1d",
|
37
|
+
"ConvTranspose2d",
|
38
|
+
"ConvTranspose3d",
|
39
|
+
] = "conv1d",
|
40
|
+
):
|
41
|
+
super().__init__()
|
42
|
+
self.pre_nt = get_activated_conv(
|
43
|
+
in_channels,
|
44
|
+
hidden,
|
45
|
+
kernels[0],
|
46
|
+
stride[0],
|
47
|
+
padding[0],
|
48
|
+
groups[0],
|
49
|
+
conv_type=network,
|
50
|
+
)
|
51
|
+
|
52
|
+
self.net = nn.Sequential(
|
53
|
+
[
|
54
|
+
get_activated_conv(
|
55
|
+
hidden,
|
56
|
+
hidden,
|
57
|
+
kernels[i],
|
58
|
+
stride[i],
|
59
|
+
padding[i],
|
60
|
+
groups[i],
|
61
|
+
conv_type=network,
|
62
|
+
)
|
63
|
+
for i in range(1, 4)
|
64
|
+
],
|
65
|
+
)
|
66
|
+
self.post = get_activated_conv(
|
67
|
+
hidden,
|
68
|
+
out_channels,
|
69
|
+
kernels[4],
|
70
|
+
stride[4],
|
71
|
+
padding[4],
|
72
|
+
groups[4],
|
73
|
+
conv_type=network,
|
74
|
+
)
|
75
|
+
|
76
|
+
def forward(self, x: Tensor):
|
77
|
+
x = self.pre_nt(x)
|
78
|
+
x = self.net(x)
|
79
|
+
return self.post(x)
|
80
|
+
|
81
|
+
|
82
|
+
class Downsample1D(Model):
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
in_channels: int,
|
86
|
+
out_channels: int,
|
87
|
+
):
|
88
|
+
super().__init__()
|
89
|
+
self.pool = nn.Conv1d(in_channels, out_channels, 4, stride=2, padding=1)
|
90
|
+
|
91
|
+
def forward(self, x):
|
92
|
+
return self.pool(x)
|
93
|
+
|
94
|
+
|
95
|
+
class Upsample1D(Model):
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
in_channels: int,
|
99
|
+
out_channels: int,
|
100
|
+
activation=nn.ReLU(inplace=True),
|
101
|
+
):
|
102
|
+
super().__init__()
|
103
|
+
self.up = nn.Sequential(
|
104
|
+
nn.ConvTranspose1d(
|
105
|
+
in_channels, out_channels, kernel_size=4, stride=2, padding=1
|
106
|
+
),
|
107
|
+
nn.BatchNorm1d(out_channels),
|
108
|
+
activation,
|
109
|
+
)
|
110
|
+
|
111
|
+
def forward(self, x):
|
112
|
+
return self.up(x)
|
113
|
+
|
114
|
+
|
115
|
+
class DiffusionUNet(Model):
|
116
|
+
def __init__(self, in_channels=1, base_channels=64, out_channels=1, depth=4):
|
117
|
+
super().__init__()
|
118
|
+
|
119
|
+
self.depth = depth
|
120
|
+
self.encoder_blocks = nn.ModuleList()
|
121
|
+
self.downsamples = nn.ModuleList()
|
122
|
+
self.upsamples = nn.ModuleList()
|
123
|
+
self.decoder_blocks = nn.ModuleList()
|
124
|
+
# Keep track of channel sizes per layer for skip connections
|
125
|
+
self.channels = [in_channels] # starting input channel
|
126
|
+
for i in range(depth):
|
127
|
+
enc_in = self.channels[-1]
|
128
|
+
enc_out = base_channels * (2**i)
|
129
|
+
# Encoder block and downsample
|
130
|
+
self.encoder_blocks.append(ResBlock1D(enc_in, enc_out))
|
131
|
+
self.downsamples.append(
|
132
|
+
Downsample1D(enc_out, enc_out)
|
133
|
+
) # halve time, keep channels
|
134
|
+
self.channels.append(enc_out)
|
135
|
+
# Bottleneck
|
136
|
+
bottleneck_ch = self.channels[-1]
|
137
|
+
self.bottleneck = ResBlock1D(bottleneck_ch, bottleneck_ch)
|
138
|
+
# Decoder blocks (reverse channel flow)
|
139
|
+
for i in reversed(range(depth)):
|
140
|
+
skip_ch = self.channels[i + 1] # from encoder
|
141
|
+
dec_out = self.channels[i] # match earlier stage's output
|
142
|
+
self.upsamples.append(Upsample1D(skip_ch, skip_ch))
|
143
|
+
self.decoder_blocks.append(ResBlock1D(skip_ch * 2, dec_out))
|
144
|
+
# Final output projection (out_channels)
|
145
|
+
self.final = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
146
|
+
|
147
|
+
def forward(self, x: Tensor):
|
148
|
+
skips = []
|
149
|
+
|
150
|
+
# Encoder
|
151
|
+
for enc, down in zip(self.encoder_blocks, self.downsamples):
|
152
|
+
# log_tensor(x, "before enc")
|
153
|
+
x = enc(x)
|
154
|
+
skips.append(x)
|
155
|
+
x = down(x)
|
156
|
+
|
157
|
+
# Bottleneck
|
158
|
+
x = self.bottleneck(x)
|
159
|
+
|
160
|
+
# Decoder
|
161
|
+
for up, dec, skip in zip(self.upsamples, self.decoder_blocks, reversed(skips)):
|
162
|
+
x = up(x)
|
163
|
+
|
164
|
+
# Match lengths via trimming or padding
|
165
|
+
if x.shape[-1] > skip.shape[-1]:
|
166
|
+
x = x[..., : skip.shape[-1]]
|
167
|
+
elif x.shape[-1] < skip.shape[-1]:
|
168
|
+
diff = skip.shape[-1] - x.shape[-1]
|
169
|
+
x = F.pad(x, (0, diff))
|
170
|
+
|
171
|
+
x = torch.cat([x, skip], dim=1) # concat on channels
|
172
|
+
x = dec(x)
|
173
|
+
|
174
|
+
# Final 1x1 conv
|
175
|
+
return self.final(x)
|
176
|
+
|
177
|
+
|
178
|
+
class UNetConvBlock1D(Model):
|
179
|
+
def __init__(self, in_channels: int, out_channels: int, down: bool = True):
|
180
|
+
super().__init__()
|
181
|
+
self.down = down
|
182
|
+
self.conv = nn.Sequential(
|
183
|
+
nn.Conv1d(
|
184
|
+
in_channels,
|
185
|
+
out_channels,
|
186
|
+
kernel_size=3,
|
187
|
+
stride=2 if down else 1,
|
188
|
+
padding=1,
|
189
|
+
),
|
190
|
+
nn.BatchNorm1d(out_channels),
|
191
|
+
nn.LeakyReLU(0.2),
|
192
|
+
nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
|
193
|
+
nn.BatchNorm1d(out_channels),
|
194
|
+
nn.LeakyReLU(0.2),
|
195
|
+
)
|
196
|
+
self.downsample = (
|
197
|
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=2 if down else 1)
|
198
|
+
if in_channels != out_channels
|
199
|
+
else nn.Identity()
|
200
|
+
)
|
201
|
+
|
202
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
203
|
+
# x: [B, C, T]
|
204
|
+
residual = self.downsample(x)
|
205
|
+
return self.conv(x) + residual
|
206
|
+
|
207
|
+
|
208
|
+
class UNetUpBlock1D(Model):
|
209
|
+
def __init__(self, in_channels: int, out_channels: int):
|
210
|
+
super().__init__()
|
211
|
+
self.conv = nn.Sequential(
|
212
|
+
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
213
|
+
nn.BatchNorm1d(out_channels),
|
214
|
+
nn.LeakyReLU(0.2),
|
215
|
+
nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
|
216
|
+
nn.BatchNorm1d(out_channels),
|
217
|
+
nn.LeakyReLU(0.2),
|
218
|
+
)
|
219
|
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
220
|
+
|
221
|
+
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
222
|
+
x = self.upsample(x)
|
223
|
+
x = torch.cat([x, skip], dim=1) # skip connection
|
224
|
+
return self.conv(x)
|
225
|
+
|
226
|
+
|
227
|
+
class NoisePredictor1D(Model):
|
228
|
+
def __init__(self, in_channels: int, cond_dim: int = 0, hidden: int = 128):
|
229
|
+
"""
|
230
|
+
Args:
|
231
|
+
in_channels: channels of the noisy input [B, C, T]
|
232
|
+
cond_dim: optional condition vector [B, cond_dim]
|
233
|
+
"""
|
234
|
+
super().__init__()
|
235
|
+
self.proj = nn.Linear(cond_dim, hidden) if cond_dim > 0 else None
|
236
|
+
self.net = nn.Sequential(
|
237
|
+
nn.Conv1d(in_channels, hidden, kernel_size=3, padding=1),
|
238
|
+
nn.SiLU(),
|
239
|
+
nn.Conv1d(hidden, in_channels, kernel_size=3, padding=1),
|
240
|
+
)
|
241
|
+
|
242
|
+
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
|
243
|
+
# x: [B, C, T], cond: [B, cond_dim]
|
244
|
+
if cond is not None:
|
245
|
+
cond_proj = self.proj(cond).unsqueeze(-1) # [B, hidden, 1]
|
246
|
+
x = x + cond_proj # simple conditioning
|
247
|
+
return self.net(x) # [B, C, T]
|
248
|
+
|
249
|
+
|
250
|
+
class UpSampleConv1D(nn.Module):
|
251
|
+
def __init__(self, upsample: bool = False, dim_in: int = 0, dim_out: int = 0):
|
252
|
+
super().__init__()
|
253
|
+
if upsample:
|
254
|
+
self.upsample = lambda x: F.interpolate(x, scale_factor=2, mode="nearest")
|
255
|
+
else:
|
256
|
+
self.upsample = nn.Identity()
|
257
|
+
|
258
|
+
if dim_in == dim_out:
|
259
|
+
self.learned = nn.Identity()
|
260
|
+
else:
|
261
|
+
self.learned = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
262
|
+
|
263
|
+
def forward(self, x):
|
264
|
+
x = self.upsample(x)
|
265
|
+
return self.learned(x)
|
266
|
+
|
267
|
+
|
268
|
+
class AdaINFeaturesBlock1D(Model):
|
269
|
+
def __init__(
|
270
|
+
self,
|
271
|
+
dim_in: int,
|
272
|
+
dim_out: int,
|
273
|
+
style_dim: int = 64,
|
274
|
+
actv=nn.LeakyReLU(0.2),
|
275
|
+
upsample: bool = False,
|
276
|
+
dropout_p=0.0,
|
277
|
+
):
|
278
|
+
super().__init__()
|
279
|
+
self.upsample = UpSampleConv1D(upsample, dim_in, dim_out)
|
280
|
+
self.res_net = nn.ModuleDict(
|
281
|
+
dict(
|
282
|
+
norm_1=AdaIN1D(style_dim, dim_in),
|
283
|
+
sq1=nn.Sequential(
|
284
|
+
actv,
|
285
|
+
(
|
286
|
+
nn.Identity()
|
287
|
+
if not upsample
|
288
|
+
else weight_norm(
|
289
|
+
nn.ConvTranspose1d(
|
290
|
+
dim_in,
|
291
|
+
dim_in,
|
292
|
+
kernel_size=3,
|
293
|
+
stride=2,
|
294
|
+
groups=dim_in,
|
295
|
+
padding=1,
|
296
|
+
output_padding=1,
|
297
|
+
)
|
298
|
+
)
|
299
|
+
),
|
300
|
+
weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)),
|
301
|
+
nn.Dropout(dropout_p),
|
302
|
+
),
|
303
|
+
norm_2=AdaIN1D(style_dim, dim_out),
|
304
|
+
sq2=nn.Sequential(
|
305
|
+
actv,
|
306
|
+
weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)),
|
307
|
+
nn.Dropout(dropout_p),
|
308
|
+
),
|
309
|
+
)
|
310
|
+
)
|
311
|
+
self.sq2 = math.sqrt(2)
|
312
|
+
|
313
|
+
def forward(self, x: Tensor, y: Tensor):
|
314
|
+
u = self.res_net["norm_1"](x, y)
|
315
|
+
u = self.res_net["sq1"](u)
|
316
|
+
u = self.res_net["norm_2"](u, y)
|
317
|
+
u = self.res_net["sq2"]
|
318
|
+
return (u + self.upsample(x)) / self.sq2
|
319
|
+
|
320
|
+
|
321
|
+
class AudioEncoder(Model):
|
322
|
+
"""Untested, hypothetical item"""
|
323
|
+
|
324
|
+
def __init__(
|
325
|
+
self,
|
326
|
+
channels: int = 80,
|
327
|
+
alpha: float = 4.0,
|
328
|
+
interp_mode: Literal[
|
329
|
+
"nearest",
|
330
|
+
"linear",
|
331
|
+
"bilinear",
|
332
|
+
"bicubic",
|
333
|
+
"trilinear",
|
334
|
+
] = "nearest",
|
335
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
336
|
+
):
|
337
|
+
super().__init__()
|
338
|
+
|
339
|
+
self.net = nn.Sequential(
|
340
|
+
nn.Conv1d(
|
341
|
+
channels, channels, kernel_size=3, stride=2, padding=5, groups=channels
|
342
|
+
),
|
343
|
+
nn.LeakyReLU(0.1),
|
344
|
+
nn.Conv1d(channels, channels, kernel_size=7, stride=1, padding=1, groups=1),
|
345
|
+
)
|
346
|
+
self.fc = nn.Linear(channels, channels)
|
347
|
+
self.activation = activation
|
348
|
+
self.channels = channels
|
349
|
+
self.mode = interp_mode
|
350
|
+
self.alpha = alpha
|
351
|
+
|
352
|
+
def forward(self, mels: Tensor, cr_audio: Tensor):
|
353
|
+
sin = torch.asin(cr_audio)
|
354
|
+
cos = torch.acos(cr_audio)
|
355
|
+
mod = (sin * cos) / self.alpha
|
356
|
+
mod = (mod - mod.median(dim=-1, keepdim=True).values) / (
|
357
|
+
mod.std(dim=-1, keepdim=True) + 1e-5
|
358
|
+
)
|
359
|
+
x = self.net(mod)
|
360
|
+
x = (
|
361
|
+
F.interpolate(
|
362
|
+
x,
|
363
|
+
size=mels.shape[-1],
|
364
|
+
mode=self.mode,
|
365
|
+
)
|
366
|
+
.transpose(-1, -2)
|
367
|
+
.contiguous()
|
368
|
+
)
|
369
|
+
x = self.activation(x)
|
370
|
+
return self.fc(x).transpose(-1, -2)
|
371
|
+
|
372
|
+
|
373
|
+
class AudioEncoderAttn(Model):
|
374
|
+
def __init__(
|
375
|
+
self,
|
376
|
+
channels: int = 80,
|
377
|
+
alpha: float = 4.0,
|
378
|
+
interp_mode: Literal[
|
379
|
+
"nearest",
|
380
|
+
"linear",
|
381
|
+
"bilinear",
|
382
|
+
"bicubic",
|
383
|
+
"trilinear",
|
384
|
+
] = "nearest",
|
385
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
386
|
+
):
|
387
|
+
super().__init__()
|
388
|
+
|
389
|
+
self.net = nn.Sequential(
|
390
|
+
nn.Conv1d(
|
391
|
+
channels, channels, kernel_size=3, stride=2, padding=5, groups=channels
|
392
|
+
),
|
393
|
+
nn.LeakyReLU(0.1),
|
394
|
+
nn.Conv1d(channels, channels, kernel_size=7, stride=1, padding=1, groups=1),
|
395
|
+
)
|
396
|
+
self.fusion = CrossAttentionFusion(channels, channels, 2, d_model=channels)
|
397
|
+
self.channels = channels
|
398
|
+
self.mode = interp_mode
|
399
|
+
self.alpha = alpha
|
400
|
+
self.activation = activation
|
401
|
+
|
402
|
+
def forward(self, mels: Tensor, cr_audio: Tensor):
|
403
|
+
sin = torch.asin(cr_audio)
|
404
|
+
cos = torch.acos(cr_audio)
|
405
|
+
mod = (sin * cos) / self.alpha
|
406
|
+
mod = (mod - mod.median(dim=-1, keepdim=True).values) / (
|
407
|
+
mod.std(dim=-1, keepdim=True) + 1e-5
|
408
|
+
)
|
409
|
+
x = self.activation(self.net(mod))
|
410
|
+
x = F.interpolate(x, size=mels.shape[-1], mode=self.mode)
|
411
|
+
|
412
|
+
# Ensure contiguous before transpose
|
413
|
+
x_t = x.transpose(-2, -1).contiguous()
|
414
|
+
mels_t = mels.transpose(-2, -1).contiguous()
|
415
|
+
|
416
|
+
return self.fusion(x_t, mels_t).transpose(-2, -1)
|
@@ -0,0 +1,164 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"ConcatFusion",
|
3
|
+
"FiLMFusion",
|
4
|
+
"BilinearFusion",
|
5
|
+
"CrossAttentionFusion",
|
6
|
+
"GatedFusion",
|
7
|
+
]
|
8
|
+
from lt_utils.common import *
|
9
|
+
from lt_tensor.torch_commons import *
|
10
|
+
from lt_tensor.model_base import Model
|
11
|
+
import torch.nn.functional as F
|
12
|
+
|
13
|
+
|
14
|
+
class ConcatFusion(Model):
|
15
|
+
def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
|
16
|
+
super().__init__()
|
17
|
+
self.proj = nn.Linear(in_dim_a + in_dim_b, out_dim)
|
18
|
+
|
19
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
20
|
+
x = torch.cat([a, b], dim=-1)
|
21
|
+
return self.proj(x)
|
22
|
+
|
23
|
+
|
24
|
+
class FiLMFusion(Model):
|
25
|
+
def __init__(self, cond_dim: int, feature_dim: int):
|
26
|
+
super().__init__()
|
27
|
+
self.modulator = nn.Linear(cond_dim, 2 * feature_dim)
|
28
|
+
|
29
|
+
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
30
|
+
scale, shift = self.modulator(cond).chunk(2, dim=-1)
|
31
|
+
return x * scale + shift
|
32
|
+
|
33
|
+
|
34
|
+
class BilinearFusion(Model):
|
35
|
+
def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
|
36
|
+
super().__init__()
|
37
|
+
self.bilinear = nn.Bilinear(in_dim_a, in_dim_b, out_dim)
|
38
|
+
|
39
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
40
|
+
return self.bilinear(a, b)
|
41
|
+
|
42
|
+
|
43
|
+
class CrossAttentionFusion(Model):
|
44
|
+
def __init__(self, q_dim: int, kv_dim: int, n_heads: int = 4, d_model: int = 256):
|
45
|
+
super().__init__()
|
46
|
+
self.q_proj = nn.Linear(q_dim, d_model)
|
47
|
+
self.k_proj = nn.Linear(kv_dim, d_model)
|
48
|
+
self.v_proj = nn.Linear(kv_dim, d_model)
|
49
|
+
self.attn = nn.MultiheadAttention(
|
50
|
+
embed_dim=d_model, num_heads=n_heads, batch_first=True
|
51
|
+
)
|
52
|
+
|
53
|
+
def forward(self, query: Tensor, context: Tensor, mask: Tensor = None) -> Tensor:
|
54
|
+
Q = self.q_proj(query)
|
55
|
+
K = self.k_proj(context)
|
56
|
+
V = self.v_proj(context)
|
57
|
+
output, _ = self.attn(Q, K, V, key_padding_mask=mask)
|
58
|
+
return output
|
59
|
+
|
60
|
+
|
61
|
+
class GatedFusion(Model):
|
62
|
+
def __init__(self, in_dim: int):
|
63
|
+
super().__init__()
|
64
|
+
self.gate = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.Sigmoid())
|
65
|
+
|
66
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
67
|
+
gate = self.gate(torch.cat([a, b], dim=-1))
|
68
|
+
return gate * a + (1 - gate) * b
|
69
|
+
|
70
|
+
|
71
|
+
class AdaFusion1D(Model):
|
72
|
+
def __init__(self, channels: int, num_features: int):
|
73
|
+
super().__init__()
|
74
|
+
self.fc = nn.Linear(channels, num_features * 2)
|
75
|
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
76
|
+
|
77
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, alpha: torch.Tensor):
|
78
|
+
h = self.fc(y)
|
79
|
+
h = h.view(h.size(0), h.size(1), 1)
|
80
|
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
81
|
+
t = (1.0 + gamma) * self.norm(x) + beta
|
82
|
+
return t + (1 / alpha) * (torch.sin(alpha * t) ** 2)
|
83
|
+
|
84
|
+
|
85
|
+
class AdaIN1D(Model):
|
86
|
+
def __init__(self, channels: int, num_features: int):
|
87
|
+
super().__init__()
|
88
|
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
89
|
+
self.fc = nn.Linear(channels, num_features * 2)
|
90
|
+
|
91
|
+
def forward(self, x, y):
|
92
|
+
h = self.fc(y)
|
93
|
+
h = h.view(h.size(0), h.size(1), 1)
|
94
|
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
95
|
+
return (1 + gamma) * self.norm(x) + beta
|
96
|
+
|
97
|
+
class AdaIN(Model):
|
98
|
+
def __init__(self, cond_dim, num_features, eps=1e-5):
|
99
|
+
"""
|
100
|
+
cond_dim: size of the conditioning input
|
101
|
+
num_features: number of channels in the input feature map
|
102
|
+
"""
|
103
|
+
super().__init__()
|
104
|
+
self.linear = nn.Linear(cond_dim, num_features * 2)
|
105
|
+
self.eps = eps
|
106
|
+
|
107
|
+
def forward(self, x, cond):
|
108
|
+
"""
|
109
|
+
x: [B, C, T] - input features
|
110
|
+
cond: [B, cond_dim] - global conditioning vector (e.g., speaker/style)
|
111
|
+
"""
|
112
|
+
B, C, T = x.size()
|
113
|
+
# Instance normalization
|
114
|
+
mean = x.mean(dim=2, keepdim=True) # [B, C, 1]
|
115
|
+
std = x.std(dim=2, keepdim=True) + self.eps # [B, C, 1]
|
116
|
+
x_norm = (x - mean) / std # [B, C, T]
|
117
|
+
|
118
|
+
# Conditioning
|
119
|
+
gamma_beta = self.linear(cond) # [B, 2*C]
|
120
|
+
gamma, beta = gamma_beta.chunk(2, dim=1) # [B, C], [B, C]
|
121
|
+
gamma = gamma.unsqueeze(-1) # [B, C, 1]
|
122
|
+
beta = beta.unsqueeze(-1) # [B, C, 1]
|
123
|
+
|
124
|
+
return gamma * x_norm + beta
|
125
|
+
|
126
|
+
class FiLMBlock(Model):
|
127
|
+
def __init__(self, activation: nn.Module = nn.Identity()):
|
128
|
+
super().__init__()
|
129
|
+
self.activation = activation
|
130
|
+
|
131
|
+
def forward(self, x: Tensor, gamma: Tensor, beta: Tensor):
|
132
|
+
beta = beta.view(x.size(0), x.size(1), 1, 1)
|
133
|
+
gamma = gamma.view(x.size(0), x.size(1), 1, 1)
|
134
|
+
return self.activation(gamma * x + beta)
|
135
|
+
|
136
|
+
|
137
|
+
class InterpolatedFusion(Model):
|
138
|
+
def __init__(
|
139
|
+
self,
|
140
|
+
in_dim_a: int,
|
141
|
+
in_dim_b: int,
|
142
|
+
out_dim: int,
|
143
|
+
mode: Literal[
|
144
|
+
"nearest",
|
145
|
+
"linear",
|
146
|
+
"bilinear",
|
147
|
+
"bicubic",
|
148
|
+
"trilinear",
|
149
|
+
"area",
|
150
|
+
"nearest-exact",
|
151
|
+
] = "nearest",
|
152
|
+
):
|
153
|
+
super().__init__()
|
154
|
+
self.fuse = nn.Linear(in_dim_a + in_dim_b, out_dim)
|
155
|
+
self.mode = mode
|
156
|
+
|
157
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
158
|
+
# a: [B, T1, D1], b: [B, T2, D2] → T1 != T2
|
159
|
+
B, T1, _ = a.shape
|
160
|
+
b_interp = F.interpolate(
|
161
|
+
b.transpose(1, 2), size=T1, mode=self.mode, align_corners=False
|
162
|
+
)
|
163
|
+
b_interp = b_interp.transpose(1, 2) # [B, T1, D2]
|
164
|
+
return self.fuse(torch.cat([a, b_interp], dim=-1))
|
@@ -6,12 +6,12 @@ from lt_utils.common import *
|
|
6
6
|
from lt_tensor.torch_commons import *
|
7
7
|
from lt_tensor.model_base import Model
|
8
8
|
from lt_tensor.misc_utils import log_tensor
|
9
|
-
from lt_tensor.model_zoo.
|
9
|
+
from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
|
10
10
|
from lt_utils.misc_utils import log_traceback
|
11
11
|
from lt_tensor.processors import AudioProcessor
|
12
12
|
from lt_utils.type_utils import is_dir, is_pathlike
|
13
13
|
from lt_tensor.misc_utils import set_seed, clear_cache
|
14
|
-
from lt_tensor.model_zoo.
|
14
|
+
from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
15
15
|
import torch.nn.functional as F
|
16
16
|
from lt_tensor.config_templates import updateDict, ModelConfig
|
17
17
|
|