lt-tensor 0.0.1a33__tar.gz → 0.0.1a35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/PKG-INFO +1 -1
  2. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/__init__.py +1 -1
  3. lt_tensor-0.0.1a35/lt_tensor/losses.py +281 -0
  4. lt_tensor-0.0.1a35/lt_tensor/lr_schedulers.py +240 -0
  5. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/misc_utils.py +35 -42
  6. lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/__init__.py +3 -0
  7. lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
  8. {lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch → lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free}/act.py +8 -6
  9. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
  10. lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/__init__.py +3 -0
  11. lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
  12. lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +189 -0
  13. lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/istft/__init__.py +216 -0
  14. lt_tensor-0.0.1a35/lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
  15. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/convs.py +21 -32
  16. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/losses/discriminators.py +143 -230
  17. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/PKG-INFO +1 -1
  18. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/SOURCES.txt +7 -4
  19. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/setup.py +1 -1
  20. lt_tensor-0.0.1a33/lt_tensor/losses.py +0 -159
  21. lt_tensor-0.0.1a33/lt_tensor/lr_schedulers.py +0 -114
  22. lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
  23. lt_tensor-0.0.1a33/lt_tensor/model_zoo/audio_models/__init__.py +0 -3
  24. lt_tensor-0.0.1a33/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +0 -520
  25. lt_tensor-0.0.1a33/lt_tensor/model_zoo/audio_models/istft/__init__.py +0 -551
  26. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/LICENSE +0 -0
  27. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/README.md +0 -0
  28. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/config_templates.py +0 -0
  29. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/math_ops.py +0 -0
  30. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_base.py +0 -0
  31. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/__init__.py +0 -0
  32. {lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch → lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free}/filter.py +0 -0
  33. {lt_tensor-0.0.1a33/lt_tensor/model_zoo/activations/alias_free_torch → lt_tensor-0.0.1a35/lt_tensor/model_zoo/activations/alias_free}/resample.py +0 -0
  34. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
  35. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/basic.py +0 -0
  36. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/features.py +0 -0
  37. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/fusion.py +0 -0
  38. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/losses/__init__.py +0 -0
  39. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  40. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/residual.py +0 -0
  41. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/model_zoo/transformer.py +0 -0
  42. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/monotonic_align.py +0 -0
  43. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/noise_tools.py +0 -0
  44. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/processors/__init__.py +0 -0
  45. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/processors/audio.py +0 -0
  46. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/torch_commons.py +0 -0
  47. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor/transform.py +0 -0
  48. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/dependency_links.txt +0 -0
  49. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/requires.txt +0 -0
  50. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/lt_tensor.egg-info/top_level.txt +0 -0
  51. {lt_tensor-0.0.1a33 → lt_tensor-0.0.1a35}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a33
3
+ Version: 0.0.1a35
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a33"
1
+ __version__ = "0.0.1a35"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
@@ -0,0 +1,281 @@
1
+ __all__ = [
2
+ "masked_cross_entropy",
3
+ "adaptive_l1_loss",
4
+ "contrastive_loss",
5
+ "smooth_l1_loss",
6
+ "hybrid_loss",
7
+ "diff_loss",
8
+ "cosine_loss",
9
+ "ft_n_loss",
10
+ "MultiMelScaleLoss",
11
+ ]
12
+ import math
13
+ import random
14
+ from lt_tensor.torch_commons import *
15
+ from lt_utils.common import *
16
+ import torch.nn.functional as F
17
+ from lt_tensor.model_base import Model
18
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
19
+ from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
20
+
21
+
22
+ def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
23
+ if weight is not None:
24
+ return torch.mean((torch.abs(output - target) + weight) ** 0.5)
25
+ return torch.mean(torch.abs(output - target) ** 0.5)
26
+
27
+
28
+ def adaptive_l1_loss(
29
+ inp: Tensor,
30
+ tgt: Tensor,
31
+ weight: Optional[Tensor] = None,
32
+ scale: float = 1.0,
33
+ inverted: bool = False,
34
+ ):
35
+
36
+ if weight is not None:
37
+ loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
38
+ else:
39
+ loss = torch.mean(torch.abs(inp - tgt))
40
+ loss *= scale
41
+ if inverted:
42
+ return -loss
43
+ return loss
44
+
45
+
46
+ def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
47
+ diff = torch.abs(inp - tgt)
48
+ loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
49
+ if weight is not None:
50
+ loss *= weight
51
+ return loss.mean()
52
+
53
+
54
+ def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
55
+ # label == 1: similar, label == 0: dissimilar
56
+ dist = torch.nn.functional.pairwise_distance(x1, x2)
57
+ loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
58
+ return loss.mean()
59
+
60
+
61
+ def cosine_loss(inp, tgt):
62
+ cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
63
+ return 1 - cos.mean() # Lower is better
64
+
65
+
66
+ def masked_cross_entropy(
67
+ logits: torch.Tensor, # [B, T, V]
68
+ targets: torch.Tensor, # [B, T]
69
+ lengths: torch.Tensor, # [B]
70
+ reduction: str = "mean",
71
+ ) -> torch.Tensor:
72
+ """
73
+ CrossEntropyLoss with masking for variable-length sequences.
74
+ - logits: unnormalized scores [B, T, V]
75
+ - targets: ground truth indices [B, T]
76
+ - lengths: actual sequence lengths [B]
77
+ """
78
+ B, T, V = logits.size()
79
+ logits = logits.view(-1, V)
80
+ targets = targets.view(-1)
81
+
82
+ # Create mask
83
+ mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
84
+ mask = mask.reshape(-1)
85
+
86
+ # Apply CE only where mask == True
87
+ loss = F.cross_entropy(
88
+ logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
89
+ )
90
+ if reduction == "none":
91
+ return loss
92
+ return loss
93
+
94
+
95
+ def diff_loss(pred_noise, true_noise, mask=None):
96
+ """Standard diffusion noise-prediction loss (e.g., DDPM)"""
97
+ if mask is not None:
98
+ return F.mse_loss(pred_noise * mask, true_noise * mask)
99
+ return F.mse_loss(pred_noise, true_noise)
100
+
101
+
102
+ def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
103
+ """Combines L1 and L2"""
104
+ l1 = F.l1_loss(pred_noise, true_noise)
105
+ l2 = F.mse_loss(pred_noise, true_noise)
106
+ return alpha * l1 + (1 - alpha) * l2
107
+
108
+
109
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
110
+ loss = 0
111
+ for real, fake in zip(real_preds, fake_preds):
112
+ if use_lsgan:
113
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
114
+ fake, torch.zeros_like(fake)
115
+ )
116
+ else:
117
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
118
+ torch.log(1 - fake + 1e-7)
119
+ )
120
+ return loss
121
+
122
+
123
+ class MultiMelScaleLoss(Model):
124
+ def __init__(
125
+ self,
126
+ sample_rate: int,
127
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
128
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
129
+ n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
130
+ hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
131
+ f_min: float = [0, 0, 0, 0, 0, 0, 0],
132
+ f_max: Optional[float] = [None, None, None, None, None, None, None],
133
+ loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
134
+ loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
135
+ loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
136
+ center: bool = True,
137
+ power: float = 1.0,
138
+ normalized: bool = False,
139
+ pad_mode: str = "reflect",
140
+ onesided: Optional[bool] = None,
141
+ std: int = 4,
142
+ mean: int = -4,
143
+ use_istft_norm: bool = True,
144
+ use_pitch_loss: bool = True,
145
+ use_rms_loss: bool = True,
146
+ norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_minmax,
147
+ norm_rms_fn: Callable[[Tensor], Tensor] = normalize_zscore,
148
+ lambda_mel: float = 1.0,
149
+ lambda_rms: float = 1.0,
150
+ lambda_pitch: float = 1.0,
151
+ weight: float = 1.0,
152
+ ):
153
+ super().__init__()
154
+ assert (
155
+ len(n_mels)
156
+ == len(window_lengths)
157
+ == len(n_ffts)
158
+ == len(hops)
159
+ == len(f_min)
160
+ == len(f_max)
161
+ )
162
+ self.loss_mel_fn = loss_mel_fn
163
+ self.loss_pitch_fn = loss_pitch_fn
164
+ self.loss_rms_fn = loss_rms_fn
165
+ self.lambda_mel = lambda_mel
166
+ self.weight = weight
167
+ self.use_istft_norm = use_istft_norm
168
+ self.use_pitch_loss = use_pitch_loss
169
+ self.use_rms_loss = use_rms_loss
170
+ self.lambda_pitch = lambda_pitch
171
+ self.lambda_rms = lambda_rms
172
+
173
+ self.norm_pitch_fn = norm_pitch_fn
174
+ self.norm_rms = norm_rms_fn
175
+
176
+ self._setup_mels(
177
+ sample_rate,
178
+ n_mels,
179
+ window_lengths,
180
+ n_ffts,
181
+ hops,
182
+ f_min,
183
+ f_max,
184
+ center,
185
+ power,
186
+ normalized,
187
+ pad_mode,
188
+ onesided,
189
+ std,
190
+ mean,
191
+ )
192
+
193
+ def _setup_mels(
194
+ self,
195
+ sample_rate: int,
196
+ n_mels: List[int],
197
+ window_lengths: List[int],
198
+ n_ffts: List[int],
199
+ hops: List[int],
200
+ f_min: List[float],
201
+ f_max: List[Optional[float]],
202
+ center: bool,
203
+ power: float,
204
+ normalized: bool,
205
+ pad_mode: str,
206
+ onesided: Optional[bool],
207
+ std: int,
208
+ mean: int,
209
+ ):
210
+ assert (
211
+ len(n_mels)
212
+ == len(window_lengths)
213
+ == len(n_ffts)
214
+ == len(hops)
215
+ == len(f_min)
216
+ == len(f_max)
217
+ )
218
+ _mel_kwargs = dict(
219
+ sample_rate=sample_rate,
220
+ center=center,
221
+ onesided=onesided,
222
+ normalized=normalized,
223
+ power=power,
224
+ pad_mode=pad_mode,
225
+ std=std,
226
+ mean=mean,
227
+ )
228
+ self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
229
+ [
230
+ AudioProcessor(
231
+ AudioProcessorConfig(
232
+ **_mel_kwargs,
233
+ n_mels=mel,
234
+ n_fft=n_fft,
235
+ win_length=win,
236
+ hop_length=hop,
237
+ f_min=fmin,
238
+ f_max=fmax,
239
+ )
240
+ )
241
+ for mel, win, n_fft, hop, fmin, fmax in zip(
242
+ n_mels, window_lengths, n_ffts, hops, f_min, f_max
243
+ )
244
+ ]
245
+ )
246
+
247
+ def forward(
248
+ self, input_wave: torch.Tensor, target_wave: torch.Tensor
249
+ ) -> torch.Tensor:
250
+ assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
251
+ target_wave = target_wave.to(input_wave.device)
252
+ losses = 0.0
253
+ for M in self.mel_spectrograms:
254
+ # Apply normalization if requested
255
+ if self.use_istft_norm:
256
+ input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
257
+ target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
258
+ else:
259
+ input_proc, target_proc = input_wave, target_wave
260
+
261
+ x_mels = M(input_proc)
262
+ y_mels = M(target_proc)
263
+
264
+ loss = self.loss_mel_fn(x_mels.squeeze(), y_mels.squeeze())
265
+ losses += loss * self.lambda_mel
266
+
267
+ # pitch/f0 loss
268
+ if self.use_pitch_loss:
269
+ x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
270
+ y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
271
+ f0_loss = self.loss_pitch_fn(x_pitch, y_pitch)
272
+ losses += f0_loss * self.lambda_pitch
273
+
274
+ # energy/rms loss
275
+ if self.use_rms_loss:
276
+ x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
277
+ y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
278
+ rms_loss = self.loss_rms_fn(x_rms, y_rms)
279
+ losses += rms_loss * self.lambda_rms
280
+
281
+ return losses * self.weight
@@ -0,0 +1,240 @@
1
+ __all__ = [
2
+ "WarmupDecayScheduler",
3
+ "AdaptiveDropScheduler",
4
+ "SinusoidalDecayLR",
5
+ "GuidedWaveringLR",
6
+ "FloorExponentialLR",
7
+ ]
8
+
9
+ import math
10
+ from torch.optim import Optimizer
11
+ from torch.optim.lr_scheduler import LRScheduler
12
+ from typing import Optional
13
+ from numbers import Number
14
+ from lt_tensor.misc_utils import update_lr
15
+
16
+
17
+ class WarmupDecayScheduler(LRScheduler):
18
+ def __init__(
19
+ self,
20
+ optimizer: Optimizer,
21
+ warmup_steps: int,
22
+ total_steps: int,
23
+ decay_type: str = "linear", # or "cosine"
24
+ min_lr: float = 0.0,
25
+ last_epoch: int = -1,
26
+ ):
27
+ self.warmup_steps = warmup_steps
28
+ self.total_steps = total_steps
29
+ self.decay_type = decay_type
30
+ self.min_lr = min_lr
31
+ super().__init__(optimizer, last_epoch)
32
+
33
+ def get_lr(self):
34
+ step = self.last_epoch + 1
35
+ warmup = self.warmup_steps
36
+ total = self.total_steps
37
+ lrs = []
38
+
39
+ for base_lr in self.base_lrs:
40
+ if step < warmup:
41
+ lr = base_lr * step / warmup
42
+ else:
43
+ progress = (step - warmup) / max(1, total - warmup)
44
+ if self.decay_type == "linear":
45
+ lr = base_lr * (1.0 - progress)
46
+ elif self.decay_type == "cosine":
47
+ lr = base_lr * 0.5 * (1 + math.cos(math.pi * progress))
48
+ else:
49
+ raise ValueError(f"Unknown decay type: {self.decay_type}")
50
+
51
+ lr = max(self.min_lr, lr)
52
+ lrs.append(lr)
53
+
54
+ return lrs
55
+
56
+
57
+ class AdaptiveDropScheduler(LRScheduler):
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ drop_factor=0.5,
62
+ patience=10,
63
+ min_lr=1e-6,
64
+ cooldown=5,
65
+ last_epoch=-1,
66
+ ):
67
+ self.drop_factor = drop_factor
68
+ self.patience = patience
69
+ self.min_lr = min_lr
70
+ self.cooldown = cooldown
71
+ self.cooldown_counter = 0
72
+ self.best_loss = float("inf")
73
+ self.bad_steps = 0
74
+ super().__init__(optimizer, last_epoch)
75
+
76
+ def step(self, val_loss=None):
77
+ if val_loss is not None:
78
+ if val_loss < self.best_loss:
79
+ self.best_loss = val_loss
80
+ self.bad_steps = 0
81
+ self.cooldown_counter = 0
82
+ else:
83
+ self.bad_steps += 1
84
+ if self.bad_steps >= self.patience and self.cooldown_counter == 0:
85
+ for i, group in enumerate(self.optimizer.param_groups):
86
+ new_lr = max(group["lr"] * self.drop_factor, self.min_lr)
87
+ group["lr"] = new_lr
88
+ self.cooldown_counter = self.cooldown
89
+ self.bad_steps = 0
90
+ if self.cooldown_counter > 0:
91
+ self.cooldown_counter -= 1
92
+
93
+ def get_lr(self):
94
+ return [group["lr"] for group in self.optimizer.param_groups]
95
+
96
+
97
+ class SinusoidalDecayLR(LRScheduler):
98
+ def __init__(
99
+ self,
100
+ optimizer: Optimizer,
101
+ initial_lr: float = 1e-3,
102
+ target_lr: float = 1e-5,
103
+ floor_lr: float = 1e-7,
104
+ decay_rate: float = 1e-6, # decay per period
105
+ wave_amplitude: float = 1e-5,
106
+ period: int = 256,
107
+ last_epoch: int = -1,
108
+ ):
109
+ assert decay_rate != 0.0, "decay_rate must be different from 0.0"
110
+ assert (
111
+ initial_lr >= target_lr >= floor_lr
112
+ ), "Must satisfy: initial_lr ≥ target_lr ≥ floor_lr"
113
+
114
+ self.initial_lr = initial_lr
115
+ self.target_lr = target_lr
116
+ self.floor_lr = floor_lr
117
+ self.decay_rate = decay_rate
118
+ self.wave_amplitude = wave_amplitude
119
+ self.period = period
120
+
121
+ super().__init__(optimizer, last_epoch)
122
+
123
+ def get_lr(self):
124
+ step = self.last_epoch + 1
125
+ cycles = step // self.period
126
+ t = step % self.period
127
+ # Decay center down to target_lr, then freeze
128
+ center_decay = math.exp(-self.decay_rate * cycles)
129
+ center = max(self.target_lr, self.initial_lr * center_decay)
130
+ # Decay amplitude in sync with center (relative to initial)
131
+ amplitude_decay = math.exp(-self.decay_rate * cycles)
132
+ current_amplitude = self.wave_amplitude * self.initial_lr * amplitude_decay
133
+ sin_offset = math.sin(2 * math.pi * t / self.period)
134
+ lr = max(center + current_amplitude * sin_offset, self.floor_lr)
135
+ return [lr for _ in self.optimizer.param_groups]
136
+
137
+
138
+ class GuidedWaveringLR(LRScheduler):
139
+ def __init__(
140
+ self,
141
+ optimizer: Optimizer,
142
+ initial_lr: float = 1e-3,
143
+ target_lr: float = 1e-5,
144
+ floor_lr: float = 1e-7,
145
+ decay_rate: float = 0.01,
146
+ wave_amplitude: float = 0.02,
147
+ period: int = 256,
148
+ stop_decay_after: int = None,
149
+ last_epoch: int = -1,
150
+ ):
151
+ assert decay_rate != 0.0, "decay_rate must be non-zero"
152
+ assert (
153
+ initial_lr >= target_lr >= floor_lr
154
+ ), "Must satisfy: initial ≥ target ≥ floor"
155
+
156
+ self.initial_lr = initial_lr
157
+ self.target_lr = target_lr
158
+ self.floor_lr = floor_lr
159
+ self.decay_rate = decay_rate
160
+ self.wave_amplitude = wave_amplitude
161
+ self.period = period
162
+ self.stop_decay_after = stop_decay_after
163
+
164
+ super().__init__(optimizer, last_epoch)
165
+
166
+ def get_lr(self):
167
+ step = self.last_epoch + 1
168
+ cycles = step // self.period
169
+ t = step % self.period
170
+
171
+ decay_cycles = (
172
+ min(cycles, self.stop_decay_after) if self.stop_decay_after else cycles
173
+ )
174
+ center = max(
175
+ self.target_lr, self.initial_lr * math.exp(-self.decay_rate * decay_cycles)
176
+ )
177
+ amp = (
178
+ self.wave_amplitude
179
+ * self.initial_lr
180
+ * math.exp(-self.decay_rate * decay_cycles)
181
+ )
182
+ phase = 2 * math.pi * t / self.period
183
+ wave = math.sin(phase) * math.cos(phase)
184
+ lr = max(center + amp * wave, self.floor_lr)
185
+ return [lr for _ in self.optimizer.param_groups]
186
+
187
+
188
+ class FloorExponentialLR(LRScheduler):
189
+ """Modified version from exponential lr, to have a minimum and reset functions.
190
+
191
+ Decays the learning rate of each parameter group by gamma every epoch.
192
+
193
+ When last_epoch=-1, sets initial lr as lr.
194
+
195
+ Args:
196
+ optimizer (Optimizer): Wrapped optimizer.
197
+ gamma (float): Multiplicative factor of learning rate decay.
198
+ last_epoch (int): The index of last epoch. Default: -1.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ optimizer: Optimizer,
204
+ initial_lr: float = 1e-4,
205
+ gamma: float = 0.99998,
206
+ last_epoch: int = -1,
207
+ floor_lr: float = 1e-6,
208
+ ):
209
+ self.gamma = gamma
210
+ self.floor_lr = floor_lr
211
+ self.initial_lr = initial_lr
212
+
213
+ super().__init__(optimizer, last_epoch)
214
+
215
+ def set_floor(self, new_value: float):
216
+ assert isinstance(new_value, Number)
217
+ self.floor_lr = new_value
218
+
219
+ def reset_lr(self, new_value: Optional[float] = None):
220
+ new_lr = new_value if isinstance(new_value, Number) else self.initial_lr
221
+ self.initial_lr = new_lr
222
+ update_lr(self.optimizer, new_lr)
223
+
224
+ def get_lr(self):
225
+
226
+ if self.last_epoch == 0:
227
+ return [
228
+ max(group["lr"], self.floor_lr) for group in self.optimizer.param_groups
229
+ ]
230
+
231
+ return [
232
+ max(group["lr"] * self.gamma, self.floor_lr)
233
+ for group in self.optimizer.param_groups
234
+ ]
235
+
236
+ def _get_closed_form_lr(self):
237
+ return [
238
+ max(base_lr * self.gamma**self.last_epoch, self.floor_lr)
239
+ for base_lr in self.base_lrs
240
+ ]
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "plot_view",
25
25
  "get_weights",
26
26
  "get_activated_conv",
27
+ "update_lr",
27
28
  ]
28
29
 
29
30
  import re
@@ -77,6 +78,33 @@ def get_activated_conv(
77
78
  )
78
79
 
79
80
 
81
+ def get_loss_average(losses: List[float]):
82
+ """A little helper for training, for example:
83
+ ```python
84
+ losses = []
85
+ for epoch in range(100):
86
+ for inp, label in dataloader:
87
+ optimizer.zero_grad()
88
+ out = model(inp)
89
+ loss = loss_fn(out, label)
90
+ optimizer.step()
91
+ losses.append(loss.item())
92
+ print(f"Epoch {epoch+1} | Loss: {get_loss_average(losses):.4f}")
93
+ """
94
+ if not losses:
95
+ return float("nan")
96
+ return sum(losses) / len(losses)
97
+
98
+
99
+ def update_lr(optimizer: optim.Optimizer, new_value: float = 1e-4):
100
+ for param_group in optimizer.param_groups:
101
+ if isinstance(param_group["lr"], Tensor):
102
+ param_group["lr"].fill_(new_value)
103
+ else:
104
+ param_group["lr"] = new_value
105
+ return optimizer
106
+
107
+
80
108
  def plot_view(
81
109
  data: Dict[str, List[Any]],
82
110
  title: str = "Loss",
@@ -520,49 +548,14 @@ def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
520
548
  return flat[idx]
521
549
 
522
550
 
523
- class TorchCacheUtils:
524
- cached_shortcuts: dict[str, Callable[[None], None]] = {}
525
-
526
- has_cuda: bool = torch.cuda.is_available()
527
- has_xpu: bool = torch.xpu.is_available()
528
- has_mps: bool = torch.mps.is_available()
529
-
530
- _ignore: list[str] = []
531
-
532
- def __init__(self):
533
- pass
534
-
535
- def _apply_clear(self, device: str):
536
- if device in self._ignore:
537
- gc.collect()
538
- return
539
- try:
540
- clear_fn = self.cached_shortcuts.get(
541
- device, getattr(torch, device).empty_cache
542
- )
543
- if device not in self.cached_shortcuts:
544
- self.cached_shortcuts.update({device: clear_fn})
545
-
546
- except Exception as e:
547
- print(e)
548
- self._ignore.append(device)
549
-
550
- def clear(self):
551
- gc.collect()
552
- if self.has_xpu:
553
- self._apply_clear("xpu")
554
- if self.has_cuda:
555
- self._apply_clear("cuda")
556
- if self.has_mps:
557
- self._apply_clear("mps")
558
- gc.collect()
559
-
560
-
561
- _clear_cache_cls = TorchCacheUtils()
562
-
563
-
564
551
  def clear_cache():
565
- _clear_cache_cls.clear()
552
+ if torch.cuda.is_available():
553
+ torch.cuda.empty_cache()
554
+ if torch.mps.is_available():
555
+ torch.mps.empty_cache()
556
+ if torch.xpu.is_available():
557
+ torch.xpu.empty_cache()
558
+ gc.collect()
566
559
 
567
560
 
568
561
  @cache_wrapper
@@ -0,0 +1,3 @@
1
+ from . import alias_free, snake
2
+
3
+ __all__ = ["snake", "alias_free"]
@@ -0,0 +1,3 @@
1
+ from .act import *
2
+ from .filter import *
3
+ from .resample import *
@@ -1,15 +1,17 @@
1
- import torch
2
1
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- from .resample import UpSample1d, DownSample1d
5
- from .resample import UpSample2d, DownSample2d
2
+ from lt_tensor.model_zoo.activations.alias_free.resample import (
3
+ UpSample2d,
4
+ DownSample2d,
5
+ UpSample1d,
6
+ DownSample1d,
7
+ )
6
8
 
7
9
 
8
10
  class Activation1d(nn.Module):
9
11
 
10
12
  def __init__(
11
13
  self,
12
- activation,
14
+ activation: nn.Module,
13
15
  up_ratio: int = 2,
14
16
  down_ratio: int = 2,
15
17
  up_kernel_size: int = 12,
@@ -34,7 +36,7 @@ class Activation2d(nn.Module):
34
36
 
35
37
  def __init__(
36
38
  self,
37
- activation,
39
+ activation: nn.Module,
38
40
  up_ratio: int = 2,
39
41
  down_ratio: int = 2,
40
42
  up_kernel_size: int = 12,