lt-tensor 0.0.1a34__py3-none-any.whl → 0.0.1a35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/losses.py +11 -7
- lt_tensor/lr_schedulers.py +147 -21
- lt_tensor/misc_utils.py +35 -42
- lt_tensor/model_zoo/activations/__init__.py +3 -0
- lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
- lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
- lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
- lt_tensor/model_zoo/audio_models/__init__.py +2 -2
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +16 -347
- lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
- lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
- lt_tensor/model_zoo/convs.py +21 -32
- lt_tensor/model_zoo/losses/discriminators.py +143 -37
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a35.dist-info}/METADATA +1 -1
- lt_tensor-0.0.1a35.dist-info/RECORD +40 -0
- lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
- lt_tensor-0.0.1a34.dist-info/RECORD +0 -37
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a35.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a35.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a34.dist-info → lt_tensor-0.0.1a35.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
lt_tensor/losses.py
CHANGED
@@ -130,7 +130,9 @@ class MultiMelScaleLoss(Model):
|
|
130
130
|
hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
|
131
131
|
f_min: float = [0, 0, 0, 0, 0, 0, 0],
|
132
132
|
f_max: Optional[float] = [None, None, None, None, None, None, None],
|
133
|
-
|
133
|
+
loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
134
|
+
loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
135
|
+
loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
134
136
|
center: bool = True,
|
135
137
|
power: float = 1.0,
|
136
138
|
normalized: bool = False,
|
@@ -141,8 +143,8 @@ class MultiMelScaleLoss(Model):
|
|
141
143
|
use_istft_norm: bool = True,
|
142
144
|
use_pitch_loss: bool = True,
|
143
145
|
use_rms_loss: bool = True,
|
144
|
-
norm_pitch_fn: Callable[[Tensor], Tensor] =
|
145
|
-
norm_rms_fn: Callable[[Tensor], Tensor] =
|
146
|
+
norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_minmax,
|
147
|
+
norm_rms_fn: Callable[[Tensor], Tensor] = normalize_zscore,
|
146
148
|
lambda_mel: float = 1.0,
|
147
149
|
lambda_rms: float = 1.0,
|
148
150
|
lambda_pitch: float = 1.0,
|
@@ -157,7 +159,9 @@ class MultiMelScaleLoss(Model):
|
|
157
159
|
== len(f_min)
|
158
160
|
== len(f_max)
|
159
161
|
)
|
160
|
-
self.
|
162
|
+
self.loss_mel_fn = loss_mel_fn
|
163
|
+
self.loss_pitch_fn = loss_pitch_fn
|
164
|
+
self.loss_rms_fn = loss_rms_fn
|
161
165
|
self.lambda_mel = lambda_mel
|
162
166
|
self.weight = weight
|
163
167
|
self.use_istft_norm = use_istft_norm
|
@@ -257,21 +261,21 @@ class MultiMelScaleLoss(Model):
|
|
257
261
|
x_mels = M(input_proc)
|
258
262
|
y_mels = M(target_proc)
|
259
263
|
|
260
|
-
loss = self.
|
264
|
+
loss = self.loss_mel_fn(x_mels.squeeze(), y_mels.squeeze())
|
261
265
|
losses += loss * self.lambda_mel
|
262
266
|
|
263
267
|
# pitch/f0 loss
|
264
268
|
if self.use_pitch_loss:
|
265
269
|
x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
|
266
270
|
y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
|
267
|
-
f0_loss = self.
|
271
|
+
f0_loss = self.loss_pitch_fn(x_pitch, y_pitch)
|
268
272
|
losses += f0_loss * self.lambda_pitch
|
269
273
|
|
270
274
|
# energy/rms loss
|
271
275
|
if self.use_rms_loss:
|
272
276
|
x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
|
273
277
|
y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
|
274
|
-
rms_loss = self.
|
278
|
+
rms_loss = self.loss_rms_fn(x_rms, y_rms)
|
275
279
|
losses += rms_loss * self.lambda_rms
|
276
280
|
|
277
281
|
return losses * self.weight
|
lt_tensor/lr_schedulers.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1
1
|
__all__ = [
|
2
2
|
"WarmupDecayScheduler",
|
3
3
|
"AdaptiveDropScheduler",
|
4
|
-
"
|
4
|
+
"SinusoidalDecayLR",
|
5
|
+
"GuidedWaveringLR",
|
6
|
+
"FloorExponentialLR",
|
5
7
|
]
|
6
8
|
|
7
9
|
import math
|
8
10
|
from torch.optim import Optimizer
|
9
|
-
from torch.optim.lr_scheduler import
|
11
|
+
from torch.optim.lr_scheduler import LRScheduler
|
12
|
+
from typing import Optional
|
13
|
+
from numbers import Number
|
14
|
+
from lt_tensor.misc_utils import update_lr
|
10
15
|
|
11
16
|
|
12
|
-
class WarmupDecayScheduler(
|
17
|
+
class WarmupDecayScheduler(LRScheduler):
|
13
18
|
def __init__(
|
14
19
|
self,
|
15
20
|
optimizer: Optimizer,
|
@@ -49,7 +54,7 @@ class WarmupDecayScheduler(_LRScheduler):
|
|
49
54
|
return lrs
|
50
55
|
|
51
56
|
|
52
|
-
class AdaptiveDropScheduler(
|
57
|
+
class AdaptiveDropScheduler(LRScheduler):
|
53
58
|
def __init__(
|
54
59
|
self,
|
55
60
|
optimizer,
|
@@ -89,26 +94,147 @@ class AdaptiveDropScheduler(_LRScheduler):
|
|
89
94
|
return [group["lr"] for group in self.optimizer.param_groups]
|
90
95
|
|
91
96
|
|
92
|
-
class
|
97
|
+
class SinusoidalDecayLR(LRScheduler):
|
93
98
|
def __init__(
|
94
|
-
self,
|
99
|
+
self,
|
100
|
+
optimizer: Optimizer,
|
101
|
+
initial_lr: float = 1e-3,
|
102
|
+
target_lr: float = 1e-5,
|
103
|
+
floor_lr: float = 1e-7,
|
104
|
+
decay_rate: float = 1e-6, # decay per period
|
105
|
+
wave_amplitude: float = 1e-5,
|
106
|
+
period: int = 256,
|
107
|
+
last_epoch: int = -1,
|
95
108
|
):
|
96
|
-
""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
self.
|
104
|
-
self.
|
109
|
+
assert decay_rate != 0.0, "decay_rate must be different from 0.0"
|
110
|
+
assert (
|
111
|
+
initial_lr >= target_lr >= floor_lr
|
112
|
+
), "Must satisfy: initial_lr ≥ target_lr ≥ floor_lr"
|
113
|
+
|
114
|
+
self.initial_lr = initial_lr
|
115
|
+
self.target_lr = target_lr
|
116
|
+
self.floor_lr = floor_lr
|
117
|
+
self.decay_rate = decay_rate
|
118
|
+
self.wave_amplitude = wave_amplitude
|
105
119
|
self.period = period
|
106
|
-
|
120
|
+
|
107
121
|
super().__init__(optimizer, last_epoch)
|
108
122
|
|
109
123
|
def get_lr(self):
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
124
|
+
step = self.last_epoch + 1
|
125
|
+
cycles = step // self.period
|
126
|
+
t = step % self.period
|
127
|
+
# Decay center down to target_lr, then freeze
|
128
|
+
center_decay = math.exp(-self.decay_rate * cycles)
|
129
|
+
center = max(self.target_lr, self.initial_lr * center_decay)
|
130
|
+
# Decay amplitude in sync with center (relative to initial)
|
131
|
+
amplitude_decay = math.exp(-self.decay_rate * cycles)
|
132
|
+
current_amplitude = self.wave_amplitude * self.initial_lr * amplitude_decay
|
133
|
+
sin_offset = math.sin(2 * math.pi * t / self.period)
|
134
|
+
lr = max(center + current_amplitude * sin_offset, self.floor_lr)
|
135
|
+
return [lr for _ in self.optimizer.param_groups]
|
136
|
+
|
137
|
+
|
138
|
+
class GuidedWaveringLR(LRScheduler):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
optimizer: Optimizer,
|
142
|
+
initial_lr: float = 1e-3,
|
143
|
+
target_lr: float = 1e-5,
|
144
|
+
floor_lr: float = 1e-7,
|
145
|
+
decay_rate: float = 0.01,
|
146
|
+
wave_amplitude: float = 0.02,
|
147
|
+
period: int = 256,
|
148
|
+
stop_decay_after: int = None,
|
149
|
+
last_epoch: int = -1,
|
150
|
+
):
|
151
|
+
assert decay_rate != 0.0, "decay_rate must be non-zero"
|
152
|
+
assert (
|
153
|
+
initial_lr >= target_lr >= floor_lr
|
154
|
+
), "Must satisfy: initial ≥ target ≥ floor"
|
155
|
+
|
156
|
+
self.initial_lr = initial_lr
|
157
|
+
self.target_lr = target_lr
|
158
|
+
self.floor_lr = floor_lr
|
159
|
+
self.decay_rate = decay_rate
|
160
|
+
self.wave_amplitude = wave_amplitude
|
161
|
+
self.period = period
|
162
|
+
self.stop_decay_after = stop_decay_after
|
163
|
+
|
164
|
+
super().__init__(optimizer, last_epoch)
|
165
|
+
|
166
|
+
def get_lr(self):
|
167
|
+
step = self.last_epoch + 1
|
168
|
+
cycles = step // self.period
|
169
|
+
t = step % self.period
|
170
|
+
|
171
|
+
decay_cycles = (
|
172
|
+
min(cycles, self.stop_decay_after) if self.stop_decay_after else cycles
|
173
|
+
)
|
174
|
+
center = max(
|
175
|
+
self.target_lr, self.initial_lr * math.exp(-self.decay_rate * decay_cycles)
|
176
|
+
)
|
177
|
+
amp = (
|
178
|
+
self.wave_amplitude
|
179
|
+
* self.initial_lr
|
180
|
+
* math.exp(-self.decay_rate * decay_cycles)
|
181
|
+
)
|
182
|
+
phase = 2 * math.pi * t / self.period
|
183
|
+
wave = math.sin(phase) * math.cos(phase)
|
184
|
+
lr = max(center + amp * wave, self.floor_lr)
|
185
|
+
return [lr for _ in self.optimizer.param_groups]
|
186
|
+
|
187
|
+
|
188
|
+
class FloorExponentialLR(LRScheduler):
|
189
|
+
"""Modified version from exponential lr, to have a minimum and reset functions.
|
190
|
+
|
191
|
+
Decays the learning rate of each parameter group by gamma every epoch.
|
192
|
+
|
193
|
+
When last_epoch=-1, sets initial lr as lr.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
optimizer (Optimizer): Wrapped optimizer.
|
197
|
+
gamma (float): Multiplicative factor of learning rate decay.
|
198
|
+
last_epoch (int): The index of last epoch. Default: -1.
|
199
|
+
"""
|
200
|
+
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
optimizer: Optimizer,
|
204
|
+
initial_lr: float = 1e-4,
|
205
|
+
gamma: float = 0.99998,
|
206
|
+
last_epoch: int = -1,
|
207
|
+
floor_lr: float = 1e-6,
|
208
|
+
):
|
209
|
+
self.gamma = gamma
|
210
|
+
self.floor_lr = floor_lr
|
211
|
+
self.initial_lr = initial_lr
|
212
|
+
|
213
|
+
super().__init__(optimizer, last_epoch)
|
214
|
+
|
215
|
+
def set_floor(self, new_value: float):
|
216
|
+
assert isinstance(new_value, Number)
|
217
|
+
self.floor_lr = new_value
|
218
|
+
|
219
|
+
def reset_lr(self, new_value: Optional[float] = None):
|
220
|
+
new_lr = new_value if isinstance(new_value, Number) else self.initial_lr
|
221
|
+
self.initial_lr = new_lr
|
222
|
+
update_lr(self.optimizer, new_lr)
|
223
|
+
|
224
|
+
def get_lr(self):
|
225
|
+
|
226
|
+
if self.last_epoch == 0:
|
227
|
+
return [
|
228
|
+
max(group["lr"], self.floor_lr) for group in self.optimizer.param_groups
|
229
|
+
]
|
230
|
+
|
231
|
+
return [
|
232
|
+
max(group["lr"] * self.gamma, self.floor_lr)
|
233
|
+
for group in self.optimizer.param_groups
|
234
|
+
]
|
235
|
+
|
236
|
+
def _get_closed_form_lr(self):
|
237
|
+
return [
|
238
|
+
max(base_lr * self.gamma**self.last_epoch, self.floor_lr)
|
239
|
+
for base_lr in self.base_lrs
|
240
|
+
]
|
lt_tensor/misc_utils.py
CHANGED
@@ -24,6 +24,7 @@ __all__ = [
|
|
24
24
|
"plot_view",
|
25
25
|
"get_weights",
|
26
26
|
"get_activated_conv",
|
27
|
+
"update_lr",
|
27
28
|
]
|
28
29
|
|
29
30
|
import re
|
@@ -77,6 +78,33 @@ def get_activated_conv(
|
|
77
78
|
)
|
78
79
|
|
79
80
|
|
81
|
+
def get_loss_average(losses: List[float]):
|
82
|
+
"""A little helper for training, for example:
|
83
|
+
```python
|
84
|
+
losses = []
|
85
|
+
for epoch in range(100):
|
86
|
+
for inp, label in dataloader:
|
87
|
+
optimizer.zero_grad()
|
88
|
+
out = model(inp)
|
89
|
+
loss = loss_fn(out, label)
|
90
|
+
optimizer.step()
|
91
|
+
losses.append(loss.item())
|
92
|
+
print(f"Epoch {epoch+1} | Loss: {get_loss_average(losses):.4f}")
|
93
|
+
"""
|
94
|
+
if not losses:
|
95
|
+
return float("nan")
|
96
|
+
return sum(losses) / len(losses)
|
97
|
+
|
98
|
+
|
99
|
+
def update_lr(optimizer: optim.Optimizer, new_value: float = 1e-4):
|
100
|
+
for param_group in optimizer.param_groups:
|
101
|
+
if isinstance(param_group["lr"], Tensor):
|
102
|
+
param_group["lr"].fill_(new_value)
|
103
|
+
else:
|
104
|
+
param_group["lr"] = new_value
|
105
|
+
return optimizer
|
106
|
+
|
107
|
+
|
80
108
|
def plot_view(
|
81
109
|
data: Dict[str, List[Any]],
|
82
110
|
title: str = "Loss",
|
@@ -520,49 +548,14 @@ def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
|
|
520
548
|
return flat[idx]
|
521
549
|
|
522
550
|
|
523
|
-
class TorchCacheUtils:
|
524
|
-
cached_shortcuts: dict[str, Callable[[None], None]] = {}
|
525
|
-
|
526
|
-
has_cuda: bool = torch.cuda.is_available()
|
527
|
-
has_xpu: bool = torch.xpu.is_available()
|
528
|
-
has_mps: bool = torch.mps.is_available()
|
529
|
-
|
530
|
-
_ignore: list[str] = []
|
531
|
-
|
532
|
-
def __init__(self):
|
533
|
-
pass
|
534
|
-
|
535
|
-
def _apply_clear(self, device: str):
|
536
|
-
if device in self._ignore:
|
537
|
-
gc.collect()
|
538
|
-
return
|
539
|
-
try:
|
540
|
-
clear_fn = self.cached_shortcuts.get(
|
541
|
-
device, getattr(torch, device).empty_cache
|
542
|
-
)
|
543
|
-
if device not in self.cached_shortcuts:
|
544
|
-
self.cached_shortcuts.update({device: clear_fn})
|
545
|
-
|
546
|
-
except Exception as e:
|
547
|
-
print(e)
|
548
|
-
self._ignore.append(device)
|
549
|
-
|
550
|
-
def clear(self):
|
551
|
-
gc.collect()
|
552
|
-
if self.has_xpu:
|
553
|
-
self._apply_clear("xpu")
|
554
|
-
if self.has_cuda:
|
555
|
-
self._apply_clear("cuda")
|
556
|
-
if self.has_mps:
|
557
|
-
self._apply_clear("mps")
|
558
|
-
gc.collect()
|
559
|
-
|
560
|
-
|
561
|
-
_clear_cache_cls = TorchCacheUtils()
|
562
|
-
|
563
|
-
|
564
551
|
def clear_cache():
|
565
|
-
|
552
|
+
if torch.cuda.is_available():
|
553
|
+
torch.cuda.empty_cache()
|
554
|
+
if torch.mps.is_available():
|
555
|
+
torch.mps.empty_cache()
|
556
|
+
if torch.xpu.is_available():
|
557
|
+
torch.xpu.empty_cache()
|
558
|
+
gc.collect()
|
566
559
|
|
567
560
|
|
568
561
|
@cache_wrapper
|
@@ -1,15 +1,17 @@
|
|
1
|
-
import torch
|
2
1
|
import torch.nn as nn
|
3
|
-
|
4
|
-
|
5
|
-
|
2
|
+
from lt_tensor.model_zoo.activations.alias_free.resample import (
|
3
|
+
UpSample2d,
|
4
|
+
DownSample2d,
|
5
|
+
UpSample1d,
|
6
|
+
DownSample1d,
|
7
|
+
)
|
6
8
|
|
7
9
|
|
8
10
|
class Activation1d(nn.Module):
|
9
11
|
|
10
12
|
def __init__(
|
11
13
|
self,
|
12
|
-
activation,
|
14
|
+
activation: nn.Module,
|
13
15
|
up_ratio: int = 2,
|
14
16
|
down_ratio: int = 2,
|
15
17
|
up_kernel_size: int = 12,
|
@@ -34,7 +36,7 @@ class Activation2d(nn.Module):
|
|
34
36
|
|
35
37
|
def __init__(
|
36
38
|
self,
|
37
|
-
activation,
|
39
|
+
activation: nn.Module,
|
38
40
|
up_ratio: int = 2,
|
39
41
|
down_ratio: int = 2,
|
40
42
|
up_kernel_size: int = 12,
|
@@ -1,8 +1,7 @@
|
|
1
|
-
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
1
|
+
# Implementation adapted and modified from https://github.com/EdwardDixon/snake under the MIT license.
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch import nn, sin, pow
|
5
|
-
from torch.nn import Parameter
|
6
5
|
|
7
6
|
|
8
7
|
class Snake(nn.Module):
|
@@ -24,10 +23,11 @@ class Snake(nn.Module):
|
|
24
23
|
|
25
24
|
def __init__(
|
26
25
|
self,
|
27
|
-
in_features,
|
28
|
-
alpha=1.0,
|
29
|
-
|
30
|
-
alpha_logscale=False,
|
26
|
+
in_features: int,
|
27
|
+
alpha: float = 1.0,
|
28
|
+
requires_grad: bool = True,
|
29
|
+
alpha_logscale: bool = False,
|
30
|
+
batched: bool = True,
|
31
31
|
):
|
32
32
|
"""
|
33
33
|
Initialization.
|
@@ -37,31 +37,27 @@ class Snake(nn.Module):
|
|
37
37
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
38
38
|
alpha will be trained along with the rest of your model.
|
39
39
|
"""
|
40
|
-
super(
|
40
|
+
super().__init__()
|
41
41
|
self.in_features = in_features
|
42
|
-
|
43
|
-
# initialize alpha
|
44
42
|
self.alpha_logscale = alpha_logscale
|
45
|
-
if self.alpha_logscale
|
46
|
-
|
47
|
-
|
48
|
-
|
43
|
+
param_fn = torch.zeros if self.alpha_logscale else torch.ones
|
44
|
+
_shape = (1, in_features, 1) if batched else (in_features, 1)
|
45
|
+
self.alpha = nn.Parameter(param_fn(_shape) * alpha, requires_grad=requires_grad)
|
46
|
+
self.eps = 1e-8
|
49
47
|
|
50
|
-
|
51
|
-
|
52
|
-
|
48
|
+
def _log_scale(self):
|
49
|
+
if self.alpha_logscale:
|
50
|
+
return self.alpha.exp()
|
51
|
+
return self.alpha
|
53
52
|
|
54
|
-
def forward(self, x):
|
53
|
+
def forward(self, x: torch.Tensor):
|
55
54
|
"""
|
56
55
|
Forward pass of the function.
|
57
56
|
Applies the function to the input elementwise.
|
58
57
|
Snake ∶= x + 1/a * sin^2 (xa)
|
59
58
|
"""
|
60
|
-
alpha = self.
|
61
|
-
|
62
|
-
alpha = torch.exp(alpha)
|
63
|
-
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
64
|
-
|
59
|
+
alpha = self._log_scale()
|
60
|
+
x = x + (1.0 / (alpha + self.eps)) * pow(sin(x * alpha), 2)
|
65
61
|
return x
|
66
62
|
|
67
63
|
|
@@ -84,7 +80,12 @@ class SnakeBeta(nn.Module):
|
|
84
80
|
"""
|
85
81
|
|
86
82
|
def __init__(
|
87
|
-
self,
|
83
|
+
self,
|
84
|
+
in_features: int,
|
85
|
+
alpha: float = 1.0,
|
86
|
+
requires_grad: bool = True,
|
87
|
+
alpha_logscale: bool = False,
|
88
|
+
batched: bool = True,
|
88
89
|
):
|
89
90
|
"""
|
90
91
|
Initialization.
|
@@ -96,34 +97,31 @@ class SnakeBeta(nn.Module):
|
|
96
97
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
97
98
|
alpha will be trained along with the rest of your model.
|
98
99
|
"""
|
99
|
-
super(
|
100
|
+
super().__init__()
|
100
101
|
self.in_features = in_features
|
101
102
|
|
102
103
|
# initialize alpha
|
103
104
|
self.alpha_logscale = alpha_logscale
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
self.alpha
|
112
|
-
self.
|
105
|
+
"""
|
106
|
+
if log scale alphas initialized to zeros
|
107
|
+
else linear scale alphas is initialized to ones
|
108
|
+
"""
|
109
|
+
param_fn = torch.zeros if alpha_logscale else torch.ones
|
110
|
+
_shape = (1, in_features, 1) if batched else (in_features, 1)
|
111
|
+
self.alpha = nn.Parameter(param_fn(_shape) * alpha, requires_grad=requires_grad)
|
112
|
+
self.beta = nn.Parameter(param_fn(_shape) * alpha, requires_grad=requires_grad)
|
113
|
+
self.eps = 1e-8
|
113
114
|
|
114
|
-
|
115
|
+
def _log_scale(self):
|
116
|
+
if self.alpha_logscale:
|
117
|
+
return self.alpha.exp(), self.beta.exp()
|
118
|
+
return self.alpha, self.beta
|
115
119
|
|
116
|
-
def forward(self, x):
|
120
|
+
def forward(self, x: torch.Tensor):
|
117
121
|
"""
|
118
122
|
Forward pass of the function.
|
119
123
|
Applies the function to the input elementwise.
|
120
124
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
121
125
|
"""
|
122
|
-
alpha = self.
|
123
|
-
beta
|
124
|
-
if self.alpha_logscale:
|
125
|
-
alpha = torch.exp(alpha)
|
126
|
-
beta = torch.exp(beta)
|
127
|
-
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
128
|
-
|
129
|
-
return x
|
126
|
+
alpha, beta = self._log_scale()
|
127
|
+
return x + (1.0 / (beta + self.eps)) * pow(sin(x * alpha), 2)
|
@@ -1,3 +1,3 @@
|
|
1
|
-
from . import diffwave, istft, hifigan
|
1
|
+
from . import diffwave, istft, hifigan, bigvgan
|
2
2
|
|
3
|
-
__all__ = ["diffwave", "istft", "hifigan"]
|
3
|
+
__all__ = ["diffwave", "istft", "hifigan", "bigvgan"]
|