lt-tensor 0.0.1a10__py3-none-any.whl → 0.0.1a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +2 -0
- lt_tensor/config_templates.py +97 -0
- lt_tensor/datasets/audio.py +21 -7
- lt_tensor/losses.py +98 -84
- lt_tensor/math_ops.py +1 -1
- lt_tensor/misc_utils.py +94 -7
- lt_tensor/model_base.py +298 -128
- lt_tensor/model_zoo/__init__.py +2 -2
- lt_tensor/model_zoo/bsc.py +25 -3
- lt_tensor/model_zoo/disc.py +55 -51
- lt_tensor/model_zoo/fsn.py +2 -2
- lt_tensor/model_zoo/gns.py +4 -4
- lt_tensor/model_zoo/istft/__init__.py +5 -0
- lt_tensor/model_zoo/istft/generator.py +150 -0
- lt_tensor/model_zoo/istft/trainer.py +450 -0
- lt_tensor/model_zoo/istft.py +508 -66
- lt_tensor/model_zoo/pos.py +2 -2
- lt_tensor/model_zoo/rsd.py +16 -146
- lt_tensor/model_zoo/tfrms.py +4 -4
- lt_tensor/noise_tools.py +3 -4
- lt_tensor/processors/audio.py +87 -16
- lt_tensor/transform.py +30 -61
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/METADATA +3 -2
- lt_tensor-0.0.1a12.dist-info/RECORD +32 -0
- lt_tensor-0.0.1a10.dist-info/RECORD +0 -28
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/top_level.txt +0 -0
lt_tensor/model_zoo/disc.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from
|
1
|
+
from lt_tensor.torch_commons import *
|
2
2
|
import torch.nn.functional as F
|
3
3
|
from lt_tensor.model_base import Model
|
4
4
|
from lt_utils.common import *
|
@@ -76,20 +76,6 @@ class PeriodDiscriminator(Model):
|
|
76
76
|
return x.flatten(1, -1), f_map
|
77
77
|
|
78
78
|
|
79
|
-
class MultiPeriodDiscriminator(Model):
|
80
|
-
def __init__(self, periods=[2, 3, 5, 7, 11]):
|
81
|
-
super().__init__()
|
82
|
-
|
83
|
-
self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
|
84
|
-
|
85
|
-
def forward(self, x: torch.Tensor):
|
86
|
-
"""
|
87
|
-
x: (B, T)
|
88
|
-
Returns: list of tuples of outputs from each period discriminator and the f_map.
|
89
|
-
"""
|
90
|
-
return [d(x) for d in self.discriminators]
|
91
|
-
|
92
|
-
|
93
79
|
class ScaleDiscriminator(nn.Module):
|
94
80
|
def __init__(self, use_spectral_norm=False):
|
95
81
|
super().__init__()
|
@@ -123,11 +109,11 @@ class ScaleDiscriminator(nn.Module):
|
|
123
109
|
|
124
110
|
|
125
111
|
class MultiScaleDiscriminator(Model):
|
126
|
-
def __init__(self):
|
112
|
+
def __init__(self, layers: int = 3):
|
127
113
|
super().__init__()
|
128
114
|
self.pooling = nn.AvgPool1d(4, 2, padding=2)
|
129
115
|
self.discriminators = nn.ModuleList(
|
130
|
-
[ScaleDiscriminator(i == 0) for i in range(
|
116
|
+
[ScaleDiscriminator(i == 0) for i in range(layers)]
|
131
117
|
)
|
132
118
|
|
133
119
|
def forward(self, x: torch.Tensor):
|
@@ -136,57 +122,75 @@ class MultiScaleDiscriminator(Model):
|
|
136
122
|
Returns: list of outputs from each scale discriminator
|
137
123
|
"""
|
138
124
|
outputs = []
|
125
|
+
features = []
|
139
126
|
for i, d in enumerate(self.discriminators):
|
140
127
|
if i != 0:
|
141
128
|
x = self.pooling(x)
|
142
|
-
|
143
|
-
|
129
|
+
out, f_map = d(x)
|
130
|
+
outputs.append(out)
|
131
|
+
features.append(f_map)
|
132
|
+
return outputs, features
|
144
133
|
|
145
134
|
|
146
|
-
class
|
147
|
-
|
148
|
-
|
149
|
-
def __init__(self):
|
135
|
+
class MultiPeriodDiscriminator(Model):
|
136
|
+
def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]):
|
150
137
|
super().__init__()
|
151
|
-
self.
|
152
|
-
self.msd = MultiScaleDiscriminator()
|
153
|
-
self.print_trainable_parameters()
|
154
|
-
|
155
|
-
def _get_group_(self):
|
156
|
-
pass
|
138
|
+
self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
|
157
139
|
|
158
|
-
def forward(self, x: Tensor
|
159
|
-
|
140
|
+
def forward(self, x: torch.Tensor):
|
141
|
+
"""
|
142
|
+
x: (B, T)
|
143
|
+
Returns: list of tuples of outputs from each period discriminator and the f_map.
|
144
|
+
"""
|
145
|
+
# torch.log(torch.clip(x, min=clip_val))
|
146
|
+
out_map = []
|
147
|
+
feat_map = []
|
148
|
+
for d in self.discriminators:
|
149
|
+
out, feat = d(x)
|
150
|
+
out_map.append(out)
|
151
|
+
feat_map.append(feat)
|
152
|
+
return out_map, feat_map
|
160
153
|
|
161
154
|
|
162
|
-
def discriminator_loss(
|
155
|
+
def discriminator_loss(real_out_map, fake_out_map):
|
163
156
|
loss = 0.0
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
)
|
170
|
-
|
157
|
+
rl, fl = [], []
|
158
|
+
for real_out, fake_out in zip(real_out_map, fake_out_map):
|
159
|
+
real_loss = torch.mean((1.0 - real_out) ** 2)
|
160
|
+
fake_loss = torch.mean(fake_out**2)
|
161
|
+
loss += real_loss + fake_loss
|
162
|
+
rl.append(real_loss.item())
|
163
|
+
fl.append(fake_loss.item())
|
164
|
+
return loss, sum(rl), sum(fl)
|
171
165
|
|
172
166
|
|
173
|
-
def generator_adv_loss(
|
167
|
+
def generator_adv_loss(fake_disc_outputs: List[Tensor]):
|
174
168
|
loss = 0.0
|
175
|
-
for fake_out in
|
169
|
+
for fake_out in fake_disc_outputs:
|
176
170
|
fake_score = fake_out[0]
|
177
171
|
loss += -torch.mean(fake_score)
|
178
172
|
return loss
|
179
173
|
|
180
174
|
|
181
|
-
def
|
182
|
-
|
183
|
-
|
184
|
-
|
175
|
+
def feature_loss(
|
176
|
+
fmap_r,
|
177
|
+
fmap_g,
|
178
|
+
weight=2.0,
|
179
|
+
loss_fn: Callable[[Tensor, Tensor], Tensor] = F.l1_loss,
|
185
180
|
):
|
186
181
|
loss = 0.0
|
187
|
-
for
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
182
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
183
|
+
for rl, gl in zip(dr, dg):
|
184
|
+
loss += loss_fn(rl - gl)
|
185
|
+
return loss * weight
|
186
|
+
|
187
|
+
|
188
|
+
def generator_loss(disc_generated_outputs):
|
189
|
+
loss = 0.0
|
190
|
+
gen_losses = []
|
191
|
+
for dg in disc_generated_outputs:
|
192
|
+
l = torch.mean((1.0 - dg) ** 2)
|
193
|
+
gen_losses.append(l.item())
|
194
|
+
loss += l
|
195
|
+
|
196
|
+
return loss, gen_losses
|
lt_tensor/model_zoo/fsn.py
CHANGED
lt_tensor/model_zoo/gns.py
CHANGED
@@ -7,10 +7,10 @@ __all__ = [
|
|
7
7
|
"NoisePredictor1D",
|
8
8
|
]
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from .rsd import ResBlock1D
|
13
|
-
from
|
10
|
+
from lt_tensor.torch_commons import *
|
11
|
+
from lt_tensor.model_base import Model
|
12
|
+
from lt_tensor.model_zoo.rsd import ResBlock1D
|
13
|
+
from lt_tensor.misc_utils import log_tensor
|
14
14
|
|
15
15
|
import torch.nn.functional as F
|
16
16
|
|
@@ -0,0 +1,150 @@
|
|
1
|
+
__all__ = ["iSTFTGenerator", "ResBlocks"]
|
2
|
+
import gc
|
3
|
+
import math
|
4
|
+
import itertools
|
5
|
+
from lt_utils.common import *
|
6
|
+
from lt_tensor.torch_commons import *
|
7
|
+
from lt_tensor.model_base import Model
|
8
|
+
from lt_tensor.misc_utils import log_tensor
|
9
|
+
from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
|
10
|
+
from lt_utils.misc_utils import log_traceback
|
11
|
+
from lt_tensor.processors import AudioProcessor
|
12
|
+
from lt_utils.type_utils import is_dir, is_pathlike
|
13
|
+
from lt_tensor.misc_utils import set_seed, clear_cache
|
14
|
+
from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
15
|
+
import torch.nn.functional as F
|
16
|
+
from lt_tensor.config_templates import updateDict, ModelConfig
|
17
|
+
|
18
|
+
|
19
|
+
class ResBlocks(ConvNets):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
channels: int,
|
23
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
24
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
25
|
+
[1, 3, 5],
|
26
|
+
[1, 3, 5],
|
27
|
+
[1, 3, 5],
|
28
|
+
],
|
29
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
30
|
+
):
|
31
|
+
super().__init__()
|
32
|
+
self.num_kernels = len(resblock_kernel_sizes)
|
33
|
+
self.rb = nn.ModuleList()
|
34
|
+
self.activation = activation
|
35
|
+
|
36
|
+
for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
37
|
+
self.rb.append(ResBlock1D(channels, k, j, activation))
|
38
|
+
|
39
|
+
self.rb.apply(self.init_weights)
|
40
|
+
|
41
|
+
def forward(self, x: torch.Tensor):
|
42
|
+
xs = None
|
43
|
+
for i, block in enumerate(self.rb):
|
44
|
+
if i == 0:
|
45
|
+
xs = block(x)
|
46
|
+
else:
|
47
|
+
xs += block(x)
|
48
|
+
x = xs / self.num_kernels
|
49
|
+
return self.activation(x)
|
50
|
+
|
51
|
+
|
52
|
+
class iSTFTGenerator(ConvNets):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
in_channels: int = 80,
|
56
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8],
|
57
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
|
58
|
+
upsample_initial_channel: int = 512,
|
59
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
60
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
61
|
+
[1, 3, 5],
|
62
|
+
[1, 3, 5],
|
63
|
+
[1, 3, 5],
|
64
|
+
],
|
65
|
+
n_fft: int = 16,
|
66
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
67
|
+
hop_length: int = 256,
|
68
|
+
):
|
69
|
+
super().__init__()
|
70
|
+
self.num_kernels = len(resblock_kernel_sizes)
|
71
|
+
self.num_upsamples = len(upsample_rates)
|
72
|
+
self.hop_length = hop_length
|
73
|
+
self.conv_pre = weight_norm(
|
74
|
+
nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
75
|
+
)
|
76
|
+
self.blocks = nn.ModuleList()
|
77
|
+
self.activation = activation
|
78
|
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
79
|
+
self.blocks.append(
|
80
|
+
self._make_blocks(
|
81
|
+
(i, k, u),
|
82
|
+
upsample_initial_channel,
|
83
|
+
resblock_kernel_sizes,
|
84
|
+
resblock_dilation_sizes,
|
85
|
+
)
|
86
|
+
)
|
87
|
+
|
88
|
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
89
|
+
self.post_n_fft = n_fft // 2 + 1
|
90
|
+
self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
|
91
|
+
self.conv_post.apply(self.init_weights)
|
92
|
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
93
|
+
|
94
|
+
self.phase = nn.Sequential(
|
95
|
+
nn.LeakyReLU(0.2),
|
96
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
97
|
+
nn.LeakyReLU(0.2),
|
98
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
99
|
+
)
|
100
|
+
self.spec = nn.Sequential(
|
101
|
+
nn.LeakyReLU(0.2),
|
102
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
103
|
+
nn.LeakyReLU(0.2),
|
104
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
105
|
+
)
|
106
|
+
|
107
|
+
def _make_blocks(
|
108
|
+
self,
|
109
|
+
state: Tuple[int, int, int],
|
110
|
+
upsample_initial_channel: int,
|
111
|
+
resblock_kernel_sizes: List[Union[int, List[int]]],
|
112
|
+
resblock_dilation_sizes: List[int | List[int]],
|
113
|
+
):
|
114
|
+
i, k, u = state
|
115
|
+
channels = upsample_initial_channel // (2 ** (i + 1))
|
116
|
+
return nn.ModuleDict(
|
117
|
+
dict(
|
118
|
+
up=nn.Sequential(
|
119
|
+
self.activation,
|
120
|
+
weight_norm(
|
121
|
+
nn.ConvTranspose1d(
|
122
|
+
upsample_initial_channel // (2**i),
|
123
|
+
channels,
|
124
|
+
k,
|
125
|
+
u,
|
126
|
+
padding=(k - u) // 2,
|
127
|
+
)
|
128
|
+
).apply(self.init_weights),
|
129
|
+
),
|
130
|
+
residual=ResBlocks(
|
131
|
+
channels,
|
132
|
+
resblock_kernel_sizes,
|
133
|
+
resblock_dilation_sizes,
|
134
|
+
self.activation,
|
135
|
+
),
|
136
|
+
)
|
137
|
+
)
|
138
|
+
|
139
|
+
def forward(self, x):
|
140
|
+
x = self.conv_pre(x)
|
141
|
+
for block in self.blocks:
|
142
|
+
x = block["up"](x)
|
143
|
+
x = block["residual"](x)
|
144
|
+
|
145
|
+
x = self.reflection_pad(x)
|
146
|
+
x = self.conv_post(x)
|
147
|
+
spec = torch.exp(self.spec(x[:, : self.post_n_fft, :]))
|
148
|
+
phase = torch.sin(self.phase(x[:, self.post_n_fft :, :]))
|
149
|
+
|
150
|
+
return spec, phase
|