lt-tensor 0.0.1a37__tar.gz → 0.0.1a38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/PKG-INFO +1 -1
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/__init__.py +1 -1
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/losses.py +10 -4
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +68 -81
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/convs.py +25 -16
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/processors/audio.py +2 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor.egg-info/PKG-INFO +1 -1
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/setup.py +1 -1
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/LICENSE +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/README.md +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/config_templates.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/lr_schedulers.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/math_ops.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/misc_utils.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_base.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/act.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/filter.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/resample.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/istft/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/resblocks.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/basic.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/features.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/fusion.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/losses/CQT/transforms.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/losses/CQT/utils.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/losses/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/losses/discriminators.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/pos_encoder.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/residual.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/transformer.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/monotonic_align.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/noise_tools.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/processors/__init__.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/torch_commons.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/transform.py +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor.egg-info/SOURCES.txt +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor.egg-info/dependency_links.txt +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor.egg-info/requires.txt +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor.egg-info/top_level.txt +0 -0
- {lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/setup.cfg +0 -0
@@ -133,7 +133,7 @@ class MultiMelScaleLoss(Model):
|
|
133
133
|
loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
134
134
|
loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
135
135
|
loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
136
|
-
center: bool =
|
136
|
+
center: bool = False,
|
137
137
|
power: float = 1.0,
|
138
138
|
normalized: bool = False,
|
139
139
|
pad_mode: str = "reflect",
|
@@ -149,6 +149,7 @@ class MultiMelScaleLoss(Model):
|
|
149
149
|
lambda_rms: float = 1.0,
|
150
150
|
lambda_pitch: float = 1.0,
|
151
151
|
weight: float = 1.0,
|
152
|
+
mel: Literal["librosa", "torch"] = "torch",
|
152
153
|
):
|
153
154
|
super().__init__()
|
154
155
|
assert (
|
@@ -188,6 +189,7 @@ class MultiMelScaleLoss(Model):
|
|
188
189
|
onesided,
|
189
190
|
std,
|
190
191
|
mean,
|
192
|
+
mel,
|
191
193
|
)
|
192
194
|
|
193
195
|
def _setup_mels(
|
@@ -206,6 +208,7 @@ class MultiMelScaleLoss(Model):
|
|
206
208
|
onesided: Optional[bool],
|
207
209
|
std: int,
|
208
210
|
mean: int,
|
211
|
+
mel: str,
|
209
212
|
):
|
210
213
|
assert (
|
211
214
|
len(n_mels)
|
@@ -224,6 +227,7 @@ class MultiMelScaleLoss(Model):
|
|
224
227
|
pad_mode=pad_mode,
|
225
228
|
std=std,
|
226
229
|
mean=mean,
|
230
|
+
mel_default=mel,
|
227
231
|
)
|
228
232
|
self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
|
229
233
|
[
|
@@ -247,12 +251,14 @@ class MultiMelScaleLoss(Model):
|
|
247
251
|
def forward(
|
248
252
|
self, input_wave: torch.Tensor, target_wave: torch.Tensor
|
249
253
|
) -> torch.Tensor:
|
250
|
-
assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
|
254
|
+
assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1], (
|
255
|
+
f"Size mismatch! input_wave {input_wave.shape[-1]} must match target_wave: {target_wave.shape[-1]}. "
|
256
|
+
"Alternatively 'use_istft_norm' can be set to Trie with will automatically force the audio to that size."
|
257
|
+
)
|
251
258
|
target_wave = target_wave.to(input_wave.device)
|
252
259
|
losses = 0.0
|
253
260
|
for M in self.mel_spectrograms:
|
254
|
-
|
255
|
-
if self.use_istft_norm:
|
261
|
+
if self.use_istft_norm and input_proc.shape[-1] != target_proc.shape[-1]:
|
256
262
|
input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
|
257
263
|
target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
|
258
264
|
else:
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py
RENAMED
@@ -1,14 +1,15 @@
|
|
1
|
-
__all__ = ["DiffWave", "DiffWaveConfig", "
|
1
|
+
__all__ = ["DiffWave", "DiffWaveConfig", "SpectrogramUpsampler", "DiffusionEmbedding"]
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
from lt_tensor.torch_commons import *
|
5
5
|
from torch.nn import functional as F
|
6
6
|
from lt_tensor.config_templates import ModelConfig
|
7
7
|
from lt_tensor.torch_commons import *
|
8
|
-
from lt_tensor.model_zoo.convs import ConvNets,
|
8
|
+
from lt_tensor.model_zoo.convs import ConvNets, ConvEXT
|
9
9
|
from lt_tensor.model_base import Model
|
10
10
|
from math import sqrt
|
11
11
|
from lt_utils.common import *
|
12
|
+
from lt_tensor.misc_utils import log_tensor
|
12
13
|
|
13
14
|
|
14
15
|
class DiffWaveConfig(ModelConfig):
|
@@ -21,12 +22,8 @@ class DiffWaveConfig(ModelConfig):
|
|
21
22
|
unconditional = False
|
22
23
|
apply_norm: Optional[Literal["weight", "spectral"]] = None
|
23
24
|
apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None
|
24
|
-
noise_schedule: list[int] = np.linspace(1e-4, 0.05,
|
25
|
+
noise_schedule: list[int] = np.linspace(1e-4, 0.05, 25).tolist()
|
25
26
|
# settings for auto-fixes
|
26
|
-
interpolate = False
|
27
|
-
interpolation_mode: Literal[
|
28
|
-
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
29
|
-
] = "nearest"
|
30
27
|
|
31
28
|
def __init__(
|
32
29
|
self,
|
@@ -37,16 +34,6 @@ class DiffWaveConfig(ModelConfig):
|
|
37
34
|
dilation_cycle_length=10,
|
38
35
|
unconditional=False,
|
39
36
|
noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist(),
|
40
|
-
interpolate_cond=False,
|
41
|
-
interpolation_mode: Literal[
|
42
|
-
"nearest",
|
43
|
-
"linear",
|
44
|
-
"bilinear",
|
45
|
-
"bicubic",
|
46
|
-
"trilinear",
|
47
|
-
"area",
|
48
|
-
"nearest-exact",
|
49
|
-
] = "nearest",
|
50
37
|
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
51
38
|
apply_norm_resblock: Optional[Literal["weight", "spectral"]] = None,
|
52
39
|
):
|
@@ -58,8 +45,6 @@ class DiffWaveConfig(ModelConfig):
|
|
58
45
|
"residual_channels": residual_channels,
|
59
46
|
"unconditional": unconditional,
|
60
47
|
"noise_schedule": noise_schedule,
|
61
|
-
"interpolate": interpolate_cond,
|
62
|
-
"interpolation_mode": interpolation_mode,
|
63
48
|
"apply_norm": apply_norm,
|
64
49
|
"apply_norm_resblock": apply_norm_resblock,
|
65
50
|
}
|
@@ -102,19 +87,34 @@ class DiffusionEmbedding(Model):
|
|
102
87
|
return table
|
103
88
|
|
104
89
|
|
105
|
-
class
|
90
|
+
class SpectrogramUpsampler(Model):
|
106
91
|
def __init__(self):
|
107
92
|
super().__init__()
|
108
|
-
self.
|
109
|
-
|
110
|
-
|
93
|
+
self.conv_net = nn.Sequential(
|
94
|
+
ConvEXT(
|
95
|
+
1,
|
96
|
+
1,
|
97
|
+
[3, 32],
|
98
|
+
stride=[1, 16],
|
99
|
+
padding=[1, 8],
|
100
|
+
module_type="2d",
|
101
|
+
transpose=True,
|
102
|
+
),
|
103
|
+
nn.LeakyReLU(0.1),
|
104
|
+
ConvEXT(
|
105
|
+
1,
|
106
|
+
1,
|
107
|
+
[3, 32],
|
108
|
+
stride=[1, 16],
|
109
|
+
padding=[1, 8],
|
110
|
+
module_type="2d",
|
111
|
+
transpose=True,
|
112
|
+
),
|
113
|
+
nn.LeakyReLU(0.1),
|
114
|
+
)
|
111
115
|
|
112
|
-
def forward(self, x):
|
113
|
-
|
114
|
-
x = self.activation(self.conv1(x))
|
115
|
-
x = self.activation(self.conv2(x))
|
116
|
-
x = torch.squeeze(x, 1)
|
117
|
-
return x
|
116
|
+
def forward(self, x: Tensor):
|
117
|
+
return self.conv_net(x.unsqueeze(0)).squeeze(1)
|
118
118
|
|
119
119
|
|
120
120
|
class ResidualBlock(Model):
|
@@ -133,7 +133,7 @@ class ResidualBlock(Model):
|
|
133
133
|
:param uncond: disable spectrogram conditional
|
134
134
|
"""
|
135
135
|
super().__init__()
|
136
|
-
self.dilated_conv =
|
136
|
+
self.dilated_conv = ConvEXT(
|
137
137
|
residual_channels,
|
138
138
|
2 * residual_channels,
|
139
139
|
3,
|
@@ -142,18 +142,18 @@ class ResidualBlock(Model):
|
|
142
142
|
apply_norm=apply_norm,
|
143
143
|
)
|
144
144
|
self.diffusion_projection = nn.Linear(512, residual_channels)
|
145
|
-
|
146
|
-
|
145
|
+
self.uncoditional = uncond
|
146
|
+
self.conditioner_projection = None
|
147
|
+
if not uncond:
|
148
|
+
self.conditioner_projection = ConvEXT(
|
147
149
|
n_mels,
|
148
150
|
2 * residual_channels,
|
149
151
|
1,
|
150
152
|
apply_norm=apply_norm,
|
151
153
|
)
|
152
|
-
else: # unconditional model
|
153
|
-
self.conditioner_projection = None
|
154
154
|
|
155
|
-
self.output_projection =
|
156
|
-
residual_channels, 2 * residual_channels, 1, apply_norm
|
155
|
+
self.output_projection = ConvEXT(
|
156
|
+
residual_channels, 2 * residual_channels, 1, apply_norm=apply_norm
|
157
157
|
)
|
158
158
|
|
159
159
|
def forward(
|
@@ -164,20 +164,15 @@ class ResidualBlock(Model):
|
|
164
164
|
):
|
165
165
|
|
166
166
|
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
167
|
-
y = x + diffusion_step
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
y = self.dilated_conv(y)
|
172
|
-
else:
|
173
|
-
conditioner = self.conditioner_projection(conditioner)
|
174
|
-
y = self.dilated_conv(y) + conditioner
|
175
|
-
|
176
|
-
gate, filter = torch.chunk(y, 2, dim=1)
|
177
|
-
y = torch.sigmoid(gate) * torch.tanh(filter)
|
167
|
+
y = (x + diffusion_step).squeeze(1)
|
168
|
+
y = self.dilated_conv(y)
|
169
|
+
if not self.uncoditional and conditioner is not None:
|
170
|
+
y = y + self.conditioner_projection(conditioner)
|
178
171
|
|
172
|
+
gate, _filter = y.chunk(2, dim=1)
|
173
|
+
y = gate.sigmoid() * _filter.tanh()
|
179
174
|
y = self.output_projection(y)
|
180
|
-
residual, skip =
|
175
|
+
residual, skip = y.chunk(2, dim=1)
|
181
176
|
return (x + residual) / sqrt(2.0), skip
|
182
177
|
|
183
178
|
|
@@ -186,19 +181,17 @@ class DiffWave(Model):
|
|
186
181
|
super().__init__()
|
187
182
|
self.params = params
|
188
183
|
self.n_hop = self.params.hop_samples
|
189
|
-
self.
|
190
|
-
self.interpolate_mode = self.params.interpolation_mode
|
191
|
-
self.input_projection = Conv1dEXT(
|
184
|
+
self.input_projection = ConvEXT(
|
192
185
|
in_channels=1,
|
193
186
|
out_channels=params.residual_channels,
|
194
187
|
kernel_size=1,
|
195
188
|
apply_norm=self.params.apply_norm,
|
189
|
+
activation_out=nn.LeakyReLU(0.1),
|
196
190
|
)
|
197
191
|
self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
|
198
|
-
|
199
|
-
self.
|
200
|
-
|
201
|
-
self.spectrogram_upsample = SpectrogramUpsample()
|
192
|
+
self.spectrogram_upsampler = (
|
193
|
+
SpectrogramUpsampler() if not self.params.unconditional else None
|
194
|
+
)
|
202
195
|
|
203
196
|
self.residual_layers = nn.ModuleList(
|
204
197
|
[
|
@@ -212,18 +205,18 @@ class DiffWave(Model):
|
|
212
205
|
for i in range(params.residual_layers)
|
213
206
|
]
|
214
207
|
)
|
215
|
-
self.skip_projection =
|
208
|
+
self.skip_projection = ConvEXT(
|
216
209
|
in_channels=params.residual_channels,
|
217
210
|
out_channels=params.residual_channels,
|
218
211
|
kernel_size=1,
|
219
212
|
apply_norm=self.params.apply_norm,
|
213
|
+
activation_out=nn.LeakyReLU(0.1),
|
220
214
|
)
|
221
|
-
self.output_projection =
|
222
|
-
params.residual_channels, 1, 1, apply_norm=self.params.apply_norm
|
215
|
+
self.output_projection = ConvEXT(
|
216
|
+
params.residual_channels, 1, 1, apply_norm=self.params.apply_norm, init_weights=True,
|
223
217
|
)
|
224
218
|
self.activation = nn.LeakyReLU(0.1)
|
225
|
-
self.
|
226
|
-
nn.init.zeros_(self.output_projection.weight)
|
219
|
+
self._res_d = sqrt(len(self.residual_layers))
|
227
220
|
|
228
221
|
def forward(
|
229
222
|
self,
|
@@ -231,31 +224,25 @@ class DiffWave(Model):
|
|
231
224
|
diffusion_step: Tensor,
|
232
225
|
spectrogram: Optional[Tensor] = None,
|
233
226
|
):
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
227
|
+
if not self.params.unconditional:
|
228
|
+
assert spectrogram is not None
|
229
|
+
if audio.ndim < 3:
|
230
|
+
if audio.ndim == 2:
|
231
|
+
audio = audio.unsqueeze(1)
|
232
|
+
else:
|
233
|
+
audio = audio.unsqueeze(0).unsqueeze(0)
|
238
234
|
|
235
|
+
x = self.input_projection(audio)
|
239
236
|
diffusion_step = self.diffusion_embedding(diffusion_step)
|
240
|
-
if
|
241
|
-
|
242
|
-
# a little heavy, but helps a lot to fix mismatched shapes,
|
243
|
-
# not always recommended due to data loss
|
244
|
-
spectrogram = F.interpolate(
|
245
|
-
input=spectrogram,
|
246
|
-
size=int(T * self.n_hop),
|
247
|
-
mode=self.interpolate_mode,
|
248
|
-
)
|
249
|
-
spectrogram = self.spectrogram_upsample(spectrogram)
|
237
|
+
if not self.params.unconditional: # use conditional model
|
238
|
+
spectrogram = self.spectrogram_upsampler(spectrogram)
|
250
239
|
|
251
|
-
skip =
|
240
|
+
skip = torch.zeros_like(x, device=x.device)
|
252
241
|
for i, layer in enumerate(self.residual_layers):
|
253
242
|
x, skip_connection = layer(x, diffusion_step, spectrogram)
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
x = skip / self.r_sqrt
|
259
|
-
x = self.activation(self.skip_projection(x))
|
243
|
+
skip += skip_connection
|
244
|
+
|
245
|
+
x = skip / self._res_d
|
246
|
+
x = self.skip_projection(x)
|
260
247
|
x = self.output_projection(x)
|
261
248
|
return x
|
@@ -1,4 +1,4 @@
|
|
1
|
-
__all__ = ["ConvNets", "
|
1
|
+
__all__ = ["ConvNets", "ConvEXT"]
|
2
2
|
import math
|
3
3
|
from lt_utils.common import *
|
4
4
|
import torch.nn.functional as F
|
@@ -6,6 +6,7 @@ from lt_tensor.torch_commons import *
|
|
6
6
|
from lt_tensor.model_base import Model
|
7
7
|
from lt_tensor.misc_utils import log_tensor
|
8
8
|
from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
|
9
|
+
from lt_utils.misc_utils import default
|
9
10
|
|
10
11
|
|
11
12
|
def spectral_norm_select(module: nn.Module, enabled: bool):
|
@@ -52,10 +53,7 @@ class ConvNets(Model):
|
|
52
53
|
m.weight.data.normal_(mean, std)
|
53
54
|
|
54
55
|
|
55
|
-
class
|
56
|
-
|
57
|
-
# TODO: Use this module to replace all that are using normalizations, mostly those in `audio_models`
|
58
|
-
|
56
|
+
class ConvEXT(ConvNets):
|
59
57
|
def __init__(
|
60
58
|
self,
|
61
59
|
in_channels: int,
|
@@ -72,6 +70,10 @@ class Conv1dEXT(ConvNets):
|
|
72
70
|
apply_norm: Optional[Literal["weight", "spectral"]] = None,
|
73
71
|
activation_in: nn.Module = nn.Identity(),
|
74
72
|
activation_out: nn.Module = nn.Identity(),
|
73
|
+
module_type: Literal["1d", "2d", "3d"] = "1d",
|
74
|
+
transpose: bool = False,
|
75
|
+
weight_init: Optional[Callable[[nn.Module], None]] = None,
|
76
|
+
init_weights: bool = True,
|
75
77
|
*args,
|
76
78
|
**kwargs,
|
77
79
|
):
|
@@ -91,23 +93,30 @@ class Conv1dEXT(ConvNets):
|
|
91
93
|
device=device,
|
92
94
|
dtype=dtype,
|
93
95
|
)
|
96
|
+
match module_type.lower():
|
97
|
+
case "1d":
|
98
|
+
md = nn.Conv1d if not transpose else nn.ConvTranspose1d
|
99
|
+
case "2d":
|
100
|
+
md = nn.Conv2d if not transpose else nn.ConvTranspose2d
|
101
|
+
case "3d":
|
102
|
+
md = nn.Conv3d if not transpose else nn.ConvTranspose3d
|
103
|
+
case _:
|
104
|
+
raise ValueError(
|
105
|
+
f"module_type {module_type} is not a valid module type! use '1d', '2d' or '3d'"
|
106
|
+
)
|
107
|
+
|
94
108
|
if apply_norm is None:
|
95
|
-
self.cnn =
|
96
|
-
self.has_wn = False
|
109
|
+
self.cnn = md(**cnn_kwargs)
|
97
110
|
else:
|
98
|
-
self.has_wn = True
|
99
111
|
if apply_norm == "spectral":
|
100
|
-
self.cnn = spectral_norm(
|
112
|
+
self.cnn = spectral_norm(md(**cnn_kwargs))
|
101
113
|
else:
|
102
|
-
self.cnn = weight_norm(
|
114
|
+
self.cnn = weight_norm(md(**cnn_kwargs))
|
103
115
|
self.actv_in = activation_in
|
104
116
|
self.actv_out = activation_out
|
105
|
-
|
117
|
+
if init_weights:
|
118
|
+
weight_init = default(weight_init, self.init_weights)
|
119
|
+
self.cnn.apply(weight_init)
|
106
120
|
|
107
121
|
def forward(self, input: Tensor):
|
108
122
|
return self.actv_out(self.cnn(self.actv_in(input)))
|
109
|
-
|
110
|
-
def remove_norms(self, name="weight"):
|
111
|
-
if self.has_wn:
|
112
|
-
remove_norm(self.cnn, name)
|
113
|
-
self.has_wn = False
|
@@ -4,7 +4,7 @@ with open("README.md", "r", encoding="utf-8") as f:
|
|
4
4
|
long_description = f.read()
|
5
5
|
|
6
6
|
setup(
|
7
|
-
version="0.0.
|
7
|
+
version="0.0.1a38",
|
8
8
|
name="lt-tensor",
|
9
9
|
description="General utilities for PyTorch and others. Built for general use.",
|
10
10
|
long_description=long_description,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/__init__.py
RENAMED
File without changes
|
File without changes
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/filter.py
RENAMED
File without changes
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/activations/alias_free/resample.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py
RENAMED
File without changes
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py
RENAMED
File without changes
|
{lt_tensor-0.0.1a37 → lt_tensor-0.0.1a38}/lt_tensor/model_zoo/audio_models/istft/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|