lt-tensor 0.0.1a14__py3-none-any.whl → 0.0.1a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,58 @@
1
+ # Copyright 2020 LMNT, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+
19
+ class AttrDict(dict):
20
+ def __init__(self, *args, **kwargs):
21
+ super(AttrDict, self).__init__(*args, **kwargs)
22
+ self.__dict__ = self
23
+
24
+ def override(self, attrs):
25
+ if isinstance(attrs, dict):
26
+ self.__dict__.update(**attrs)
27
+ elif isinstance(attrs, (list, tuple, set)):
28
+ for attr in attrs:
29
+ self.override(attr)
30
+ elif attrs is not None:
31
+ raise NotImplementedError
32
+ return self
33
+
34
+
35
+ params = AttrDict(
36
+ # Training params
37
+ batch_size=16,
38
+ learning_rate=2e-4,
39
+ max_grad_norm=None,
40
+
41
+ # Data params
42
+ sample_rate=22050,
43
+ n_mels=80,
44
+ n_fft=1024,
45
+ hop_samples=256,
46
+ crop_mel_frames=62, # Probably an error in paper.
47
+
48
+ # Model params
49
+ residual_layers=30,
50
+ residual_channels=64,
51
+ dilation_cycle_length=10,
52
+ unconditional = False,
53
+ noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
54
+ inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],
55
+
56
+ # unconditional sample len
57
+ audio_len = 22050*5, # unconditional_synthesis_samples
58
+ )
@@ -2,85 +2,118 @@ from lt_tensor.torch_commons import *
2
2
  import torch.nn.functional as F
3
3
  from lt_tensor.model_base import Model
4
4
  from lt_utils.common import *
5
+ from einops import rearrange
6
+ import torchaudio
5
7
 
6
8
 
7
- class PeriodDiscriminator(Model):
8
- def __init__(
9
- self,
10
- period: int,
11
- use_spectral_norm=False,
12
- kernel_size: int = 5,
13
- stride: int = 3,
14
- ):
9
+ def get_padding(ks, d):
10
+ return int((ks * d - d) / 2)
11
+
12
+
13
+ class DiscriminatorP(Model):
14
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
15
15
  super().__init__()
16
16
  self.period = period
17
- self.stride = stride
18
- self.kernel_size = kernel_size
19
- self.norm_f = weight_norm if use_spectral_norm == False else spectral_norm
20
-
21
- self.channels = [32, 128, 512, 1024, 1024]
22
- self.first_pass = nn.Sequential(
23
- self.norm_f(
24
- nn.Conv2d(
25
- 1, self.channels[0], (kernel_size, 1), (stride, 1), padding=(2, 0)
26
- )
27
- ),
28
- nn.LeakyReLU(0.1),
29
- )
30
-
17
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
31
18
  self.convs = nn.ModuleList(
32
19
  [
33
- self._get_next(self.channels[i + 1], self.channels[i], i == 3)
34
- for i in range(4)
20
+ norm_f(
21
+ nn.Conv2d(
22
+ 1,
23
+ 32,
24
+ (kernel_size, 1),
25
+ (stride, 1),
26
+ padding=(get_padding(5, 1), 0),
27
+ )
28
+ ),
29
+ norm_f(
30
+ nn.Conv2d(
31
+ 32,
32
+ 128,
33
+ (kernel_size, 1),
34
+ (stride, 1),
35
+ padding=(get_padding(5, 1), 0),
36
+ )
37
+ ),
38
+ norm_f(
39
+ nn.Conv2d(
40
+ 128,
41
+ 512,
42
+ (kernel_size, 1),
43
+ (stride, 1),
44
+ padding=(get_padding(5, 1), 0),
45
+ )
46
+ ),
47
+ norm_f(
48
+ nn.Conv2d(
49
+ 512,
50
+ 1024,
51
+ (kernel_size, 1),
52
+ (stride, 1),
53
+ padding=(get_padding(5, 1), 0),
54
+ )
55
+ ),
56
+ norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
35
57
  ]
36
58
  )
59
+ self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
60
+ self.activation = nn.LeakyReLU(0.1)
37
61
 
38
- self.post_conv = nn.Conv2d(1024, 1, (stride, 1), 1, padding=(1, 0))
39
-
40
- def _get_next(self, out_dim: int, last_in: int, is_last: bool = False):
41
- stride = (self.stride, 1) if not is_last else 1
42
-
43
- return nn.Sequential(
44
- self.norm_f(
45
- nn.Conv2d(
46
- last_in,
47
- out_dim,
48
- (self.kernel_size, 1),
49
- stride,
50
- padding=(2, 0),
51
- )
52
- ),
53
- nn.LeakyReLU(0.1),
62
+ def forward(self, x):
63
+ fmap = []
64
+
65
+ # 1d to 2d
66
+ b, c, t = x.shape
67
+ if t % self.period != 0: # pad first
68
+ n_pad = self.period - (t % self.period)
69
+ x = F.pad(x, (0, n_pad), "reflect")
70
+ t = t + n_pad
71
+ x = x.view(b, c, t // self.period, self.period)
72
+
73
+ for l in self.convs:
74
+ x = l(x)
75
+ x = self.activation(x)
76
+ fmap.append(x)
77
+ x = self.conv_post(x)
78
+ fmap.append(x)
79
+ x = torch.flatten(x, 1, -1)
80
+
81
+ return x, fmap
82
+
83
+
84
+ class MultiPeriodDiscriminator(Model):
85
+ def __init__(self):
86
+ super().__init__()
87
+ self.discriminators = nn.ModuleList(
88
+ [
89
+ DiscriminatorP(2),
90
+ DiscriminatorP(3),
91
+ DiscriminatorP(5),
92
+ DiscriminatorP(7),
93
+ DiscriminatorP(11),
94
+ ]
54
95
  )
55
96
 
56
- def forward(self, x: torch.Tensor):
57
- """
58
- x: (B, T)
59
- """
60
- b, t = x.shape
61
- if t % self.period != 0:
62
- pad_len = self.period - (t % self.period)
63
- x = F.pad(x, (0, pad_len), mode="reflect")
64
- t = t + pad_len
65
-
66
- x = x.view(b, 1, t // self.period, self.period) # (B, 1, T//P, P)
67
-
68
- f_map = []
69
- x = self.first_pass(x)
70
- f_map.append(x)
71
- for conv in self.convs:
72
- x = conv(x)
73
- f_map.append(x)
74
- x = self.post_conv(x)
75
- f_map.append(x)
76
- return x.flatten(1, -1), f_map
77
-
78
-
79
- class ScaleDiscriminator(nn.Module):
97
+ def forward(self, y, y_hat):
98
+ y_d_rs = []
99
+ y_d_gs = []
100
+ fmap_rs = []
101
+ fmap_gs = []
102
+ for i, d in enumerate(self.discriminators):
103
+ y_d_r, fmap_r = d(y)
104
+ y_d_g, fmap_g = d(y_hat)
105
+ y_d_rs.append(y_d_r)
106
+ fmap_rs.append(fmap_r)
107
+ y_d_gs.append(y_d_g)
108
+ fmap_gs.append(fmap_g)
109
+
110
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
111
+
112
+
113
+ class DiscriminatorS(Model):
80
114
  def __init__(self, use_spectral_norm=False):
81
115
  super().__init__()
82
116
  norm_f = weight_norm if use_spectral_norm == False else spectral_norm
83
- self.activation = nn.LeakyReLU(0.1)
84
117
  self.convs = nn.ModuleList(
85
118
  [
86
119
  norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
@@ -92,105 +125,190 @@ class ScaleDiscriminator(nn.Module):
92
125
  norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
93
126
  ]
94
127
  )
95
- self.post_conv = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
128
+ self.activation = nn.LeakyReLU(0.1)
129
+ self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
96
130
 
97
- def forward(self, x: torch.Tensor):
98
- """
99
- x: (B, T)
100
- """
101
- f_map = []
102
- x = x.unsqueeze(1) # (B, 1, T)
103
- for conv in self.convs:
104
- x = self.activation(conv(x))
105
- f_map.append(x)
106
- x = self.post_conv(x)
107
- f_map.append(x)
108
- return x.flatten(1, -1), f_map
131
+ def forward(self, x):
132
+ fmap = []
133
+ for l in self.convs:
134
+ x = l(x)
135
+ x = self.activation(x)
136
+ fmap.append(x)
137
+ x = self.conv_post(x)
138
+ fmap.append(x)
139
+ x = torch.flatten(x, 1, -1)
140
+
141
+ return x, fmap
109
142
 
110
143
 
111
144
  class MultiScaleDiscriminator(Model):
112
- def __init__(self, layers: int = 3):
145
+ def __init__(self):
113
146
  super().__init__()
114
- self.pooling = nn.AvgPool1d(4, 2, padding=2)
115
147
  self.discriminators = nn.ModuleList(
116
- [ScaleDiscriminator(i == 0) for i in range(layers)]
148
+ [
149
+ DiscriminatorS(use_spectral_norm=True),
150
+ DiscriminatorS(),
151
+ DiscriminatorS(),
152
+ ]
153
+ )
154
+ self.meanpools = nn.ModuleList(
155
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
117
156
  )
118
157
 
119
- def forward(self, x: torch.Tensor):
120
- """
121
- x: (B, T)
122
- Returns: list of outputs from each scale discriminator
123
- """
124
- outputs = []
125
- features = []
158
+ def forward(self, y, y_hat):
159
+ y_d_rs = []
160
+ y_d_gs = []
161
+ fmap_rs = []
162
+ fmap_gs = []
126
163
  for i, d in enumerate(self.discriminators):
127
164
  if i != 0:
128
- x = self.pooling(x)
129
- out, f_map = d(x)
130
- outputs.append(out)
131
- features.append(f_map)
132
- return outputs, features
165
+ y = self.meanpools[i - 1](y)
166
+ y_hat = self.meanpools[i - 1](y_hat)
167
+ y_d_r, fmap_r = d(y)
168
+ y_d_g, fmap_g = d(y_hat)
169
+ y_d_rs.append(y_d_r)
170
+ fmap_rs.append(fmap_r)
171
+ y_d_gs.append(y_d_g)
172
+ fmap_gs.append(fmap_g)
133
173
 
174
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
134
175
 
135
- class MultiPeriodDiscriminator(Model):
136
- def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]):
137
- super().__init__()
138
- self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
139
176
 
140
- def forward(self, x: torch.Tensor):
177
+ class MultiResolutionDiscriminator(Model):
178
+ """Source: https://github.com/gemelo-ai/vocos/blob/main/vocos/discriminators.py"""
179
+
180
+ def __init__(
181
+ self,
182
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
183
+ num_embeddings: Optional[int] = None,
184
+ ):
141
185
  """
142
- x: (B, T)
143
- Returns: list of tuples of outputs from each period discriminator and the f_map.
186
+
187
+ Args:
188
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
189
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
190
+ Defaults to None.
144
191
  """
145
- # torch.log(torch.clip(x, min=clip_val))
146
- out_map = []
147
- feat_map = []
192
+
193
+ super().__init__()
194
+ self.discriminators = nn.ModuleList(
195
+ [
196
+ DiscriminatorR(window_length=w, num_embeddings=num_embeddings)
197
+ for w in fft_sizes
198
+ ]
199
+ )
200
+
201
+ def forward(
202
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
203
+ ) -> Tuple[
204
+ List[torch.Tensor],
205
+ List[torch.Tensor],
206
+ List[List[torch.Tensor]],
207
+ List[List[torch.Tensor]],
208
+ ]:
209
+ y_d_rs = []
210
+ y_d_gs = []
211
+ fmap_rs = []
212
+ fmap_gs = []
213
+
148
214
  for d in self.discriminators:
149
- out, feat = d(x)
150
- out_map.append(out)
151
- feat_map.append(feat)
152
- return out_map, feat_map
153
-
154
-
155
- def discriminator_loss(real_out_map, fake_out_map):
156
- loss = 0.0
157
- rl, fl = [], []
158
- for real_out, fake_out in zip(real_out_map, fake_out_map):
159
- real_loss = torch.mean((1.0 - real_out) ** 2)
160
- fake_loss = torch.mean(fake_out**2)
161
- loss += real_loss + fake_loss
162
- rl.append(real_loss.item())
163
- fl.append(fake_loss.item())
164
- return loss, sum(rl), sum(fl)
165
-
166
-
167
- def generator_adv_loss(fake_disc_outputs: List[Tensor]):
168
- loss = 0.0
169
- for fake_out in fake_disc_outputs:
170
- fake_score = fake_out[0]
171
- loss += -torch.mean(fake_score)
172
- return loss
173
-
174
-
175
- def feature_loss(
176
- fmap_r,
177
- fmap_g,
178
- weight=2.0,
179
- loss_fn: Callable[[Tensor, Tensor], Tensor] = F.l1_loss,
180
- ):
181
- loss = 0.0
182
- for dr, dg in zip(fmap_r, fmap_g):
183
- for rl, gl in zip(dr, dg):
184
- loss += loss_fn(rl - gl)
185
- return loss * weight
186
-
187
-
188
- def generator_loss(disc_generated_outputs):
189
- loss = 0.0
190
- gen_losses = []
191
- for dg in disc_generated_outputs:
192
- l = torch.mean((1.0 - dg) ** 2)
193
- gen_losses.append(l.item())
194
- loss += l
195
-
196
- return loss, gen_losses
215
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
216
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
217
+ y_d_rs.append(y_d_r)
218
+ fmap_rs.append(fmap_r)
219
+ y_d_gs.append(y_d_g)
220
+ fmap_gs.append(fmap_g)
221
+
222
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
223
+
224
+
225
+ class DiscriminatorR(Model):
226
+ def __init__(
227
+ self,
228
+ window_length: int,
229
+ num_embeddings: Optional[int] = None,
230
+ channels: int = 32,
231
+ hop_factor: float = 0.25,
232
+ bands: Tuple[Tuple[float, float], ...] = (
233
+ (0.0, 0.1),
234
+ (0.1, 0.25),
235
+ (0.25, 0.5),
236
+ (0.5, 0.75),
237
+ (0.75, 1.0),
238
+ ),
239
+ ):
240
+ super().__init__()
241
+ self.window_length = window_length
242
+ self.hop_factor = hop_factor
243
+ self.spec_fn = torchaudio.transforms.Spectrogram(
244
+ n_fft=window_length,
245
+ hop_length=int(window_length * hop_factor),
246
+ win_length=window_length,
247
+ power=None,
248
+ )
249
+ n_fft = window_length // 2 + 1
250
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
251
+ self.bands = bands
252
+ convs = lambda: nn.ModuleList(
253
+ [
254
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
255
+ weight_norm(
256
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
257
+ ),
258
+ weight_norm(
259
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
260
+ ),
261
+ weight_norm(
262
+ nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
263
+ ),
264
+ weight_norm(
265
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
266
+ ),
267
+ ]
268
+ )
269
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
270
+
271
+ if num_embeddings is not None:
272
+ self.emb = torch.nn.Embedding(
273
+ num_embeddings=num_embeddings, embedding_dim=channels
274
+ )
275
+ torch.nn.init.zeros_(self.emb.weight)
276
+
277
+ self.conv_post = weight_norm(
278
+ nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
279
+ )
280
+
281
+ def spectrogram(self, x):
282
+ # Remove DC offset
283
+ x = x - x.mean(dim=-1, keepdims=True)
284
+ # Peak normalize the volume of input audio
285
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
286
+ x = self.spec_fn(x)
287
+ x = torch.view_as_real(x)
288
+ x = rearrange(x, "b f t c -> b c t f")
289
+ # Split into bands
290
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
291
+ return x_bands
292
+
293
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
294
+ x_bands = self.spectrogram(x)
295
+ fmap = []
296
+ x = []
297
+ for band, stack in zip(x_bands, self.band_convs):
298
+ for i, layer in enumerate(stack):
299
+ band = layer(band)
300
+ band = torch.nn.functional.leaky_relu(band, 0.1)
301
+ if i > 0:
302
+ fmap.append(band)
303
+ x.append(band)
304
+ x = torch.cat(x, dim=-1)
305
+ if cond_embedding_id is not None:
306
+ emb = self.emb(cond_embedding_id)
307
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
308
+ else:
309
+ h = 0
310
+ x = self.conv_post(x)
311
+ fmap.append(x)
312
+ x += h
313
+
314
+ return x, fmap
@@ -323,8 +323,11 @@ class AudioEncoder(Model):
323
323
 
324
324
  def __init__(
325
325
  self,
326
- channels: int = 80,
326
+ channels: int,
327
327
  alpha: float = 4.0,
328
+ feat_channels: int = 64,
329
+ out_features: Optional[int] = None,
330
+ out_channels: int = 1,
328
331
  interp_mode: Literal[
329
332
  "nearest",
330
333
  "linear",
@@ -338,16 +341,60 @@ class AudioEncoder(Model):
338
341
 
339
342
  self.net = nn.Sequential(
340
343
  nn.Conv1d(
341
- channels, channels, kernel_size=3, stride=2, padding=5, groups=channels
344
+ channels, feat_channels, kernel_size=3, stride=1, padding=5, groups=1
342
345
  ),
343
346
  nn.LeakyReLU(0.1),
344
- nn.Conv1d(channels, channels, kernel_size=7, stride=1, padding=1, groups=1),
347
+ nn.Conv1d(
348
+ feat_channels,
349
+ feat_channels,
350
+ kernel_size=3,
351
+ stride=2,
352
+ padding=1,
353
+ groups=feat_channels,
354
+ ),
355
+ nn.LeakyReLU(0.1),
356
+ nn.Conv1d(
357
+ feat_channels,
358
+ feat_channels,
359
+ kernel_size=3,
360
+ stride=1,
361
+ padding=1,
362
+ groups=feat_channels // 8,
363
+ ),
364
+ nn.LeakyReLU(0.1),
365
+ nn.Conv1d(
366
+ feat_channels,
367
+ feat_channels,
368
+ kernel_size=7,
369
+ stride=1,
370
+ padding=1,
371
+ groups=1,
372
+ ),
345
373
  )
346
- self.fc = nn.Linear(channels, channels)
374
+ self.fc = nn.Linear(feat_channels, channels)
375
+ self.feat_channels = feat_channels
347
376
  self.activation = activation
348
377
  self.channels = channels
349
378
  self.mode = interp_mode
350
379
  self.alpha = alpha
380
+ self.post_conv = nn.Conv1d(
381
+ channels,
382
+ out_channels,
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ dilation=1,
387
+ groups=1,
388
+ bias=True,
389
+ )
390
+ if out_features is not None:
391
+ self.format_out = lambda tensor: F.interpolate(
392
+ tensor,
393
+ size=out_features,
394
+ mode=interp_mode,
395
+ )
396
+ else:
397
+ self.format_out = nn.Identity()
351
398
 
352
399
  def forward(self, mels: Tensor, cr_audio: Tensor):
353
400
  sin = torch.asin(cr_audio)
@@ -367,14 +414,20 @@ class AudioEncoder(Model):
367
414
  .contiguous()
368
415
  )
369
416
  x = self.activation(x)
370
- return self.fc(x).transpose(-1, -2)
417
+
418
+ xt = self.fc(x).transpose(-1, -2)
419
+ out = self.post_conv(xt)
420
+ return self.format_out(out)
371
421
 
372
422
 
373
423
  class AudioEncoderAttn(Model):
374
424
  def __init__(
375
425
  self,
376
- channels: int = 80,
426
+ channels: int,
427
+ feat_channels: int = 64,
377
428
  alpha: float = 4.0,
429
+ out_channels: Optional[int] = None,
430
+ out_features: int = 1,
378
431
  interp_mode: Literal[
379
432
  "nearest",
380
433
  "linear",
@@ -388,16 +441,54 @@ class AudioEncoderAttn(Model):
388
441
 
389
442
  self.net = nn.Sequential(
390
443
  nn.Conv1d(
391
- channels, channels, kernel_size=3, stride=2, padding=5, groups=channels
444
+ channels, feat_channels, kernel_size=3, stride=1, padding=1, groups=1
445
+ ),
446
+ nn.LeakyReLU(0.1),
447
+ nn.Conv1d(
448
+ feat_channels,
449
+ feat_channels,
450
+ kernel_size=3,
451
+ stride=2,
452
+ padding=5,
453
+ groups=feat_channels,
392
454
  ),
393
455
  nn.LeakyReLU(0.1),
394
- nn.Conv1d(channels, channels, kernel_size=7, stride=1, padding=1, groups=1),
456
+ nn.Conv1d(
457
+ feat_channels,
458
+ feat_channels,
459
+ kernel_size=3,
460
+ stride=1,
461
+ padding=1,
462
+ groups=feat_channels // 8,
463
+ ),
464
+ nn.LeakyReLU(0.1),
465
+ nn.Conv1d(
466
+ feat_channels, channels, kernel_size=7, stride=1, padding=1, groups=1
467
+ ),
395
468
  )
396
469
  self.fusion = CrossAttentionFusion(channels, channels, 2, d_model=channels)
397
470
  self.channels = channels
398
471
  self.mode = interp_mode
399
472
  self.alpha = alpha
400
473
  self.activation = activation
474
+ self.post_conv = nn.Conv1d(
475
+ channels,
476
+ out_channels,
477
+ kernel_size=1,
478
+ stride=1,
479
+ padding=0,
480
+ dilation=1,
481
+ groups=1,
482
+ bias=True,
483
+ )
484
+ if out_features is not None:
485
+ self.format_out = lambda tensor: F.interpolate(
486
+ tensor,
487
+ size=out_features,
488
+ mode=interp_mode,
489
+ )
490
+ else:
491
+ self.format_out = nn.Identity()
401
492
 
402
493
  def forward(self, mels: Tensor, cr_audio: Tensor):
403
494
  sin = torch.asin(cr_audio)
@@ -408,9 +499,9 @@ class AudioEncoderAttn(Model):
408
499
  )
409
500
  x = self.activation(self.net(mod))
410
501
  x = F.interpolate(x, size=mels.shape[-1], mode=self.mode)
411
-
412
- # Ensure contiguous before transpose
413
502
  x_t = x.transpose(-2, -1).contiguous()
414
503
  mels_t = mels.transpose(-2, -1).contiguous()
415
504
 
416
- return self.fusion(x_t, mels_t).transpose(-2, -1)
505
+ xt = self.fusion(x_t, mels_t).transpose(-2, -1)
506
+ out = self.post_conv(xt)
507
+ return self.format_out(out)