lt-tensor 0.0.1a33__py3-none-any.whl → 0.0.1a35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a33"
1
+ __version__ = "0.0.1a35"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
lt_tensor/losses.py CHANGED
@@ -6,19 +6,24 @@ __all__ = [
6
6
  "hybrid_loss",
7
7
  "diff_loss",
8
8
  "cosine_loss",
9
- "gan_loss",
10
9
  "ft_n_loss",
10
+ "MultiMelScaleLoss",
11
11
  ]
12
12
  import math
13
13
  import random
14
14
  from lt_tensor.torch_commons import *
15
15
  from lt_utils.common import *
16
16
  import torch.nn.functional as F
17
+ from lt_tensor.model_base import Model
18
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
19
+ from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
20
+
17
21
 
18
22
  def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
19
23
  if weight is not None:
20
- return torch.mean((torch.abs(output - target) + weight) **0.5)
21
- return torch.mean(torch.abs(output - target)**0.5)
24
+ return torch.mean((torch.abs(output - target) + weight) ** 0.5)
25
+ return torch.mean(torch.abs(output - target) ** 0.5)
26
+
22
27
 
23
28
  def adaptive_l1_loss(
24
29
  inp: Tensor,
@@ -58,50 +63,6 @@ def cosine_loss(inp, tgt):
58
63
  return 1 - cos.mean() # Lower is better
59
64
 
60
65
 
61
- class GanLosses:
62
- @staticmethod
63
- def get_loss(
64
- pred: Tensor,
65
- target_is_real: bool,
66
- loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
67
- ) -> Tensor:
68
- if loss_type == "bce": # Standard GAN
69
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
70
- return F.binary_cross_entropy_with_logits(pred, target)
71
-
72
- elif loss_type == "mse": # LSGAN
73
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
74
- return F.mse_loss(torch.sigmoid(pred), target)
75
-
76
- elif loss_type == "hinge":
77
- if target_is_real:
78
- return torch.mean(F.relu(1.0 - pred))
79
- else:
80
- return torch.mean(F.relu(1.0 + pred))
81
-
82
- elif loss_type == "wasserstein":
83
- return -pred.mean() if target_is_real else pred.mean()
84
-
85
- else:
86
- raise ValueError(f"Unknown loss_type: {loss_type}")
87
-
88
- @staticmethod
89
- def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
90
- return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
91
-
92
- @staticmethod
93
- def discriminator_loss(
94
- real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
95
- ) -> Tensor:
96
- real_loss = GanLosses.get_loss(
97
- real_pred, target_is_real=True, loss_type=loss_type
98
- )
99
- fake_loss = GanLosses.get_loss(
100
- fake_pred.detach(), target_is_real=False, loss_type=loss_type
101
- )
102
- return (real_loss + fake_loss) * 0.5
103
-
104
-
105
66
  def masked_cross_entropy(
106
67
  logits: torch.Tensor, # [B, T, V]
107
68
  targets: torch.Tensor, # [B, T]
@@ -157,3 +118,164 @@ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
157
118
  torch.log(1 - fake + 1e-7)
158
119
  )
159
120
  return loss
121
+
122
+
123
+ class MultiMelScaleLoss(Model):
124
+ def __init__(
125
+ self,
126
+ sample_rate: int,
127
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
128
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
129
+ n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
130
+ hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
131
+ f_min: float = [0, 0, 0, 0, 0, 0, 0],
132
+ f_max: Optional[float] = [None, None, None, None, None, None, None],
133
+ loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
134
+ loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
135
+ loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
136
+ center: bool = True,
137
+ power: float = 1.0,
138
+ normalized: bool = False,
139
+ pad_mode: str = "reflect",
140
+ onesided: Optional[bool] = None,
141
+ std: int = 4,
142
+ mean: int = -4,
143
+ use_istft_norm: bool = True,
144
+ use_pitch_loss: bool = True,
145
+ use_rms_loss: bool = True,
146
+ norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_minmax,
147
+ norm_rms_fn: Callable[[Tensor], Tensor] = normalize_zscore,
148
+ lambda_mel: float = 1.0,
149
+ lambda_rms: float = 1.0,
150
+ lambda_pitch: float = 1.0,
151
+ weight: float = 1.0,
152
+ ):
153
+ super().__init__()
154
+ assert (
155
+ len(n_mels)
156
+ == len(window_lengths)
157
+ == len(n_ffts)
158
+ == len(hops)
159
+ == len(f_min)
160
+ == len(f_max)
161
+ )
162
+ self.loss_mel_fn = loss_mel_fn
163
+ self.loss_pitch_fn = loss_pitch_fn
164
+ self.loss_rms_fn = loss_rms_fn
165
+ self.lambda_mel = lambda_mel
166
+ self.weight = weight
167
+ self.use_istft_norm = use_istft_norm
168
+ self.use_pitch_loss = use_pitch_loss
169
+ self.use_rms_loss = use_rms_loss
170
+ self.lambda_pitch = lambda_pitch
171
+ self.lambda_rms = lambda_rms
172
+
173
+ self.norm_pitch_fn = norm_pitch_fn
174
+ self.norm_rms = norm_rms_fn
175
+
176
+ self._setup_mels(
177
+ sample_rate,
178
+ n_mels,
179
+ window_lengths,
180
+ n_ffts,
181
+ hops,
182
+ f_min,
183
+ f_max,
184
+ center,
185
+ power,
186
+ normalized,
187
+ pad_mode,
188
+ onesided,
189
+ std,
190
+ mean,
191
+ )
192
+
193
+ def _setup_mels(
194
+ self,
195
+ sample_rate: int,
196
+ n_mels: List[int],
197
+ window_lengths: List[int],
198
+ n_ffts: List[int],
199
+ hops: List[int],
200
+ f_min: List[float],
201
+ f_max: List[Optional[float]],
202
+ center: bool,
203
+ power: float,
204
+ normalized: bool,
205
+ pad_mode: str,
206
+ onesided: Optional[bool],
207
+ std: int,
208
+ mean: int,
209
+ ):
210
+ assert (
211
+ len(n_mels)
212
+ == len(window_lengths)
213
+ == len(n_ffts)
214
+ == len(hops)
215
+ == len(f_min)
216
+ == len(f_max)
217
+ )
218
+ _mel_kwargs = dict(
219
+ sample_rate=sample_rate,
220
+ center=center,
221
+ onesided=onesided,
222
+ normalized=normalized,
223
+ power=power,
224
+ pad_mode=pad_mode,
225
+ std=std,
226
+ mean=mean,
227
+ )
228
+ self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
229
+ [
230
+ AudioProcessor(
231
+ AudioProcessorConfig(
232
+ **_mel_kwargs,
233
+ n_mels=mel,
234
+ n_fft=n_fft,
235
+ win_length=win,
236
+ hop_length=hop,
237
+ f_min=fmin,
238
+ f_max=fmax,
239
+ )
240
+ )
241
+ for mel, win, n_fft, hop, fmin, fmax in zip(
242
+ n_mels, window_lengths, n_ffts, hops, f_min, f_max
243
+ )
244
+ ]
245
+ )
246
+
247
+ def forward(
248
+ self, input_wave: torch.Tensor, target_wave: torch.Tensor
249
+ ) -> torch.Tensor:
250
+ assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
251
+ target_wave = target_wave.to(input_wave.device)
252
+ losses = 0.0
253
+ for M in self.mel_spectrograms:
254
+ # Apply normalization if requested
255
+ if self.use_istft_norm:
256
+ input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
257
+ target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
258
+ else:
259
+ input_proc, target_proc = input_wave, target_wave
260
+
261
+ x_mels = M(input_proc)
262
+ y_mels = M(target_proc)
263
+
264
+ loss = self.loss_mel_fn(x_mels.squeeze(), y_mels.squeeze())
265
+ losses += loss * self.lambda_mel
266
+
267
+ # pitch/f0 loss
268
+ if self.use_pitch_loss:
269
+ x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
270
+ y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
271
+ f0_loss = self.loss_pitch_fn(x_pitch, y_pitch)
272
+ losses += f0_loss * self.lambda_pitch
273
+
274
+ # energy/rms loss
275
+ if self.use_rms_loss:
276
+ x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
277
+ y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
278
+ rms_loss = self.loss_rms_fn(x_rms, y_rms)
279
+ losses += rms_loss * self.lambda_rms
280
+
281
+ return losses * self.weight
@@ -1,15 +1,20 @@
1
1
  __all__ = [
2
2
  "WarmupDecayScheduler",
3
3
  "AdaptiveDropScheduler",
4
- "WaveringLRScheduler",
4
+ "SinusoidalDecayLR",
5
+ "GuidedWaveringLR",
6
+ "FloorExponentialLR",
5
7
  ]
6
8
 
7
9
  import math
8
10
  from torch.optim import Optimizer
9
- from torch.optim.lr_scheduler import _LRScheduler
11
+ from torch.optim.lr_scheduler import LRScheduler
12
+ from typing import Optional
13
+ from numbers import Number
14
+ from lt_tensor.misc_utils import update_lr
10
15
 
11
16
 
12
- class WarmupDecayScheduler(_LRScheduler):
17
+ class WarmupDecayScheduler(LRScheduler):
13
18
  def __init__(
14
19
  self,
15
20
  optimizer: Optimizer,
@@ -49,7 +54,7 @@ class WarmupDecayScheduler(_LRScheduler):
49
54
  return lrs
50
55
 
51
56
 
52
- class AdaptiveDropScheduler(_LRScheduler):
57
+ class AdaptiveDropScheduler(LRScheduler):
53
58
  def __init__(
54
59
  self,
55
60
  optimizer,
@@ -89,26 +94,147 @@ class AdaptiveDropScheduler(_LRScheduler):
89
94
  return [group["lr"] for group in self.optimizer.param_groups]
90
95
 
91
96
 
92
- class WaveringLRScheduler(_LRScheduler):
97
+ class SinusoidalDecayLR(LRScheduler):
93
98
  def __init__(
94
- self, optimizer, base_lr, max_lr, period=1000, decay=0.999, last_epoch=-1
99
+ self,
100
+ optimizer: Optimizer,
101
+ initial_lr: float = 1e-3,
102
+ target_lr: float = 1e-5,
103
+ floor_lr: float = 1e-7,
104
+ decay_rate: float = 1e-6, # decay per period
105
+ wave_amplitude: float = 1e-5,
106
+ period: int = 256,
107
+ last_epoch: int = -1,
95
108
  ):
96
- """
97
- Sinusoidal-like oscillating LR. Can escape shallow local minima.
98
- - base_lr: minimum LR
99
- - max_lr: maximum LR
100
- - period: full sine cycle in steps
101
- - decay: multiplies max_lr each cycle
102
- """
103
- self.base_lr = base_lr
104
- self.max_lr = max_lr
109
+ assert decay_rate != 0.0, "decay_rate must be different from 0.0"
110
+ assert (
111
+ initial_lr >= target_lr >= floor_lr
112
+ ), "Must satisfy: initial_lr ≥ target_lr ≥ floor_lr"
113
+
114
+ self.initial_lr = initial_lr
115
+ self.target_lr = target_lr
116
+ self.floor_lr = floor_lr
117
+ self.decay_rate = decay_rate
118
+ self.wave_amplitude = wave_amplitude
105
119
  self.period = period
106
- self.decay = decay
120
+
107
121
  super().__init__(optimizer, last_epoch)
108
122
 
109
123
  def get_lr(self):
110
- cycle = self.last_epoch // self.period
111
- step_in_cycle = self.last_epoch % self.period
112
- factor = math.sin(math.pi * step_in_cycle / self.period)
113
- amplitude = (self.max_lr - self.base_lr) * (self.decay**cycle)
114
- return [self.base_lr + amplitude * factor for _ in self.optimizer.param_groups]
124
+ step = self.last_epoch + 1
125
+ cycles = step // self.period
126
+ t = step % self.period
127
+ # Decay center down to target_lr, then freeze
128
+ center_decay = math.exp(-self.decay_rate * cycles)
129
+ center = max(self.target_lr, self.initial_lr * center_decay)
130
+ # Decay amplitude in sync with center (relative to initial)
131
+ amplitude_decay = math.exp(-self.decay_rate * cycles)
132
+ current_amplitude = self.wave_amplitude * self.initial_lr * amplitude_decay
133
+ sin_offset = math.sin(2 * math.pi * t / self.period)
134
+ lr = max(center + current_amplitude * sin_offset, self.floor_lr)
135
+ return [lr for _ in self.optimizer.param_groups]
136
+
137
+
138
+ class GuidedWaveringLR(LRScheduler):
139
+ def __init__(
140
+ self,
141
+ optimizer: Optimizer,
142
+ initial_lr: float = 1e-3,
143
+ target_lr: float = 1e-5,
144
+ floor_lr: float = 1e-7,
145
+ decay_rate: float = 0.01,
146
+ wave_amplitude: float = 0.02,
147
+ period: int = 256,
148
+ stop_decay_after: int = None,
149
+ last_epoch: int = -1,
150
+ ):
151
+ assert decay_rate != 0.0, "decay_rate must be non-zero"
152
+ assert (
153
+ initial_lr >= target_lr >= floor_lr
154
+ ), "Must satisfy: initial ≥ target ≥ floor"
155
+
156
+ self.initial_lr = initial_lr
157
+ self.target_lr = target_lr
158
+ self.floor_lr = floor_lr
159
+ self.decay_rate = decay_rate
160
+ self.wave_amplitude = wave_amplitude
161
+ self.period = period
162
+ self.stop_decay_after = stop_decay_after
163
+
164
+ super().__init__(optimizer, last_epoch)
165
+
166
+ def get_lr(self):
167
+ step = self.last_epoch + 1
168
+ cycles = step // self.period
169
+ t = step % self.period
170
+
171
+ decay_cycles = (
172
+ min(cycles, self.stop_decay_after) if self.stop_decay_after else cycles
173
+ )
174
+ center = max(
175
+ self.target_lr, self.initial_lr * math.exp(-self.decay_rate * decay_cycles)
176
+ )
177
+ amp = (
178
+ self.wave_amplitude
179
+ * self.initial_lr
180
+ * math.exp(-self.decay_rate * decay_cycles)
181
+ )
182
+ phase = 2 * math.pi * t / self.period
183
+ wave = math.sin(phase) * math.cos(phase)
184
+ lr = max(center + amp * wave, self.floor_lr)
185
+ return [lr for _ in self.optimizer.param_groups]
186
+
187
+
188
+ class FloorExponentialLR(LRScheduler):
189
+ """Modified version from exponential lr, to have a minimum and reset functions.
190
+
191
+ Decays the learning rate of each parameter group by gamma every epoch.
192
+
193
+ When last_epoch=-1, sets initial lr as lr.
194
+
195
+ Args:
196
+ optimizer (Optimizer): Wrapped optimizer.
197
+ gamma (float): Multiplicative factor of learning rate decay.
198
+ last_epoch (int): The index of last epoch. Default: -1.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ optimizer: Optimizer,
204
+ initial_lr: float = 1e-4,
205
+ gamma: float = 0.99998,
206
+ last_epoch: int = -1,
207
+ floor_lr: float = 1e-6,
208
+ ):
209
+ self.gamma = gamma
210
+ self.floor_lr = floor_lr
211
+ self.initial_lr = initial_lr
212
+
213
+ super().__init__(optimizer, last_epoch)
214
+
215
+ def set_floor(self, new_value: float):
216
+ assert isinstance(new_value, Number)
217
+ self.floor_lr = new_value
218
+
219
+ def reset_lr(self, new_value: Optional[float] = None):
220
+ new_lr = new_value if isinstance(new_value, Number) else self.initial_lr
221
+ self.initial_lr = new_lr
222
+ update_lr(self.optimizer, new_lr)
223
+
224
+ def get_lr(self):
225
+
226
+ if self.last_epoch == 0:
227
+ return [
228
+ max(group["lr"], self.floor_lr) for group in self.optimizer.param_groups
229
+ ]
230
+
231
+ return [
232
+ max(group["lr"] * self.gamma, self.floor_lr)
233
+ for group in self.optimizer.param_groups
234
+ ]
235
+
236
+ def _get_closed_form_lr(self):
237
+ return [
238
+ max(base_lr * self.gamma**self.last_epoch, self.floor_lr)
239
+ for base_lr in self.base_lrs
240
+ ]
lt_tensor/misc_utils.py CHANGED
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "plot_view",
25
25
  "get_weights",
26
26
  "get_activated_conv",
27
+ "update_lr",
27
28
  ]
28
29
 
29
30
  import re
@@ -77,6 +78,33 @@ def get_activated_conv(
77
78
  )
78
79
 
79
80
 
81
+ def get_loss_average(losses: List[float]):
82
+ """A little helper for training, for example:
83
+ ```python
84
+ losses = []
85
+ for epoch in range(100):
86
+ for inp, label in dataloader:
87
+ optimizer.zero_grad()
88
+ out = model(inp)
89
+ loss = loss_fn(out, label)
90
+ optimizer.step()
91
+ losses.append(loss.item())
92
+ print(f"Epoch {epoch+1} | Loss: {get_loss_average(losses):.4f}")
93
+ """
94
+ if not losses:
95
+ return float("nan")
96
+ return sum(losses) / len(losses)
97
+
98
+
99
+ def update_lr(optimizer: optim.Optimizer, new_value: float = 1e-4):
100
+ for param_group in optimizer.param_groups:
101
+ if isinstance(param_group["lr"], Tensor):
102
+ param_group["lr"].fill_(new_value)
103
+ else:
104
+ param_group["lr"] = new_value
105
+ return optimizer
106
+
107
+
80
108
  def plot_view(
81
109
  data: Dict[str, List[Any]],
82
110
  title: str = "Loss",
@@ -520,49 +548,14 @@ def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
520
548
  return flat[idx]
521
549
 
522
550
 
523
- class TorchCacheUtils:
524
- cached_shortcuts: dict[str, Callable[[None], None]] = {}
525
-
526
- has_cuda: bool = torch.cuda.is_available()
527
- has_xpu: bool = torch.xpu.is_available()
528
- has_mps: bool = torch.mps.is_available()
529
-
530
- _ignore: list[str] = []
531
-
532
- def __init__(self):
533
- pass
534
-
535
- def _apply_clear(self, device: str):
536
- if device in self._ignore:
537
- gc.collect()
538
- return
539
- try:
540
- clear_fn = self.cached_shortcuts.get(
541
- device, getattr(torch, device).empty_cache
542
- )
543
- if device not in self.cached_shortcuts:
544
- self.cached_shortcuts.update({device: clear_fn})
545
-
546
- except Exception as e:
547
- print(e)
548
- self._ignore.append(device)
549
-
550
- def clear(self):
551
- gc.collect()
552
- if self.has_xpu:
553
- self._apply_clear("xpu")
554
- if self.has_cuda:
555
- self._apply_clear("cuda")
556
- if self.has_mps:
557
- self._apply_clear("mps")
558
- gc.collect()
559
-
560
-
561
- _clear_cache_cls = TorchCacheUtils()
562
-
563
-
564
551
  def clear_cache():
565
- _clear_cache_cls.clear()
552
+ if torch.cuda.is_available():
553
+ torch.cuda.empty_cache()
554
+ if torch.mps.is_available():
555
+ torch.mps.empty_cache()
556
+ if torch.xpu.is_available():
557
+ torch.xpu.empty_cache()
558
+ gc.collect()
566
559
 
567
560
 
568
561
  @cache_wrapper
@@ -0,0 +1,3 @@
1
+ from . import alias_free, snake
2
+
3
+ __all__ = ["snake", "alias_free"]
@@ -0,0 +1,3 @@
1
+ from .act import *
2
+ from .filter import *
3
+ from .resample import *
@@ -1,15 +1,17 @@
1
- import torch
2
1
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- from .resample import UpSample1d, DownSample1d
5
- from .resample import UpSample2d, DownSample2d
2
+ from lt_tensor.model_zoo.activations.alias_free.resample import (
3
+ UpSample2d,
4
+ DownSample2d,
5
+ UpSample1d,
6
+ DownSample1d,
7
+ )
6
8
 
7
9
 
8
10
  class Activation1d(nn.Module):
9
11
 
10
12
  def __init__(
11
13
  self,
12
- activation,
14
+ activation: nn.Module,
13
15
  up_ratio: int = 2,
14
16
  down_ratio: int = 2,
15
17
  up_kernel_size: int = 12,
@@ -34,7 +36,7 @@ class Activation2d(nn.Module):
34
36
 
35
37
  def __init__(
36
38
  self,
37
- activation,
39
+ activation: nn.Module,
38
40
  up_ratio: int = 2,
39
41
  down_ratio: int = 2,
40
42
  up_kernel_size: int = 12,