lt-tensor 0.0.1a10__py3-none-any.whl → 0.0.1a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- from ..torch_commons import *
1
+ from lt_tensor.torch_commons import *
2
2
  import torch.nn.functional as F
3
3
  from lt_tensor.model_base import Model
4
4
  from lt_utils.common import *
@@ -76,20 +76,6 @@ class PeriodDiscriminator(Model):
76
76
  return x.flatten(1, -1), f_map
77
77
 
78
78
 
79
- class MultiPeriodDiscriminator(Model):
80
- def __init__(self, periods=[2, 3, 5, 7, 11]):
81
- super().__init__()
82
-
83
- self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
84
-
85
- def forward(self, x: torch.Tensor):
86
- """
87
- x: (B, T)
88
- Returns: list of tuples of outputs from each period discriminator and the f_map.
89
- """
90
- return [d(x) for d in self.discriminators]
91
-
92
-
93
79
  class ScaleDiscriminator(nn.Module):
94
80
  def __init__(self, use_spectral_norm=False):
95
81
  super().__init__()
@@ -123,11 +109,11 @@ class ScaleDiscriminator(nn.Module):
123
109
 
124
110
 
125
111
  class MultiScaleDiscriminator(Model):
126
- def __init__(self):
112
+ def __init__(self, layers: int = 3):
127
113
  super().__init__()
128
114
  self.pooling = nn.AvgPool1d(4, 2, padding=2)
129
115
  self.discriminators = nn.ModuleList(
130
- [ScaleDiscriminator(i == 0) for i in range(3)]
116
+ [ScaleDiscriminator(i == 0) for i in range(layers)]
131
117
  )
132
118
 
133
119
  def forward(self, x: torch.Tensor):
@@ -136,57 +122,75 @@ class MultiScaleDiscriminator(Model):
136
122
  Returns: list of outputs from each scale discriminator
137
123
  """
138
124
  outputs = []
125
+ features = []
139
126
  for i, d in enumerate(self.discriminators):
140
127
  if i != 0:
141
128
  x = self.pooling(x)
142
- outputs.append(d(x))
143
- return outputs
129
+ out, f_map = d(x)
130
+ outputs.append(out)
131
+ features.append(f_map)
132
+ return outputs, features
144
133
 
145
134
 
146
- class GeneralLossDescriminator(Model):
147
- """TODO: build an unified loss for both mpd and msd here."""
148
-
149
- def __init__(self):
135
+ class MultiPeriodDiscriminator(Model):
136
+ def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]):
150
137
  super().__init__()
151
- self.mpd = MultiPeriodDiscriminator()
152
- self.msd = MultiScaleDiscriminator()
153
- self.print_trainable_parameters()
154
-
155
- def _get_group_(self):
156
- pass
138
+ self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
157
139
 
158
- def forward(self, x: Tensor, y_hat: Tensor):
159
- return
140
+ def forward(self, x: torch.Tensor):
141
+ """
142
+ x: (B, T)
143
+ Returns: list of tuples of outputs from each period discriminator and the f_map.
144
+ """
145
+ # torch.log(torch.clip(x, min=clip_val))
146
+ out_map = []
147
+ feat_map = []
148
+ for d in self.discriminators:
149
+ out, feat = d(x)
150
+ out_map.append(out)
151
+ feat_map.append(feat)
152
+ return out_map, feat_map
160
153
 
161
154
 
162
- def discriminator_loss(d_outputs_real, d_outputs_fake):
155
+ def discriminator_loss(real_out_map, fake_out_map):
163
156
  loss = 0.0
164
- for real_out, fake_out in zip(d_outputs_real, d_outputs_fake):
165
- real_score = real_out[0]
166
- fake_score = fake_out[0]
167
- loss += torch.mean(F.relu(1.0 - real_score)) + torch.mean(
168
- F.relu(1.0 + fake_score)
169
- )
170
- return loss
157
+ rl, fl = [], []
158
+ for real_out, fake_out in zip(real_out_map, fake_out_map):
159
+ real_loss = torch.mean((1.0 - real_out) ** 2)
160
+ fake_loss = torch.mean(fake_out**2)
161
+ loss += real_loss + fake_loss
162
+ rl.append(real_loss.item())
163
+ fl.append(fake_loss.item())
164
+ return loss, sum(rl), sum(fl)
171
165
 
172
166
 
173
- def generator_adv_loss(d_outputs_fake):
167
+ def generator_adv_loss(fake_disc_outputs: List[Tensor]):
174
168
  loss = 0.0
175
- for fake_out in d_outputs_fake:
169
+ for fake_out in fake_disc_outputs:
176
170
  fake_score = fake_out[0]
177
171
  loss += -torch.mean(fake_score)
178
172
  return loss
179
173
 
180
174
 
181
- def feature_matching_loss(
182
- d_outputs_real,
183
- d_outputs_fake,
184
- loss_fn: Callable[[Tensor, Tensor], Tensor] = F.mse_loss,
175
+ def feature_loss(
176
+ fmap_r,
177
+ fmap_g,
178
+ weight=2.0,
179
+ loss_fn: Callable[[Tensor, Tensor], Tensor] = F.l1_loss,
185
180
  ):
186
181
  loss = 0.0
187
- for real_out, fake_out in zip(d_outputs_real, d_outputs_fake):
188
- real_feats = real_out[1]
189
- fake_feats = fake_out[1]
190
- for real_f, fake_f in zip(real_feats, fake_feats):
191
- loss += loss_fn(fake_f, real_f)
192
- return loss
182
+ for dr, dg in zip(fmap_r, fmap_g):
183
+ for rl, gl in zip(dr, dg):
184
+ loss += loss_fn(rl - gl)
185
+ return loss * weight
186
+
187
+
188
+ def generator_loss(disc_generated_outputs):
189
+ loss = 0.0
190
+ gen_losses = []
191
+ for dg in disc_generated_outputs:
192
+ l = torch.mean((1.0 - dg) ** 2)
193
+ gen_losses.append(l.item())
194
+ loss += l
195
+
196
+ return loss, gen_losses
@@ -6,8 +6,8 @@ __all__ = [
6
6
  "GatedFusion",
7
7
  ]
8
8
 
9
- from ..torch_commons import *
10
- from ..model_base import Model
9
+ from lt_tensor.torch_commons import *
10
+ from lt_tensor.model_base import Model
11
11
 
12
12
 
13
13
  class ConcatFusion(Model):
@@ -7,10 +7,10 @@ __all__ = [
7
7
  "NoisePredictor1D",
8
8
  ]
9
9
 
10
- from ..torch_commons import *
11
- from ..model_base import Model
12
- from .rsd import ResBlock1D, ResBlocks
13
- from ..misc_utils import log_tensor
10
+ from lt_tensor.torch_commons import *
11
+ from lt_tensor.model_base import Model
12
+ from lt_tensor.model_zoo.rsd import ResBlock1D
13
+ from lt_tensor.misc_utils import log_tensor
14
14
 
15
15
  import torch.nn.functional as F
16
16
 
@@ -0,0 +1,5 @@
1
+ from .generator import iSTFTGenerator
2
+ from . import trainer
3
+
4
+
5
+ __all__ = ["iSTFTGenerator", "trainer"]
@@ -0,0 +1,150 @@
1
+ __all__ = ["iSTFTGenerator", "ResBlocks"]
2
+ import gc
3
+ import math
4
+ import itertools
5
+ from lt_utils.common import *
6
+ from lt_tensor.torch_commons import *
7
+ from lt_tensor.model_base import Model
8
+ from lt_tensor.misc_utils import log_tensor
9
+ from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
10
+ from lt_utils.misc_utils import log_traceback
11
+ from lt_tensor.processors import AudioProcessor
12
+ from lt_utils.type_utils import is_dir, is_pathlike
13
+ from lt_tensor.misc_utils import set_seed, clear_cache
14
+ from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
15
+ import torch.nn.functional as F
16
+ from lt_tensor.config_templates import updateDict, ModelConfig
17
+
18
+
19
+ class ResBlocks(ConvNets):
20
+ def __init__(
21
+ self,
22
+ channels: int,
23
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
24
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
25
+ [1, 3, 5],
26
+ [1, 3, 5],
27
+ [1, 3, 5],
28
+ ],
29
+ activation: nn.Module = nn.LeakyReLU(0.1),
30
+ ):
31
+ super().__init__()
32
+ self.num_kernels = len(resblock_kernel_sizes)
33
+ self.rb = nn.ModuleList()
34
+ self.activation = activation
35
+
36
+ for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
37
+ self.rb.append(ResBlock1D(channels, k, j, activation))
38
+
39
+ self.rb.apply(self.init_weights)
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ xs = None
43
+ for i, block in enumerate(self.rb):
44
+ if i == 0:
45
+ xs = block(x)
46
+ else:
47
+ xs += block(x)
48
+ x = xs / self.num_kernels
49
+ return self.activation(x)
50
+
51
+
52
+ class iSTFTGenerator(ConvNets):
53
+ def __init__(
54
+ self,
55
+ in_channels: int = 80,
56
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
57
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
58
+ upsample_initial_channel: int = 512,
59
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
60
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
61
+ [1, 3, 5],
62
+ [1, 3, 5],
63
+ [1, 3, 5],
64
+ ],
65
+ n_fft: int = 16,
66
+ activation: nn.Module = nn.LeakyReLU(0.1),
67
+ hop_length: int = 256,
68
+ ):
69
+ super().__init__()
70
+ self.num_kernels = len(resblock_kernel_sizes)
71
+ self.num_upsamples = len(upsample_rates)
72
+ self.hop_length = hop_length
73
+ self.conv_pre = weight_norm(
74
+ nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
75
+ )
76
+ self.blocks = nn.ModuleList()
77
+ self.activation = activation
78
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
79
+ self.blocks.append(
80
+ self._make_blocks(
81
+ (i, k, u),
82
+ upsample_initial_channel,
83
+ resblock_kernel_sizes,
84
+ resblock_dilation_sizes,
85
+ )
86
+ )
87
+
88
+ ch = upsample_initial_channel // (2 ** (i + 1))
89
+ self.post_n_fft = n_fft // 2 + 1
90
+ self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
91
+ self.conv_post.apply(self.init_weights)
92
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
93
+
94
+ self.phase = nn.Sequential(
95
+ nn.LeakyReLU(0.2),
96
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
97
+ nn.LeakyReLU(0.2),
98
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
99
+ )
100
+ self.spec = nn.Sequential(
101
+ nn.LeakyReLU(0.2),
102
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
103
+ nn.LeakyReLU(0.2),
104
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
105
+ )
106
+
107
+ def _make_blocks(
108
+ self,
109
+ state: Tuple[int, int, int],
110
+ upsample_initial_channel: int,
111
+ resblock_kernel_sizes: List[Union[int, List[int]]],
112
+ resblock_dilation_sizes: List[int | List[int]],
113
+ ):
114
+ i, k, u = state
115
+ channels = upsample_initial_channel // (2 ** (i + 1))
116
+ return nn.ModuleDict(
117
+ dict(
118
+ up=nn.Sequential(
119
+ self.activation,
120
+ weight_norm(
121
+ nn.ConvTranspose1d(
122
+ upsample_initial_channel // (2**i),
123
+ channels,
124
+ k,
125
+ u,
126
+ padding=(k - u) // 2,
127
+ )
128
+ ).apply(self.init_weights),
129
+ ),
130
+ residual=ResBlocks(
131
+ channels,
132
+ resblock_kernel_sizes,
133
+ resblock_dilation_sizes,
134
+ self.activation,
135
+ ),
136
+ )
137
+ )
138
+
139
+ def forward(self, x):
140
+ x = self.conv_pre(x)
141
+ for block in self.blocks:
142
+ x = block["up"](x)
143
+ x = block["residual"](x)
144
+
145
+ x = self.reflection_pad(x)
146
+ x = self.conv_post(x)
147
+ spec = torch.exp(self.spec(x[:, : self.post_n_fft, :]))
148
+ phase = torch.sin(self.phase(x[:, self.post_n_fft :, :]))
149
+
150
+ return spec, phase