lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,416 @@
1
+ __all__ = [
2
+ "Downsample1D",
3
+ "Upsample1D",
4
+ "DiffusionUNet",
5
+ "UNetConvBlock1D",
6
+ "UNetUpBlock1D",
7
+ "NoisePredictor1D",
8
+ "AdaINFeaturesBlock1D",
9
+ "UpSampleConv1D",
10
+ ]
11
+ import math
12
+ from lt_tensor.torch_commons import *
13
+ from lt_tensor.model_base import Model
14
+ from lt_tensor.model_zoo.residual import ResBlock1D
15
+ from lt_tensor.misc_utils import log_tensor
16
+ from lt_tensor.model_zoo.fusion import AdaIN1D, CrossAttentionFusion
17
+ import torch.nn.functional as F
18
+ from lt_tensor.misc_utils import get_activated_conv
19
+ from lt_utils.common import *
20
+
21
+
22
+ class FeatureExtractor(Model):
23
+ def __init__(
24
+ self,
25
+ in_channels: int = 1,
26
+ out_channels: int = 32,
27
+ hidden: int = 128,
28
+ groups: tuple[int] = [1, 1, 1, 1, 1],
29
+ kernels: tuple[int, int, int, int, int] = (5, 3, 3, 3, 3),
30
+ padding: tuple[int, int, int, int, int] = (2, 1, 1, 1, 1),
31
+ stride: tuple[int, int, int, int, int] = (2, 2, 2, 1, 1),
32
+ network: Literal[
33
+ "Conv1d",
34
+ "Conv2d",
35
+ "Conv3d",
36
+ "ConvTranspose1d",
37
+ "ConvTranspose2d",
38
+ "ConvTranspose3d",
39
+ ] = "conv1d",
40
+ ):
41
+ super().__init__()
42
+ self.pre_nt = get_activated_conv(
43
+ in_channels,
44
+ hidden,
45
+ kernels[0],
46
+ stride[0],
47
+ padding[0],
48
+ groups[0],
49
+ conv_type=network,
50
+ )
51
+
52
+ self.net = nn.Sequential(
53
+ [
54
+ get_activated_conv(
55
+ hidden,
56
+ hidden,
57
+ kernels[i],
58
+ stride[i],
59
+ padding[i],
60
+ groups[i],
61
+ conv_type=network,
62
+ )
63
+ for i in range(1, 4)
64
+ ],
65
+ )
66
+ self.post = get_activated_conv(
67
+ hidden,
68
+ out_channels,
69
+ kernels[4],
70
+ stride[4],
71
+ padding[4],
72
+ groups[4],
73
+ conv_type=network,
74
+ )
75
+
76
+ def forward(self, x: Tensor):
77
+ x = self.pre_nt(x)
78
+ x = self.net(x)
79
+ return self.post(x)
80
+
81
+
82
+ class Downsample1D(Model):
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: int,
87
+ ):
88
+ super().__init__()
89
+ self.pool = nn.Conv1d(in_channels, out_channels, 4, stride=2, padding=1)
90
+
91
+ def forward(self, x):
92
+ return self.pool(x)
93
+
94
+
95
+ class Upsample1D(Model):
96
+ def __init__(
97
+ self,
98
+ in_channels: int,
99
+ out_channels: int,
100
+ activation=nn.ReLU(inplace=True),
101
+ ):
102
+ super().__init__()
103
+ self.up = nn.Sequential(
104
+ nn.ConvTranspose1d(
105
+ in_channels, out_channels, kernel_size=4, stride=2, padding=1
106
+ ),
107
+ nn.BatchNorm1d(out_channels),
108
+ activation,
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.up(x)
113
+
114
+
115
+ class DiffusionUNet(Model):
116
+ def __init__(self, in_channels=1, base_channels=64, out_channels=1, depth=4):
117
+ super().__init__()
118
+
119
+ self.depth = depth
120
+ self.encoder_blocks = nn.ModuleList()
121
+ self.downsamples = nn.ModuleList()
122
+ self.upsamples = nn.ModuleList()
123
+ self.decoder_blocks = nn.ModuleList()
124
+ # Keep track of channel sizes per layer for skip connections
125
+ self.channels = [in_channels] # starting input channel
126
+ for i in range(depth):
127
+ enc_in = self.channels[-1]
128
+ enc_out = base_channels * (2**i)
129
+ # Encoder block and downsample
130
+ self.encoder_blocks.append(ResBlock1D(enc_in, enc_out))
131
+ self.downsamples.append(
132
+ Downsample1D(enc_out, enc_out)
133
+ ) # halve time, keep channels
134
+ self.channels.append(enc_out)
135
+ # Bottleneck
136
+ bottleneck_ch = self.channels[-1]
137
+ self.bottleneck = ResBlock1D(bottleneck_ch, bottleneck_ch)
138
+ # Decoder blocks (reverse channel flow)
139
+ for i in reversed(range(depth)):
140
+ skip_ch = self.channels[i + 1] # from encoder
141
+ dec_out = self.channels[i] # match earlier stage's output
142
+ self.upsamples.append(Upsample1D(skip_ch, skip_ch))
143
+ self.decoder_blocks.append(ResBlock1D(skip_ch * 2, dec_out))
144
+ # Final output projection (out_channels)
145
+ self.final = nn.Conv1d(in_channels, out_channels, kernel_size=1)
146
+
147
+ def forward(self, x: Tensor):
148
+ skips = []
149
+
150
+ # Encoder
151
+ for enc, down in zip(self.encoder_blocks, self.downsamples):
152
+ # log_tensor(x, "before enc")
153
+ x = enc(x)
154
+ skips.append(x)
155
+ x = down(x)
156
+
157
+ # Bottleneck
158
+ x = self.bottleneck(x)
159
+
160
+ # Decoder
161
+ for up, dec, skip in zip(self.upsamples, self.decoder_blocks, reversed(skips)):
162
+ x = up(x)
163
+
164
+ # Match lengths via trimming or padding
165
+ if x.shape[-1] > skip.shape[-1]:
166
+ x = x[..., : skip.shape[-1]]
167
+ elif x.shape[-1] < skip.shape[-1]:
168
+ diff = skip.shape[-1] - x.shape[-1]
169
+ x = F.pad(x, (0, diff))
170
+
171
+ x = torch.cat([x, skip], dim=1) # concat on channels
172
+ x = dec(x)
173
+
174
+ # Final 1x1 conv
175
+ return self.final(x)
176
+
177
+
178
+ class UNetConvBlock1D(Model):
179
+ def __init__(self, in_channels: int, out_channels: int, down: bool = True):
180
+ super().__init__()
181
+ self.down = down
182
+ self.conv = nn.Sequential(
183
+ nn.Conv1d(
184
+ in_channels,
185
+ out_channels,
186
+ kernel_size=3,
187
+ stride=2 if down else 1,
188
+ padding=1,
189
+ ),
190
+ nn.BatchNorm1d(out_channels),
191
+ nn.LeakyReLU(0.2),
192
+ nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
193
+ nn.BatchNorm1d(out_channels),
194
+ nn.LeakyReLU(0.2),
195
+ )
196
+ self.downsample = (
197
+ nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=2 if down else 1)
198
+ if in_channels != out_channels
199
+ else nn.Identity()
200
+ )
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ # x: [B, C, T]
204
+ residual = self.downsample(x)
205
+ return self.conv(x) + residual
206
+
207
+
208
+ class UNetUpBlock1D(Model):
209
+ def __init__(self, in_channels: int, out_channels: int):
210
+ super().__init__()
211
+ self.conv = nn.Sequential(
212
+ nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
213
+ nn.BatchNorm1d(out_channels),
214
+ nn.LeakyReLU(0.2),
215
+ nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
216
+ nn.BatchNorm1d(out_channels),
217
+ nn.LeakyReLU(0.2),
218
+ )
219
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
220
+
221
+ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
222
+ x = self.upsample(x)
223
+ x = torch.cat([x, skip], dim=1) # skip connection
224
+ return self.conv(x)
225
+
226
+
227
+ class NoisePredictor1D(Model):
228
+ def __init__(self, in_channels: int, cond_dim: int = 0, hidden: int = 128):
229
+ """
230
+ Args:
231
+ in_channels: channels of the noisy input [B, C, T]
232
+ cond_dim: optional condition vector [B, cond_dim]
233
+ """
234
+ super().__init__()
235
+ self.proj = nn.Linear(cond_dim, hidden) if cond_dim > 0 else None
236
+ self.net = nn.Sequential(
237
+ nn.Conv1d(in_channels, hidden, kernel_size=3, padding=1),
238
+ nn.SiLU(),
239
+ nn.Conv1d(hidden, in_channels, kernel_size=3, padding=1),
240
+ )
241
+
242
+ def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
243
+ # x: [B, C, T], cond: [B, cond_dim]
244
+ if cond is not None:
245
+ cond_proj = self.proj(cond).unsqueeze(-1) # [B, hidden, 1]
246
+ x = x + cond_proj # simple conditioning
247
+ return self.net(x) # [B, C, T]
248
+
249
+
250
+ class UpSampleConv1D(nn.Module):
251
+ def __init__(self, upsample: bool = False, dim_in: int = 0, dim_out: int = 0):
252
+ super().__init__()
253
+ if upsample:
254
+ self.upsample = lambda x: F.interpolate(x, scale_factor=2, mode="nearest")
255
+ else:
256
+ self.upsample = nn.Identity()
257
+
258
+ if dim_in == dim_out:
259
+ self.learned = nn.Identity()
260
+ else:
261
+ self.learned = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
262
+
263
+ def forward(self, x):
264
+ x = self.upsample(x)
265
+ return self.learned(x)
266
+
267
+
268
+ class AdaINFeaturesBlock1D(Model):
269
+ def __init__(
270
+ self,
271
+ dim_in: int,
272
+ dim_out: int,
273
+ style_dim: int = 64,
274
+ actv=nn.LeakyReLU(0.2),
275
+ upsample: bool = False,
276
+ dropout_p=0.0,
277
+ ):
278
+ super().__init__()
279
+ self.upsample = UpSampleConv1D(upsample, dim_in, dim_out)
280
+ self.res_net = nn.ModuleDict(
281
+ dict(
282
+ norm_1=AdaIN1D(style_dim, dim_in),
283
+ sq1=nn.Sequential(
284
+ actv,
285
+ (
286
+ nn.Identity()
287
+ if not upsample
288
+ else weight_norm(
289
+ nn.ConvTranspose1d(
290
+ dim_in,
291
+ dim_in,
292
+ kernel_size=3,
293
+ stride=2,
294
+ groups=dim_in,
295
+ padding=1,
296
+ output_padding=1,
297
+ )
298
+ )
299
+ ),
300
+ weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)),
301
+ nn.Dropout(dropout_p),
302
+ ),
303
+ norm_2=AdaIN1D(style_dim, dim_out),
304
+ sq2=nn.Sequential(
305
+ actv,
306
+ weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)),
307
+ nn.Dropout(dropout_p),
308
+ ),
309
+ )
310
+ )
311
+ self.sq2 = math.sqrt(2)
312
+
313
+ def forward(self, x: Tensor, y: Tensor):
314
+ u = self.res_net["norm_1"](x, y)
315
+ u = self.res_net["sq1"](u)
316
+ u = self.res_net["norm_2"](u, y)
317
+ u = self.res_net["sq2"]
318
+ return (u + self.upsample(x)) / self.sq2
319
+
320
+
321
+ class AudioEncoder(Model):
322
+ """Untested, hypothetical item"""
323
+
324
+ def __init__(
325
+ self,
326
+ channels: int = 80,
327
+ alpha: float = 4.0,
328
+ interp_mode: Literal[
329
+ "nearest",
330
+ "linear",
331
+ "bilinear",
332
+ "bicubic",
333
+ "trilinear",
334
+ ] = "nearest",
335
+ activation: nn.Module = nn.LeakyReLU(0.1),
336
+ ):
337
+ super().__init__()
338
+
339
+ self.net = nn.Sequential(
340
+ nn.Conv1d(
341
+ channels, channels, kernel_size=3, stride=2, padding=5, groups=channels
342
+ ),
343
+ nn.LeakyReLU(0.1),
344
+ nn.Conv1d(channels, channels, kernel_size=7, stride=1, padding=1, groups=1),
345
+ )
346
+ self.fc = nn.Linear(channels, channels)
347
+ self.activation = activation
348
+ self.channels = channels
349
+ self.mode = interp_mode
350
+ self.alpha = alpha
351
+
352
+ def forward(self, mels: Tensor, cr_audio: Tensor):
353
+ sin = torch.asin(cr_audio)
354
+ cos = torch.acos(cr_audio)
355
+ mod = (sin * cos) / self.alpha
356
+ mod = (mod - mod.median(dim=-1, keepdim=True).values) / (
357
+ mod.std(dim=-1, keepdim=True) + 1e-5
358
+ )
359
+ x = self.net(mod)
360
+ x = (
361
+ F.interpolate(
362
+ x,
363
+ size=mels.shape[-1],
364
+ mode=self.mode,
365
+ )
366
+ .transpose(-1, -2)
367
+ .contiguous()
368
+ )
369
+ x = self.activation(x)
370
+ return self.fc(x).transpose(-1, -2)
371
+
372
+
373
+ class AudioEncoderAttn(Model):
374
+ def __init__(
375
+ self,
376
+ channels: int = 80,
377
+ alpha: float = 4.0,
378
+ interp_mode: Literal[
379
+ "nearest",
380
+ "linear",
381
+ "bilinear",
382
+ "bicubic",
383
+ "trilinear",
384
+ ] = "nearest",
385
+ activation: nn.Module = nn.LeakyReLU(0.1),
386
+ ):
387
+ super().__init__()
388
+
389
+ self.net = nn.Sequential(
390
+ nn.Conv1d(
391
+ channels, channels, kernel_size=3, stride=2, padding=5, groups=channels
392
+ ),
393
+ nn.LeakyReLU(0.1),
394
+ nn.Conv1d(channels, channels, kernel_size=7, stride=1, padding=1, groups=1),
395
+ )
396
+ self.fusion = CrossAttentionFusion(channels, channels, 2, d_model=channels)
397
+ self.channels = channels
398
+ self.mode = interp_mode
399
+ self.alpha = alpha
400
+ self.activation = activation
401
+
402
+ def forward(self, mels: Tensor, cr_audio: Tensor):
403
+ sin = torch.asin(cr_audio)
404
+ cos = torch.acos(cr_audio)
405
+ mod = (sin * cos) / self.alpha
406
+ mod = (mod - mod.median(dim=-1, keepdim=True).values) / (
407
+ mod.std(dim=-1, keepdim=True) + 1e-5
408
+ )
409
+ x = self.activation(self.net(mod))
410
+ x = F.interpolate(x, size=mels.shape[-1], mode=self.mode)
411
+
412
+ # Ensure contiguous before transpose
413
+ x_t = x.transpose(-2, -1).contiguous()
414
+ mels_t = mels.transpose(-2, -1).contiguous()
415
+
416
+ return self.fusion(x_t, mels_t).transpose(-2, -1)
@@ -0,0 +1,164 @@
1
+ __all__ = [
2
+ "ConcatFusion",
3
+ "FiLMFusion",
4
+ "BilinearFusion",
5
+ "CrossAttentionFusion",
6
+ "GatedFusion",
7
+ ]
8
+ from lt_utils.common import *
9
+ from lt_tensor.torch_commons import *
10
+ from lt_tensor.model_base import Model
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ConcatFusion(Model):
15
+ def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
16
+ super().__init__()
17
+ self.proj = nn.Linear(in_dim_a + in_dim_b, out_dim)
18
+
19
+ def forward(self, a: Tensor, b: Tensor) -> Tensor:
20
+ x = torch.cat([a, b], dim=-1)
21
+ return self.proj(x)
22
+
23
+
24
+ class FiLMFusion(Model):
25
+ def __init__(self, cond_dim: int, feature_dim: int):
26
+ super().__init__()
27
+ self.modulator = nn.Linear(cond_dim, 2 * feature_dim)
28
+
29
+ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
30
+ scale, shift = self.modulator(cond).chunk(2, dim=-1)
31
+ return x * scale + shift
32
+
33
+
34
+ class BilinearFusion(Model):
35
+ def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
36
+ super().__init__()
37
+ self.bilinear = nn.Bilinear(in_dim_a, in_dim_b, out_dim)
38
+
39
+ def forward(self, a: Tensor, b: Tensor) -> Tensor:
40
+ return self.bilinear(a, b)
41
+
42
+
43
+ class CrossAttentionFusion(Model):
44
+ def __init__(self, q_dim: int, kv_dim: int, n_heads: int = 4, d_model: int = 256):
45
+ super().__init__()
46
+ self.q_proj = nn.Linear(q_dim, d_model)
47
+ self.k_proj = nn.Linear(kv_dim, d_model)
48
+ self.v_proj = nn.Linear(kv_dim, d_model)
49
+ self.attn = nn.MultiheadAttention(
50
+ embed_dim=d_model, num_heads=n_heads, batch_first=True
51
+ )
52
+
53
+ def forward(self, query: Tensor, context: Tensor, mask: Tensor = None) -> Tensor:
54
+ Q = self.q_proj(query)
55
+ K = self.k_proj(context)
56
+ V = self.v_proj(context)
57
+ output, _ = self.attn(Q, K, V, key_padding_mask=mask)
58
+ return output
59
+
60
+
61
+ class GatedFusion(Model):
62
+ def __init__(self, in_dim: int):
63
+ super().__init__()
64
+ self.gate = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.Sigmoid())
65
+
66
+ def forward(self, a: Tensor, b: Tensor) -> Tensor:
67
+ gate = self.gate(torch.cat([a, b], dim=-1))
68
+ return gate * a + (1 - gate) * b
69
+
70
+
71
+ class AdaFusion1D(Model):
72
+ def __init__(self, channels: int, num_features: int):
73
+ super().__init__()
74
+ self.fc = nn.Linear(channels, num_features * 2)
75
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
76
+
77
+ def forward(self, x: torch.Tensor, y: torch.Tensor, alpha: torch.Tensor):
78
+ h = self.fc(y)
79
+ h = h.view(h.size(0), h.size(1), 1)
80
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
81
+ t = (1.0 + gamma) * self.norm(x) + beta
82
+ return t + (1 / alpha) * (torch.sin(alpha * t) ** 2)
83
+
84
+
85
+ class AdaIN1D(Model):
86
+ def __init__(self, channels: int, num_features: int):
87
+ super().__init__()
88
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
89
+ self.fc = nn.Linear(channels, num_features * 2)
90
+
91
+ def forward(self, x, y):
92
+ h = self.fc(y)
93
+ h = h.view(h.size(0), h.size(1), 1)
94
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
95
+ return (1 + gamma) * self.norm(x) + beta
96
+
97
+ class AdaIN(Model):
98
+ def __init__(self, cond_dim, num_features, eps=1e-5):
99
+ """
100
+ cond_dim: size of the conditioning input
101
+ num_features: number of channels in the input feature map
102
+ """
103
+ super().__init__()
104
+ self.linear = nn.Linear(cond_dim, num_features * 2)
105
+ self.eps = eps
106
+
107
+ def forward(self, x, cond):
108
+ """
109
+ x: [B, C, T] - input features
110
+ cond: [B, cond_dim] - global conditioning vector (e.g., speaker/style)
111
+ """
112
+ B, C, T = x.size()
113
+ # Instance normalization
114
+ mean = x.mean(dim=2, keepdim=True) # [B, C, 1]
115
+ std = x.std(dim=2, keepdim=True) + self.eps # [B, C, 1]
116
+ x_norm = (x - mean) / std # [B, C, T]
117
+
118
+ # Conditioning
119
+ gamma_beta = self.linear(cond) # [B, 2*C]
120
+ gamma, beta = gamma_beta.chunk(2, dim=1) # [B, C], [B, C]
121
+ gamma = gamma.unsqueeze(-1) # [B, C, 1]
122
+ beta = beta.unsqueeze(-1) # [B, C, 1]
123
+
124
+ return gamma * x_norm + beta
125
+
126
+ class FiLMBlock(Model):
127
+ def __init__(self, activation: nn.Module = nn.Identity()):
128
+ super().__init__()
129
+ self.activation = activation
130
+
131
+ def forward(self, x: Tensor, gamma: Tensor, beta: Tensor):
132
+ beta = beta.view(x.size(0), x.size(1), 1, 1)
133
+ gamma = gamma.view(x.size(0), x.size(1), 1, 1)
134
+ return self.activation(gamma * x + beta)
135
+
136
+
137
+ class InterpolatedFusion(Model):
138
+ def __init__(
139
+ self,
140
+ in_dim_a: int,
141
+ in_dim_b: int,
142
+ out_dim: int,
143
+ mode: Literal[
144
+ "nearest",
145
+ "linear",
146
+ "bilinear",
147
+ "bicubic",
148
+ "trilinear",
149
+ "area",
150
+ "nearest-exact",
151
+ ] = "nearest",
152
+ ):
153
+ super().__init__()
154
+ self.fuse = nn.Linear(in_dim_a + in_dim_b, out_dim)
155
+ self.mode = mode
156
+
157
+ def forward(self, a: Tensor, b: Tensor) -> Tensor:
158
+ # a: [B, T1, D1], b: [B, T2, D2] → T1 != T2
159
+ B, T1, _ = a.shape
160
+ b_interp = F.interpolate(
161
+ b.transpose(1, 2), size=T1, mode=self.mode, align_corners=False
162
+ )
163
+ b_interp = b_interp.transpose(1, 2) # [B, T1, D2]
164
+ return self.fuse(torch.cat([a, b_interp], dim=-1))
@@ -6,12 +6,12 @@ from lt_utils.common import *
6
6
  from lt_tensor.torch_commons import *
7
7
  from lt_tensor.model_base import Model
8
8
  from lt_tensor.misc_utils import log_tensor
9
- from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
9
+ from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
10
10
  from lt_utils.misc_utils import log_traceback
11
11
  from lt_tensor.processors import AudioProcessor
12
12
  from lt_utils.type_utils import is_dir, is_pathlike
13
13
  from lt_tensor.misc_utils import set_seed, clear_cache
14
- from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
14
+ from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
15
15
  import torch.nn.functional as F
16
16
  from lt_tensor.config_templates import updateDict, ModelConfig
17
17