lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from einops import repeat
5
+
6
+
7
+ class SineGen(nn.Module):
8
+ def __init__(
9
+ self,
10
+ samp_rate,
11
+ upsample_scale,
12
+ harmonic_num=0,
13
+ sine_amp=0.1,
14
+ noise_std=0.003,
15
+ voiced_threshold=0,
16
+ flag_for_pulse=False,
17
+ ):
18
+ super().__init__()
19
+ self.sampling_rate = samp_rate
20
+ self.upsample_scale = upsample_scale
21
+ self.harmonic_num = harmonic_num
22
+ self.sine_amp = sine_amp
23
+ self.noise_std = noise_std
24
+ self.voiced_threshold = voiced_threshold
25
+ self.flag_for_pulse = flag_for_pulse
26
+ self.dim = self.harmonic_num + 1 # fundamental + harmonics
27
+
28
+ def _f02uv_b(self, f0):
29
+ return (f0 > self.voiced_threshold).float() # [B, T]
30
+
31
+ def _f02uv(self, f0):
32
+ return (f0 > self.voiced_threshold).float().unsqueeze(-1) # -> (B, T, 1)
33
+
34
+ @torch.no_grad()
35
+ def _f02sine(self, f0_values):
36
+ """
37
+ f0_values: (B, T, 1)
38
+ Output: sine waves (B, T * upsample, dim)
39
+ """
40
+ B, T, _ = f0_values.size()
41
+ f0_upsampled = repeat(
42
+ f0_values, "b t d -> b (t r) d", r=self.upsample_scale
43
+ ) # (B, T_up, 1)
44
+
45
+ # Create harmonics
46
+ harmonics = (
47
+ torch.arange(1, self.dim + 1, device=f0_values.device)
48
+ .float()
49
+ .view(1, 1, -1)
50
+ )
51
+ f0_harm = f0_upsampled * harmonics # (B, T_up, dim)
52
+
53
+ # Convert Hz to radians (2πf/sr), then integrate to get phase
54
+ rad_values = f0_harm / self.sampling_rate # normalized freq
55
+ rad_values = rad_values % 1.0 # remove multiples of 2π
56
+
57
+ # Random initial phase for each harmonic (except 0th if pulse mode)
58
+ if self.flag_for_pulse:
59
+ rand_ini = torch.zeros((B, 1, self.dim), device=f0_values.device)
60
+ else:
61
+ rand_ini = torch.rand((B, 1, self.dim), device=f0_values.device)
62
+
63
+ rand_ini = rand_ini * 2 * math.pi
64
+
65
+ # Compute cumulative phase
66
+ rad_values = rad_values * 2 * math.pi
67
+ phase = torch.cumsum(rad_values, dim=1) + rand_ini # (B, T_up, dim)
68
+
69
+ sine_waves = torch.sin(phase) # (B, T_up, dim)
70
+ return sine_waves
71
+
72
+ def _forward(self, f0):
73
+ """
74
+ f0: (B, T, 1)
75
+ returns: sine signal with harmonics and noise added
76
+ """
77
+ sine_waves = self._f02sine(f0) # (B, T_up, dim)
78
+ uv = self._f02uv_b(f0) # (B, T, 1)
79
+ uv = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
80
+
81
+ # voiced sine + unvoiced noise
82
+ sine_signal = self.sine_amp * sine_waves * uv # (B, T_up, dim)
83
+ noise = torch.randn_like(sine_signal) * self.noise_std
84
+ output = sine_signal + noise * (1.0 - uv) # noise added only on unvoiced
85
+
86
+ return output # (B, T_up, dim)
87
+
88
+ def forward(self, f0):
89
+ """
90
+ Args:
91
+ f0: (B, T) in Hz (before upsampling)
92
+ Returns:
93
+ sine_waves: (B, T_up, dim)
94
+ uv: (B, T_up, 1)
95
+ noise: (B, T_up, 1)
96
+ """
97
+ B, T = f0.shape
98
+ device = f0.device
99
+
100
+ # Get uv mask (before upsampling)
101
+ uv = self._f02uv(f0) # (B, T, 1)
102
+
103
+ # Expand f0 to include harmonics: (B, T, dim)
104
+ f0 = f0.unsqueeze(-1) # (B, T, 1)
105
+ harmonics = (
106
+ torch.arange(1, self.dim + 1, device=device).float().view(1, 1, -1)
107
+ ) # (1, 1, dim)
108
+ f0_harm = f0 * harmonics # (B, T, dim)
109
+
110
+ # Upsample
111
+ f0_harm_up = repeat(
112
+ f0_harm, "b t d -> b (t r) d", r=self.upsample_scale
113
+ ) # (B, T_up, dim)
114
+ uv_up = repeat(uv, "b t d -> b (t r) d", r=self.upsample_scale) # (B, T_up, 1)
115
+
116
+ # Convert to radians
117
+ rad_per_sample = f0_harm_up / self.sampling_rate # Hz → cycles/sample
118
+ rad_per_sample = rad_per_sample * 2 * math.pi # cycles → radians/sample
119
+
120
+ # Random phase init for each sample
121
+ B, T_up, D = rad_per_sample.shape
122
+ rand_phase = torch.rand(B, D, device=device) * 2 * math.pi # (B, D)
123
+
124
+ # Compute cumulative phase
125
+ phase = torch.cumsum(rad_per_sample, dim=1) + rand_phase.unsqueeze(
126
+ 1
127
+ ) # (B, T_up, D)
128
+
129
+ # Apply sine
130
+ sine_waves = torch.sin(phase) * self.sine_amp # (B, T_up, D)
131
+
132
+ # Handle unvoiced: create noise only for fundamental
133
+ noise = torch.randn(B, T_up, 1, device=device) * self.noise_std
134
+ if self.flag_for_pulse:
135
+ # If pulse mode is on, align phase at start of voiced segments
136
+ # Optional and tricky to implement — may require segmenting uv
137
+ pass
138
+
139
+ # Replace sine by noise for unvoiced (only on fundamental)
140
+ sine_waves[:, :, 0:1] = sine_waves[:, :, 0:1] * uv_up + noise * (1 - uv_up)
141
+
142
+ return sine_waves, uv_up, noise
@@ -13,18 +13,45 @@ from lt_tensor.misc_utils import set_seed, clear_cache
13
13
  from lt_utils.type_utils import is_dir, is_pathlike, is_file
14
14
  from lt_tensor.config_templates import updateDict, ModelConfig
15
15
  from lt_tensor.model_zoo.istft.generator import iSTFTGenerator
16
- from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
17
- from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
16
+ from lt_tensor.model_zoo.residual import ResBlock1D, ConvNets, get_weight_norm
17
+ from lt_tensor.model_zoo.discriminator import MultiPeriodDiscriminator, MultiScaleDiscriminator
18
18
 
19
19
 
20
- def feature_loss(real_feats, fake_feats):
21
- loss = 0.0
22
- for r, f in zip(real_feats, fake_feats):
23
- for ri, fi in zip(r, f):
24
- loss += F.l1_loss(ri, fi)
20
+ def feature_loss(fmap_r, fmap_g):
21
+ loss = 0
22
+ for dr, dg in zip(fmap_r, fmap_g):
23
+ for rl, gl in zip(dr, dg):
24
+ loss += torch.mean(torch.abs(rl - gl))
25
+ return loss * 2
26
+
27
+
28
+ def generator_adv_loss(disc_outputs):
29
+ loss = 0
30
+ for dg in disc_outputs:
31
+ l = torch.mean((1 - dg) ** 2)
32
+
33
+ loss += l
34
+ return loss
35
+
36
+
37
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
38
+ loss = 0
39
+
40
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
41
+ r_loss = torch.mean((1 - dr) ** 2)
42
+ g_loss = torch.mean(dg**2)
43
+ loss += r_loss + g_loss
25
44
  return loss
26
45
 
27
46
 
47
+ """def feature_loss(fmap_r, fmap_g):
48
+ loss = 0
49
+ for dr, dg in zip(fmap_r, fmap_g):
50
+ for rl, gl in zip(dr, dg):
51
+ loss += torch.mean(torch.abs(rl - gl))
52
+ return loss * 2
53
+
54
+
28
55
  def generator_adv_loss(fake_preds):
29
56
  loss = 0.0
30
57
  for f in fake_preds:
@@ -37,6 +64,7 @@ def discriminator_loss(real_preds, fake_preds):
37
64
  for r, f in zip(real_preds, fake_preds):
38
65
  loss += torch.mean((r - 1.0) ** 2) + torch.mean(f**2)
39
66
  return loss
67
+ """
40
68
 
41
69
 
42
70
  class AudioSettings(ModelConfig):
@@ -284,9 +312,6 @@ class AudioDecoder(Model):
284
312
  win_length=self.settings.n_fft,
285
313
  # length=real_audio.shape[-1]
286
314
  )[:, : real_audio.shape[-1]]
287
- # smallest = min(real_audio.shape[-1], fake_audio.shape[-1])
288
- # real_audio = real_audio[:, :, :smallest].squeeze(1)
289
- # fake_audio = fake_audio[:, :smallest]
290
315
 
291
316
  disc_kwargs = dict(
292
317
  real_audio=real_audio,
@@ -372,13 +397,13 @@ class AudioDecoder(Model):
372
397
 
373
398
  loss_stft = self.audio_processor.stft_loss(fake_audio, real_audio) * stft_scale
374
399
  loss_mel = (
375
- F.l1_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
400
+ F.huber_loss(self.audio_processor.compute_mel(fake_audio), mels) * mel_scale
376
401
  )
377
402
  loss_fm = ((loss_fm_mpd + loss_fm_msd) * fm_scale) + fm_add
378
403
 
379
404
  loss_adv = (loss_adv_mpd + loss_adv_msd) * adv_scale
380
405
 
381
- loss_g = loss_adv + loss_fm + loss_stft + loss_mel
406
+ loss_g = loss_adv + loss_fm + loss_stft # + loss_mel
382
407
  if not am_i_frozen:
383
408
  self.g_optim.zero_grad()
384
409
  loss_g.backward()
@@ -0,0 +1,217 @@
1
+ __all__ = [
2
+ "spectral_norm_select",
3
+ "get_weight_norm",
4
+ "ResBlock1D",
5
+ "ResBlock2D",
6
+ "ResBlock1DShuffled",
7
+ "AdaResBlock1D",
8
+ ]
9
+ import math
10
+ from lt_utils.common import *
11
+ from lt_tensor.torch_commons import *
12
+ from lt_tensor.model_base import Model
13
+ from lt_tensor.misc_utils import log_tensor
14
+ import torch.nn.functional as F
15
+ from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
16
+
17
+
18
+ def spectral_norm_select(module: nn.Module, enabled: bool):
19
+ if enabled:
20
+ return spectral_norm(module)
21
+ return module
22
+
23
+
24
+ def get_weight_norm(norm_type: Optional[Literal["weight", "spectral"]] = None):
25
+ if not norm_type:
26
+ return lambda x: x
27
+ if norm_type == "weight":
28
+ return lambda x: weight_norm(x)
29
+ return lambda x: spectral_norm(x)
30
+
31
+
32
+ class ConvNets(Model):
33
+ def remove_weight_norm(self):
34
+ for module in self.modules():
35
+ try:
36
+ remove_weight_norm(module)
37
+ except ValueError:
38
+ pass
39
+
40
+ @staticmethod
41
+ def init_weights(m, mean=0.0, std=0.01):
42
+ classname = m.__class__.__name__
43
+ if "Conv" in classname:
44
+ m.weight.data.normal_(mean, std)
45
+
46
+
47
+ class ResBlock1D(ConvNets):
48
+ def __init__(
49
+ self,
50
+ channels,
51
+ kernel_size=3,
52
+ dilation=(1, 3, 5),
53
+ activation: nn.Module = nn.LeakyReLU(0.1),
54
+ ):
55
+ super().__init__()
56
+
57
+ self.conv_nets = nn.ModuleList(
58
+ [
59
+ self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
60
+ for i in range(3)
61
+ ]
62
+ )
63
+ self.conv_nets.apply(self.init_weights)
64
+ self.last_index = len(self.conv_nets) - 1
65
+
66
+ def _get_conv_layer(self, id, ch, k, stride, d, actv):
67
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
68
+ return nn.Sequential(
69
+ actv, # 1
70
+ weight_norm(
71
+ nn.Conv1d(
72
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
73
+ )
74
+ ), # 2
75
+ actv, # 3
76
+ weight_norm(
77
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
78
+ ), # 4
79
+ )
80
+
81
+ def forward(self, x: Tensor):
82
+ for cnn in self.conv_nets:
83
+ x = cnn(x) + x
84
+ return x
85
+
86
+
87
+ class ResBlock1DShuffled(ConvNets):
88
+ def __init__(
89
+ self,
90
+ channels,
91
+ kernel_size=3,
92
+ dilation=(1, 3, 5),
93
+ activation: nn.Module = nn.LeakyReLU(0.1),
94
+ add_channel_shuffle: bool = False, # requires pytorch 2.7.0 +
95
+ channel_shuffle_groups=1,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.channel_shuffle = (
100
+ nn.ChannelShuffle(channel_shuffle_groups)
101
+ if add_channel_shuffle
102
+ else nn.Identity()
103
+ )
104
+
105
+ self.conv_nets = nn.ModuleList(
106
+ [
107
+ self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
108
+ for i in range(3)
109
+ ]
110
+ )
111
+ self.conv_nets.apply(self.init_weights)
112
+ self.last_index = len(self.conv_nets) - 1
113
+
114
+ def _get_conv_layer(self, id, ch, k, stride, d, actv):
115
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
116
+ return nn.Sequential(
117
+ actv, # 1
118
+ weight_norm(
119
+ nn.Conv1d(
120
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
121
+ )
122
+ ), # 2
123
+ actv, # 3
124
+ weight_norm(
125
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
126
+ ), # 4
127
+ )
128
+
129
+ def forward(self, x: Tensor):
130
+ b = x.clone() * 0.5
131
+ for cnn in self.conv_nets:
132
+ x = cnn(self.channel_shuffle(x)) + b
133
+ return x
134
+
135
+
136
+ class ResBlock2D(Model):
137
+ def __init__(
138
+ self,
139
+ in_channels,
140
+ out_channels,
141
+ downsample=False,
142
+ ):
143
+ super().__init__()
144
+ stride = 2 if downsample else 1
145
+
146
+ self.block = nn.Sequential(
147
+ nn.Conv2d(in_channels, out_channels, 3, stride, 1),
148
+ nn.LeakyReLU(0.2),
149
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
150
+ )
151
+
152
+ self.skip = nn.Identity()
153
+ if downsample or in_channels != out_channels:
154
+ self.skip = spectral_norm_select(
155
+ nn.Conv2d(in_channels, out_channels, 1, stride)
156
+ )
157
+ # on less to be handled every cicle
158
+ self.sqrt_2 = math.sqrt(2)
159
+
160
+ def forward(self, x: Tensor):
161
+ return (self.block(x) + self.skip(x)) / self.sqrt_2
162
+
163
+
164
+ class AdaResBlock1D(ConvNets):
165
+ def __init__(
166
+ self,
167
+ res_block_channels: int,
168
+ ada_channel_in: int,
169
+ kernel_size=3,
170
+ dilation=(1, 3, 5),
171
+ activation: nn.Module = nn.LeakyReLU(0.1),
172
+ ):
173
+ super().__init__()
174
+
175
+ self.conv_nets = nn.ModuleList(
176
+ [
177
+ self._get_conv_layer(
178
+ i,
179
+ res_block_channels,
180
+ ada_channel_in,
181
+ kernel_size,
182
+ 1,
183
+ dilation,
184
+ )
185
+ for i in range(3)
186
+ ]
187
+ )
188
+ self.conv_nets.apply(self.init_weights)
189
+ self.last_index = len(self.conv_nets) - 1
190
+ self.activation = activation
191
+
192
+ def _get_conv_layer(self, id, ch, ada_ch, k, stride, d):
193
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
194
+ return nn.ModuleDict(
195
+ dict(
196
+ norm1=AdaFusion1D(ada_ch, ch),
197
+ norm2=AdaFusion1D(ada_ch, ch),
198
+ alpha1=nn.Parameter(torch.ones(1, ada_ch, 1)),
199
+ alpha2=nn.Parameter(torch.ones(1, ada_ch, 1)),
200
+ conv1=weight_norm(
201
+ nn.Conv1d(
202
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
203
+ )
204
+ ), # 2
205
+ conv2=weight_norm(
206
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
207
+ ), # 4
208
+ )
209
+ )
210
+
211
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
212
+ for cnn in self.conv_nets:
213
+ xt = self.activation(cnn["norm1"](x, y, cnn["alpha1"]))
214
+ xt = cnn["conv1"](xt)
215
+ xt = self.activation(cnn["norm2"](xt, y, cnn["alpha2"]))
216
+ x = cnn["conv2"](xt) + x
217
+ return x
@@ -11,8 +11,8 @@ from lt_tensor.torch_commons import *
11
11
  from lt_tensor.model_base import Model
12
12
  from lt_utils.misc_utils import default
13
13
  from typing import Optional
14
- from lt_tensor.model_zoo.pos import *
15
- from lt_tensor.model_zoo.bsc import FeedForward
14
+ from lt_tensor.model_zoo.pos_encoder import *
15
+ from lt_tensor.model_zoo.basic import FeedForward
16
16
 
17
17
 
18
18
  def init_weights(module):