lt-tensor 0.0.1a33__tar.gz → 0.0.1a35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/PKG-INFO +1 -1
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/__init__.py +1 -1
- lt_tensor-0.0.1a35/lt_tensor/losses.py +281 -0
- lt_tensor-0.0.1a35/lt_tensor/lr_schedulers.py +240 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/misc_utils.py +35 -42
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/__init__.py +3 -0
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
- {lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch → lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free}/act.py +8 -6
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/__init__.py +3 -0
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +189 -0
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/istft/__init__.py +216 -0
- lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/convs.py +21 -32
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/losses/discriminators.py +143 -230
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/PKG-INFO +1 -1
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/SOURCES.txt +7 -4
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/setup.py +1 -1
- lt_tensor-0.0.1a33/lt_tensor/losses.py +0 -159
- lt_tensor-0.0.1a33/lt_tensor/lr_schedulers.py +0 -114
- lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
- lt_tensor-0.0.1a33/lt_tensor/model_zoo/audio_models/__init__.py +0 -3
- lt_tensor-0.0.1a33/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +0 -520
- lt_tensor-0.0.1a33/lt_tensor/model_zoo/audio_models/istft/__init__.py +0 -551
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/LICENSE +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/README.md +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/config_templates.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/math_ops.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_base.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/__init__.py +0 -0
- {lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch → lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free}/filter.py +0 -0
- {lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch → lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free}/resample.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/basic.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/features.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/fusion.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/losses/__init__.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/pos_encoder.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/residual.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/transformer.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/monotonic_align.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/noise_tools.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/processors/__init__.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/processors/audio.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/torch_commons.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/transform.py +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/dependency_links.txt +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/requires.txt +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/top_level.txt +0 -0
- {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/setup.cfg +0 -0
@@ -0,0 +1,281 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"masked_cross_entropy",
|
3
|
+
"adaptive_l1_loss",
|
4
|
+
"contrastive_loss",
|
5
|
+
"smooth_l1_loss",
|
6
|
+
"hybrid_loss",
|
7
|
+
"diff_loss",
|
8
|
+
"cosine_loss",
|
9
|
+
"ft_n_loss",
|
10
|
+
"MultiMelScaleLoss",
|
11
|
+
]
|
12
|
+
import math
|
13
|
+
import random
|
14
|
+
from lt_tensor.torch_commons import *
|
15
|
+
from lt_utils.common import *
|
16
|
+
import torch.nn.functional as F
|
17
|
+
from lt_tensor.model_base import Model
|
18
|
+
from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
|
19
|
+
from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
|
20
|
+
|
21
|
+
|
22
|
+
def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
|
23
|
+
if weight is not None:
|
24
|
+
return torch.mean((torch.abs(output - target) + weight) ** 0.5)
|
25
|
+
return torch.mean(torch.abs(output - target) ** 0.5)
|
26
|
+
|
27
|
+
|
28
|
+
def adaptive_l1_loss(
|
29
|
+
inp: Tensor,
|
30
|
+
tgt: Tensor,
|
31
|
+
weight: Optional[Tensor] = None,
|
32
|
+
scale: float = 1.0,
|
33
|
+
inverted: bool = False,
|
34
|
+
):
|
35
|
+
|
36
|
+
if weight is not None:
|
37
|
+
loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
|
38
|
+
else:
|
39
|
+
loss = torch.mean(torch.abs(inp - tgt))
|
40
|
+
loss *= scale
|
41
|
+
if inverted:
|
42
|
+
return -loss
|
43
|
+
return loss
|
44
|
+
|
45
|
+
|
46
|
+
def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
|
47
|
+
diff = torch.abs(inp - tgt)
|
48
|
+
loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
|
49
|
+
if weight is not None:
|
50
|
+
loss *= weight
|
51
|
+
return loss.mean()
|
52
|
+
|
53
|
+
|
54
|
+
def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
|
55
|
+
# label == 1: similar, label == 0: dissimilar
|
56
|
+
dist = torch.nn.functional.pairwise_distance(x1, x2)
|
57
|
+
loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
|
58
|
+
return loss.mean()
|
59
|
+
|
60
|
+
|
61
|
+
def cosine_loss(inp, tgt):
|
62
|
+
cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
|
63
|
+
return 1 - cos.mean() # Lower is better
|
64
|
+
|
65
|
+
|
66
|
+
def masked_cross_entropy(
|
67
|
+
logits: torch.Tensor, # [B, T, V]
|
68
|
+
targets: torch.Tensor, # [B, T]
|
69
|
+
lengths: torch.Tensor, # [B]
|
70
|
+
reduction: str = "mean",
|
71
|
+
) -> torch.Tensor:
|
72
|
+
"""
|
73
|
+
CrossEntropyLoss with masking for variable-length sequences.
|
74
|
+
- logits: unnormalized scores [B, T, V]
|
75
|
+
- targets: ground truth indices [B, T]
|
76
|
+
- lengths: actual sequence lengths [B]
|
77
|
+
"""
|
78
|
+
B, T, V = logits.size()
|
79
|
+
logits = logits.view(-1, V)
|
80
|
+
targets = targets.view(-1)
|
81
|
+
|
82
|
+
# Create mask
|
83
|
+
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
84
|
+
mask = mask.reshape(-1)
|
85
|
+
|
86
|
+
# Apply CE only where mask == True
|
87
|
+
loss = F.cross_entropy(
|
88
|
+
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
89
|
+
)
|
90
|
+
if reduction == "none":
|
91
|
+
return loss
|
92
|
+
return loss
|
93
|
+
|
94
|
+
|
95
|
+
def diff_loss(pred_noise, true_noise, mask=None):
|
96
|
+
"""Standard diffusion noise-prediction loss (e.g., DDPM)"""
|
97
|
+
if mask is not None:
|
98
|
+
return F.mse_loss(pred_noise * mask, true_noise * mask)
|
99
|
+
return F.mse_loss(pred_noise, true_noise)
|
100
|
+
|
101
|
+
|
102
|
+
def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
|
103
|
+
"""Combines L1 and L2"""
|
104
|
+
l1 = F.l1_loss(pred_noise, true_noise)
|
105
|
+
l2 = F.mse_loss(pred_noise, true_noise)
|
106
|
+
return alpha * l1 + (1 - alpha) * l2
|
107
|
+
|
108
|
+
|
109
|
+
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
110
|
+
loss = 0
|
111
|
+
for real, fake in zip(real_preds, fake_preds):
|
112
|
+
if use_lsgan:
|
113
|
+
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
114
|
+
fake, torch.zeros_like(fake)
|
115
|
+
)
|
116
|
+
else:
|
117
|
+
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
118
|
+
torch.log(1 - fake + 1e-7)
|
119
|
+
)
|
120
|
+
return loss
|
121
|
+
|
122
|
+
|
123
|
+
class MultiMelScaleLoss(Model):
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
sample_rate: int,
|
127
|
+
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
128
|
+
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
129
|
+
n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
130
|
+
hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
|
131
|
+
f_min: float = [0, 0, 0, 0, 0, 0, 0],
|
132
|
+
f_max: Optional[float] = [None, None, None, None, None, None, None],
|
133
|
+
loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
134
|
+
loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
135
|
+
loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
136
|
+
center: bool = True,
|
137
|
+
power: float = 1.0,
|
138
|
+
normalized: bool = False,
|
139
|
+
pad_mode: str = "reflect",
|
140
|
+
onesided: Optional[bool] = None,
|
141
|
+
std: int = 4,
|
142
|
+
mean: int = -4,
|
143
|
+
use_istft_norm: bool = True,
|
144
|
+
use_pitch_loss: bool = True,
|
145
|
+
use_rms_loss: bool = True,
|
146
|
+
norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_minmax,
|
147
|
+
norm_rms_fn: Callable[[Tensor], Tensor] = normalize_zscore,
|
148
|
+
lambda_mel: float = 1.0,
|
149
|
+
lambda_rms: float = 1.0,
|
150
|
+
lambda_pitch: float = 1.0,
|
151
|
+
weight: float = 1.0,
|
152
|
+
):
|
153
|
+
super().__init__()
|
154
|
+
assert (
|
155
|
+
len(n_mels)
|
156
|
+
== len(window_lengths)
|
157
|
+
== len(n_ffts)
|
158
|
+
== len(hops)
|
159
|
+
== len(f_min)
|
160
|
+
== len(f_max)
|
161
|
+
)
|
162
|
+
self.loss_mel_fn = loss_mel_fn
|
163
|
+
self.loss_pitch_fn = loss_pitch_fn
|
164
|
+
self.loss_rms_fn = loss_rms_fn
|
165
|
+
self.lambda_mel = lambda_mel
|
166
|
+
self.weight = weight
|
167
|
+
self.use_istft_norm = use_istft_norm
|
168
|
+
self.use_pitch_loss = use_pitch_loss
|
169
|
+
self.use_rms_loss = use_rms_loss
|
170
|
+
self.lambda_pitch = lambda_pitch
|
171
|
+
self.lambda_rms = lambda_rms
|
172
|
+
|
173
|
+
self.norm_pitch_fn = norm_pitch_fn
|
174
|
+
self.norm_rms = norm_rms_fn
|
175
|
+
|
176
|
+
self._setup_mels(
|
177
|
+
sample_rate,
|
178
|
+
n_mels,
|
179
|
+
window_lengths,
|
180
|
+
n_ffts,
|
181
|
+
hops,
|
182
|
+
f_min,
|
183
|
+
f_max,
|
184
|
+
center,
|
185
|
+
power,
|
186
|
+
normalized,
|
187
|
+
pad_mode,
|
188
|
+
onesided,
|
189
|
+
std,
|
190
|
+
mean,
|
191
|
+
)
|
192
|
+
|
193
|
+
def _setup_mels(
|
194
|
+
self,
|
195
|
+
sample_rate: int,
|
196
|
+
n_mels: List[int],
|
197
|
+
window_lengths: List[int],
|
198
|
+
n_ffts: List[int],
|
199
|
+
hops: List[int],
|
200
|
+
f_min: List[float],
|
201
|
+
f_max: List[Optional[float]],
|
202
|
+
center: bool,
|
203
|
+
power: float,
|
204
|
+
normalized: bool,
|
205
|
+
pad_mode: str,
|
206
|
+
onesided: Optional[bool],
|
207
|
+
std: int,
|
208
|
+
mean: int,
|
209
|
+
):
|
210
|
+
assert (
|
211
|
+
len(n_mels)
|
212
|
+
== len(window_lengths)
|
213
|
+
== len(n_ffts)
|
214
|
+
== len(hops)
|
215
|
+
== len(f_min)
|
216
|
+
== len(f_max)
|
217
|
+
)
|
218
|
+
_mel_kwargs = dict(
|
219
|
+
sample_rate=sample_rate,
|
220
|
+
center=center,
|
221
|
+
onesided=onesided,
|
222
|
+
normalized=normalized,
|
223
|
+
power=power,
|
224
|
+
pad_mode=pad_mode,
|
225
|
+
std=std,
|
226
|
+
mean=mean,
|
227
|
+
)
|
228
|
+
self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
|
229
|
+
[
|
230
|
+
AudioProcessor(
|
231
|
+
AudioProcessorConfig(
|
232
|
+
**_mel_kwargs,
|
233
|
+
n_mels=mel,
|
234
|
+
n_fft=n_fft,
|
235
|
+
win_length=win,
|
236
|
+
hop_length=hop,
|
237
|
+
f_min=fmin,
|
238
|
+
f_max=fmax,
|
239
|
+
)
|
240
|
+
)
|
241
|
+
for mel, win, n_fft, hop, fmin, fmax in zip(
|
242
|
+
n_mels, window_lengths, n_ffts, hops, f_min, f_max
|
243
|
+
)
|
244
|
+
]
|
245
|
+
)
|
246
|
+
|
247
|
+
def forward(
|
248
|
+
self, input_wave: torch.Tensor, target_wave: torch.Tensor
|
249
|
+
) -> torch.Tensor:
|
250
|
+
assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
|
251
|
+
target_wave = target_wave.to(input_wave.device)
|
252
|
+
losses = 0.0
|
253
|
+
for M in self.mel_spectrograms:
|
254
|
+
# Apply normalization if requested
|
255
|
+
if self.use_istft_norm:
|
256
|
+
input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
|
257
|
+
target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
|
258
|
+
else:
|
259
|
+
input_proc, target_proc = input_wave, target_wave
|
260
|
+
|
261
|
+
x_mels = M(input_proc)
|
262
|
+
y_mels = M(target_proc)
|
263
|
+
|
264
|
+
loss = self.loss_mel_fn(x_mels.squeeze(), y_mels.squeeze())
|
265
|
+
losses += loss * self.lambda_mel
|
266
|
+
|
267
|
+
# pitch/f0 loss
|
268
|
+
if self.use_pitch_loss:
|
269
|
+
x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
|
270
|
+
y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
|
271
|
+
f0_loss = self.loss_pitch_fn(x_pitch, y_pitch)
|
272
|
+
losses += f0_loss * self.lambda_pitch
|
273
|
+
|
274
|
+
# energy/rms loss
|
275
|
+
if self.use_rms_loss:
|
276
|
+
x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
|
277
|
+
y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
|
278
|
+
rms_loss = self.loss_rms_fn(x_rms, y_rms)
|
279
|
+
losses += rms_loss * self.lambda_rms
|
280
|
+
|
281
|
+
return losses * self.weight
|
@@ -0,0 +1,240 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"WarmupDecayScheduler",
|
3
|
+
"AdaptiveDropScheduler",
|
4
|
+
"SinusoidalDecayLR",
|
5
|
+
"GuidedWaveringLR",
|
6
|
+
"FloorExponentialLR",
|
7
|
+
]
|
8
|
+
|
9
|
+
import math
|
10
|
+
from torch.optim import Optimizer
|
11
|
+
from torch.optim.lr_scheduler import LRScheduler
|
12
|
+
from typing import Optional
|
13
|
+
from numbers import Number
|
14
|
+
from lt_tensor.misc_utils import update_lr
|
15
|
+
|
16
|
+
|
17
|
+
class WarmupDecayScheduler(LRScheduler):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
optimizer: Optimizer,
|
21
|
+
warmup_steps: int,
|
22
|
+
total_steps: int,
|
23
|
+
decay_type: str = "linear", # or "cosine"
|
24
|
+
min_lr: float = 0.0,
|
25
|
+
last_epoch: int = -1,
|
26
|
+
):
|
27
|
+
self.warmup_steps = warmup_steps
|
28
|
+
self.total_steps = total_steps
|
29
|
+
self.decay_type = decay_type
|
30
|
+
self.min_lr = min_lr
|
31
|
+
super().__init__(optimizer, last_epoch)
|
32
|
+
|
33
|
+
def get_lr(self):
|
34
|
+
step = self.last_epoch + 1
|
35
|
+
warmup = self.warmup_steps
|
36
|
+
total = self.total_steps
|
37
|
+
lrs = []
|
38
|
+
|
39
|
+
for base_lr in self.base_lrs:
|
40
|
+
if step < warmup:
|
41
|
+
lr = base_lr * step / warmup
|
42
|
+
else:
|
43
|
+
progress = (step - warmup) / max(1, total - warmup)
|
44
|
+
if self.decay_type == "linear":
|
45
|
+
lr = base_lr * (1.0 - progress)
|
46
|
+
elif self.decay_type == "cosine":
|
47
|
+
lr = base_lr * 0.5 * (1 + math.cos(math.pi * progress))
|
48
|
+
else:
|
49
|
+
raise ValueError(f"Unknown decay type: {self.decay_type}")
|
50
|
+
|
51
|
+
lr = max(self.min_lr, lr)
|
52
|
+
lrs.append(lr)
|
53
|
+
|
54
|
+
return lrs
|
55
|
+
|
56
|
+
|
57
|
+
class AdaptiveDropScheduler(LRScheduler):
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
optimizer,
|
61
|
+
drop_factor=0.5,
|
62
|
+
patience=10,
|
63
|
+
min_lr=1e-6,
|
64
|
+
cooldown=5,
|
65
|
+
last_epoch=-1,
|
66
|
+
):
|
67
|
+
self.drop_factor = drop_factor
|
68
|
+
self.patience = patience
|
69
|
+
self.min_lr = min_lr
|
70
|
+
self.cooldown = cooldown
|
71
|
+
self.cooldown_counter = 0
|
72
|
+
self.best_loss = float("inf")
|
73
|
+
self.bad_steps = 0
|
74
|
+
super().__init__(optimizer, last_epoch)
|
75
|
+
|
76
|
+
def step(self, val_loss=None):
|
77
|
+
if val_loss is not None:
|
78
|
+
if val_loss < self.best_loss:
|
79
|
+
self.best_loss = val_loss
|
80
|
+
self.bad_steps = 0
|
81
|
+
self.cooldown_counter = 0
|
82
|
+
else:
|
83
|
+
self.bad_steps += 1
|
84
|
+
if self.bad_steps >= self.patience and self.cooldown_counter == 0:
|
85
|
+
for i, group in enumerate(self.optimizer.param_groups):
|
86
|
+
new_lr = max(group["lr"] * self.drop_factor, self.min_lr)
|
87
|
+
group["lr"] = new_lr
|
88
|
+
self.cooldown_counter = self.cooldown
|
89
|
+
self.bad_steps = 0
|
90
|
+
if self.cooldown_counter > 0:
|
91
|
+
self.cooldown_counter -= 1
|
92
|
+
|
93
|
+
def get_lr(self):
|
94
|
+
return [group["lr"] for group in self.optimizer.param_groups]
|
95
|
+
|
96
|
+
|
97
|
+
class SinusoidalDecayLR(LRScheduler):
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
optimizer: Optimizer,
|
101
|
+
initial_lr: float = 1e-3,
|
102
|
+
target_lr: float = 1e-5,
|
103
|
+
floor_lr: float = 1e-7,
|
104
|
+
decay_rate: float = 1e-6, # decay per period
|
105
|
+
wave_amplitude: float = 1e-5,
|
106
|
+
period: int = 256,
|
107
|
+
last_epoch: int = -1,
|
108
|
+
):
|
109
|
+
assert decay_rate != 0.0, "decay_rate must be different from 0.0"
|
110
|
+
assert (
|
111
|
+
initial_lr >= target_lr >= floor_lr
|
112
|
+
), "Must satisfy: initial_lr ≥ target_lr ≥ floor_lr"
|
113
|
+
|
114
|
+
self.initial_lr = initial_lr
|
115
|
+
self.target_lr = target_lr
|
116
|
+
self.floor_lr = floor_lr
|
117
|
+
self.decay_rate = decay_rate
|
118
|
+
self.wave_amplitude = wave_amplitude
|
119
|
+
self.period = period
|
120
|
+
|
121
|
+
super().__init__(optimizer, last_epoch)
|
122
|
+
|
123
|
+
def get_lr(self):
|
124
|
+
step = self.last_epoch + 1
|
125
|
+
cycles = step // self.period
|
126
|
+
t = step % self.period
|
127
|
+
# Decay center down to target_lr, then freeze
|
128
|
+
center_decay = math.exp(-self.decay_rate * cycles)
|
129
|
+
center = max(self.target_lr, self.initial_lr * center_decay)
|
130
|
+
# Decay amplitude in sync with center (relative to initial)
|
131
|
+
amplitude_decay = math.exp(-self.decay_rate * cycles)
|
132
|
+
current_amplitude = self.wave_amplitude * self.initial_lr * amplitude_decay
|
133
|
+
sin_offset = math.sin(2 * math.pi * t / self.period)
|
134
|
+
lr = max(center + current_amplitude * sin_offset, self.floor_lr)
|
135
|
+
return [lr for _ in self.optimizer.param_groups]
|
136
|
+
|
137
|
+
|
138
|
+
class GuidedWaveringLR(LRScheduler):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
optimizer: Optimizer,
|
142
|
+
initial_lr: float = 1e-3,
|
143
|
+
target_lr: float = 1e-5,
|
144
|
+
floor_lr: float = 1e-7,
|
145
|
+
decay_rate: float = 0.01,
|
146
|
+
wave_amplitude: float = 0.02,
|
147
|
+
period: int = 256,
|
148
|
+
stop_decay_after: int = None,
|
149
|
+
last_epoch: int = -1,
|
150
|
+
):
|
151
|
+
assert decay_rate != 0.0, "decay_rate must be non-zero"
|
152
|
+
assert (
|
153
|
+
initial_lr >= target_lr >= floor_lr
|
154
|
+
), "Must satisfy: initial ≥ target ≥ floor"
|
155
|
+
|
156
|
+
self.initial_lr = initial_lr
|
157
|
+
self.target_lr = target_lr
|
158
|
+
self.floor_lr = floor_lr
|
159
|
+
self.decay_rate = decay_rate
|
160
|
+
self.wave_amplitude = wave_amplitude
|
161
|
+
self.period = period
|
162
|
+
self.stop_decay_after = stop_decay_after
|
163
|
+
|
164
|
+
super().__init__(optimizer, last_epoch)
|
165
|
+
|
166
|
+
def get_lr(self):
|
167
|
+
step = self.last_epoch + 1
|
168
|
+
cycles = step // self.period
|
169
|
+
t = step % self.period
|
170
|
+
|
171
|
+
decay_cycles = (
|
172
|
+
min(cycles, self.stop_decay_after) if self.stop_decay_after else cycles
|
173
|
+
)
|
174
|
+
center = max(
|
175
|
+
self.target_lr, self.initial_lr * math.exp(-self.decay_rate * decay_cycles)
|
176
|
+
)
|
177
|
+
amp = (
|
178
|
+
self.wave_amplitude
|
179
|
+
* self.initial_lr
|
180
|
+
* math.exp(-self.decay_rate * decay_cycles)
|
181
|
+
)
|
182
|
+
phase = 2 * math.pi * t / self.period
|
183
|
+
wave = math.sin(phase) * math.cos(phase)
|
184
|
+
lr = max(center + amp * wave, self.floor_lr)
|
185
|
+
return [lr for _ in self.optimizer.param_groups]
|
186
|
+
|
187
|
+
|
188
|
+
class FloorExponentialLR(LRScheduler):
|
189
|
+
"""Modified version from exponential lr, to have a minimum and reset functions.
|
190
|
+
|
191
|
+
Decays the learning rate of each parameter group by gamma every epoch.
|
192
|
+
|
193
|
+
When last_epoch=-1, sets initial lr as lr.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
optimizer (Optimizer): Wrapped optimizer.
|
197
|
+
gamma (float): Multiplicative factor of learning rate decay.
|
198
|
+
last_epoch (int): The index of last epoch. Default: -1.
|
199
|
+
"""
|
200
|
+
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
optimizer: Optimizer,
|
204
|
+
initial_lr: float = 1e-4,
|
205
|
+
gamma: float = 0.99998,
|
206
|
+
last_epoch: int = -1,
|
207
|
+
floor_lr: float = 1e-6,
|
208
|
+
):
|
209
|
+
self.gamma = gamma
|
210
|
+
self.floor_lr = floor_lr
|
211
|
+
self.initial_lr = initial_lr
|
212
|
+
|
213
|
+
super().__init__(optimizer, last_epoch)
|
214
|
+
|
215
|
+
def set_floor(self, new_value: float):
|
216
|
+
assert isinstance(new_value, Number)
|
217
|
+
self.floor_lr = new_value
|
218
|
+
|
219
|
+
def reset_lr(self, new_value: Optional[float] = None):
|
220
|
+
new_lr = new_value if isinstance(new_value, Number) else self.initial_lr
|
221
|
+
self.initial_lr = new_lr
|
222
|
+
update_lr(self.optimizer, new_lr)
|
223
|
+
|
224
|
+
def get_lr(self):
|
225
|
+
|
226
|
+
if self.last_epoch == 0:
|
227
|
+
return [
|
228
|
+
max(group["lr"], self.floor_lr) for group in self.optimizer.param_groups
|
229
|
+
]
|
230
|
+
|
231
|
+
return [
|
232
|
+
max(group["lr"] * self.gamma, self.floor_lr)
|
233
|
+
for group in self.optimizer.param_groups
|
234
|
+
]
|
235
|
+
|
236
|
+
def _get_closed_form_lr(self):
|
237
|
+
return [
|
238
|
+
max(base_lr * self.gamma**self.last_epoch, self.floor_lr)
|
239
|
+
for base_lr in self.base_lrs
|
240
|
+
]
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
24
24
|
"plot_view",
|
25
25
|
"get_weights",
|
26
26
|
"get_activated_conv",
|
27
|
+
"update_lr",
|
27
28
|
]
|
28
29
|
|
29
30
|
import re
|
@@ -77,6 +78,33 @@ def get_activated_conv(
|
|
77
78
|
)
|
78
79
|
|
79
80
|
|
81
|
+
def get_loss_average(losses: List[float]):
|
82
|
+
"""A little helper for training, for example:
|
83
|
+
```python
|
84
|
+
losses = []
|
85
|
+
for epoch in range(100):
|
86
|
+
for inp, label in dataloader:
|
87
|
+
optimizer.zero_grad()
|
88
|
+
out = model(inp)
|
89
|
+
loss = loss_fn(out, label)
|
90
|
+
optimizer.step()
|
91
|
+
losses.append(loss.item())
|
92
|
+
print(f"Epoch {epoch+1} | Loss: {get_loss_average(losses):.4f}")
|
93
|
+
"""
|
94
|
+
if not losses:
|
95
|
+
return float("nan")
|
96
|
+
return sum(losses) / len(losses)
|
97
|
+
|
98
|
+
|
99
|
+
def update_lr(optimizer: optim.Optimizer, new_value: float = 1e-4):
|
100
|
+
for param_group in optimizer.param_groups:
|
101
|
+
if isinstance(param_group["lr"], Tensor):
|
102
|
+
param_group["lr"].fill_(new_value)
|
103
|
+
else:
|
104
|
+
param_group["lr"] = new_value
|
105
|
+
return optimizer
|
106
|
+
|
107
|
+
|
80
108
|
def plot_view(
|
81
109
|
data: Dict[str, List[Any]],
|
82
110
|
title: str = "Loss",
|
@@ -520,49 +548,14 @@ def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
|
|
520
548
|
return flat[idx]
|
521
549
|
|
522
550
|
|
523
|
-
class TorchCacheUtils:
|
524
|
-
cached_shortcuts: dict[str, Callable[[None], None]] = {}
|
525
|
-
|
526
|
-
has_cuda: bool = torch.cuda.is_available()
|
527
|
-
has_xpu: bool = torch.xpu.is_available()
|
528
|
-
has_mps: bool = torch.mps.is_available()
|
529
|
-
|
530
|
-
_ignore: list[str] = []
|
531
|
-
|
532
|
-
def __init__(self):
|
533
|
-
pass
|
534
|
-
|
535
|
-
def _apply_clear(self, device: str):
|
536
|
-
if device in self._ignore:
|
537
|
-
gc.collect()
|
538
|
-
return
|
539
|
-
try:
|
540
|
-
clear_fn = self.cached_shortcuts.get(
|
541
|
-
device, getattr(torch, device).empty_cache
|
542
|
-
)
|
543
|
-
if device not in self.cached_shortcuts:
|
544
|
-
self.cached_shortcuts.update({device: clear_fn})
|
545
|
-
|
546
|
-
except Exception as e:
|
547
|
-
print(e)
|
548
|
-
self._ignore.append(device)
|
549
|
-
|
550
|
-
def clear(self):
|
551
|
-
gc.collect()
|
552
|
-
if self.has_xpu:
|
553
|
-
self._apply_clear("xpu")
|
554
|
-
if self.has_cuda:
|
555
|
-
self._apply_clear("cuda")
|
556
|
-
if self.has_mps:
|
557
|
-
self._apply_clear("mps")
|
558
|
-
gc.collect()
|
559
|
-
|
560
|
-
|
561
|
-
_clear_cache_cls = TorchCacheUtils()
|
562
|
-
|
563
|
-
|
564
551
|
def clear_cache():
|
565
|
-
|
552
|
+
if torch.cuda.is_available():
|
553
|
+
torch.cuda.empty_cache()
|
554
|
+
if torch.mps.is_available():
|
555
|
+
torch.mps.empty_cache()
|
556
|
+
if torch.xpu.is_available():
|
557
|
+
torch.xpu.empty_cache()
|
558
|
+
gc.collect()
|
566
559
|
|
567
560
|
|
568
561
|
@cache_wrapper
|
@@ -1,15 +1,17 @@
|
|
1
|
-
import torch
|
2
1
|
import torch.nn as nn
|
3
|
-
|
4
|
-
|
5
|
-
|
2
|
+
from lt_tensor.model_zoo.activations.alias_free.resample import (
|
3
|
+
UpSample2d,
|
4
|
+
DownSample2d,
|
5
|
+
UpSample1d,
|
6
|
+
DownSample1d,
|
7
|
+
)
|
6
8
|
|
7
9
|
|
8
10
|
class Activation1d(nn.Module):
|
9
11
|
|
10
12
|
def __init__(
|
11
13
|
self,
|
12
|
-
activation,
|
14
|
+
activation: nn.Module,
|
13
15
|
up_ratio: int = 2,
|
14
16
|
down_ratio: int = 2,
|
15
17
|
up_kernel_size: int = 12,
|
@@ -34,7 +36,7 @@ class Activation2d(nn.Module):
|
|
34
36
|
|
35
37
|
def __init__(
|
36
38
|
self,
|
37
|
-
activation,
|
39
|
+
activation: nn.Module,
|
38
40
|
up_ratio: int = 2,
|
39
41
|
down_ratio: int = 2,
|
40
42
|
up_kernel_size: int = 12,
|