lt-tensor 0.0.1a32__py3-none-any.whl → 0.0.1a34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a"
1
+ __version__ = "0.0.1a34"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
lt_tensor/losses.py CHANGED
@@ -6,19 +6,24 @@ __all__ = [
6
6
  "hybrid_loss",
7
7
  "diff_loss",
8
8
  "cosine_loss",
9
- "gan_loss",
10
9
  "ft_n_loss",
10
+ "MultiMelScaleLoss",
11
11
  ]
12
12
  import math
13
13
  import random
14
14
  from lt_tensor.torch_commons import *
15
15
  from lt_utils.common import *
16
16
  import torch.nn.functional as F
17
+ from lt_tensor.model_base import Model
18
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
19
+ from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
20
+
17
21
 
18
22
  def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
19
23
  if weight is not None:
20
- return torch.mean((torch.abs(output - target) + weight) **0.5)
21
- return torch.mean(torch.abs(output - target)**0.5)
24
+ return torch.mean((torch.abs(output - target) + weight) ** 0.5)
25
+ return torch.mean(torch.abs(output - target) ** 0.5)
26
+
22
27
 
23
28
  def adaptive_l1_loss(
24
29
  inp: Tensor,
@@ -58,50 +63,6 @@ def cosine_loss(inp, tgt):
58
63
  return 1 - cos.mean() # Lower is better
59
64
 
60
65
 
61
- class GanLosses:
62
- @staticmethod
63
- def get_loss(
64
- pred: Tensor,
65
- target_is_real: bool,
66
- loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
67
- ) -> Tensor:
68
- if loss_type == "bce": # Standard GAN
69
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
70
- return F.binary_cross_entropy_with_logits(pred, target)
71
-
72
- elif loss_type == "mse": # LSGAN
73
- target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
74
- return F.mse_loss(torch.sigmoid(pred), target)
75
-
76
- elif loss_type == "hinge":
77
- if target_is_real:
78
- return torch.mean(F.relu(1.0 - pred))
79
- else:
80
- return torch.mean(F.relu(1.0 + pred))
81
-
82
- elif loss_type == "wasserstein":
83
- return -pred.mean() if target_is_real else pred.mean()
84
-
85
- else:
86
- raise ValueError(f"Unknown loss_type: {loss_type}")
87
-
88
- @staticmethod
89
- def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
90
- return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
91
-
92
- @staticmethod
93
- def discriminator_loss(
94
- real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
95
- ) -> Tensor:
96
- real_loss = GanLosses.get_loss(
97
- real_pred, target_is_real=True, loss_type=loss_type
98
- )
99
- fake_loss = GanLosses.get_loss(
100
- fake_pred.detach(), target_is_real=False, loss_type=loss_type
101
- )
102
- return (real_loss + fake_loss) * 0.5
103
-
104
-
105
66
  def masked_cross_entropy(
106
67
  logits: torch.Tensor, # [B, T, V]
107
68
  targets: torch.Tensor, # [B, T]
@@ -157,3 +118,160 @@ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
157
118
  torch.log(1 - fake + 1e-7)
158
119
  )
159
120
  return loss
121
+
122
+
123
+ class MultiMelScaleLoss(Model):
124
+ def __init__(
125
+ self,
126
+ sample_rate: int,
127
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
128
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
129
+ n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
130
+ hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
131
+ f_min: float = [0, 0, 0, 0, 0, 0, 0],
132
+ f_max: Optional[float] = [None, None, None, None, None, None, None],
133
+ loss_fn: Callable = nn.L1Loss(),
134
+ center: bool = True,
135
+ power: float = 1.0,
136
+ normalized: bool = False,
137
+ pad_mode: str = "reflect",
138
+ onesided: Optional[bool] = None,
139
+ std: int = 4,
140
+ mean: int = -4,
141
+ use_istft_norm: bool = True,
142
+ use_pitch_loss: bool = True,
143
+ use_rms_loss: bool = True,
144
+ norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
145
+ norm_rms_fn: Callable[[Tensor], Tensor] = normalize_unit_norm,
146
+ lambda_mel: float = 1.0,
147
+ lambda_rms: float = 1.0,
148
+ lambda_pitch: float = 1.0,
149
+ weight: float = 1.0,
150
+ ):
151
+ super().__init__()
152
+ assert (
153
+ len(n_mels)
154
+ == len(window_lengths)
155
+ == len(n_ffts)
156
+ == len(hops)
157
+ == len(f_min)
158
+ == len(f_max)
159
+ )
160
+ self.loss_fn = loss_fn
161
+ self.lambda_mel = lambda_mel
162
+ self.weight = weight
163
+ self.use_istft_norm = use_istft_norm
164
+ self.use_pitch_loss = use_pitch_loss
165
+ self.use_rms_loss = use_rms_loss
166
+ self.lambda_pitch = lambda_pitch
167
+ self.lambda_rms = lambda_rms
168
+
169
+ self.norm_pitch_fn = norm_pitch_fn
170
+ self.norm_rms = norm_rms_fn
171
+
172
+ self._setup_mels(
173
+ sample_rate,
174
+ n_mels,
175
+ window_lengths,
176
+ n_ffts,
177
+ hops,
178
+ f_min,
179
+ f_max,
180
+ center,
181
+ power,
182
+ normalized,
183
+ pad_mode,
184
+ onesided,
185
+ std,
186
+ mean,
187
+ )
188
+
189
+ def _setup_mels(
190
+ self,
191
+ sample_rate: int,
192
+ n_mels: List[int],
193
+ window_lengths: List[int],
194
+ n_ffts: List[int],
195
+ hops: List[int],
196
+ f_min: List[float],
197
+ f_max: List[Optional[float]],
198
+ center: bool,
199
+ power: float,
200
+ normalized: bool,
201
+ pad_mode: str,
202
+ onesided: Optional[bool],
203
+ std: int,
204
+ mean: int,
205
+ ):
206
+ assert (
207
+ len(n_mels)
208
+ == len(window_lengths)
209
+ == len(n_ffts)
210
+ == len(hops)
211
+ == len(f_min)
212
+ == len(f_max)
213
+ )
214
+ _mel_kwargs = dict(
215
+ sample_rate=sample_rate,
216
+ center=center,
217
+ onesided=onesided,
218
+ normalized=normalized,
219
+ power=power,
220
+ pad_mode=pad_mode,
221
+ std=std,
222
+ mean=mean,
223
+ )
224
+ self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
225
+ [
226
+ AudioProcessor(
227
+ AudioProcessorConfig(
228
+ **_mel_kwargs,
229
+ n_mels=mel,
230
+ n_fft=n_fft,
231
+ win_length=win,
232
+ hop_length=hop,
233
+ f_min=fmin,
234
+ f_max=fmax,
235
+ )
236
+ )
237
+ for mel, win, n_fft, hop, fmin, fmax in zip(
238
+ n_mels, window_lengths, n_ffts, hops, f_min, f_max
239
+ )
240
+ ]
241
+ )
242
+
243
+ def forward(
244
+ self, input_wave: torch.Tensor, target_wave: torch.Tensor
245
+ ) -> torch.Tensor:
246
+ assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
247
+ target_wave = target_wave.to(input_wave.device)
248
+ losses = 0.0
249
+ for M in self.mel_spectrograms:
250
+ # Apply normalization if requested
251
+ if self.use_istft_norm:
252
+ input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
253
+ target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
254
+ else:
255
+ input_proc, target_proc = input_wave, target_wave
256
+
257
+ x_mels = M(input_proc)
258
+ y_mels = M(target_proc)
259
+
260
+ loss = self.loss_fn(x_mels.squeeze(), y_mels.squeeze())
261
+ losses += loss * self.lambda_mel
262
+
263
+ # pitch/f0 loss
264
+ if self.use_pitch_loss:
265
+ x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
266
+ y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
267
+ f0_loss = self.loss_fn(x_pitch, y_pitch)
268
+ losses += f0_loss * self.lambda_pitch
269
+
270
+ # energy/rms loss
271
+ if self.use_rms_loss:
272
+ x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
273
+ y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
274
+ rms_loss = self.loss_fn(x_rms, y_rms)
275
+ losses += rms_loss * self.lambda_rms
276
+
277
+ return losses * self.weight
lt_tensor/math_ops.py CHANGED
@@ -6,10 +6,12 @@ __all__ = [
6
6
  "apply_window",
7
7
  "shift_ring",
8
8
  "dot_product",
9
- "normalize_tensor",
10
9
  "log_magnitude",
11
10
  "shift_time",
12
11
  "phase",
12
+ "normalize_unit_norm",
13
+ "normalize_minmax",
14
+ "normalize_zscore",
13
15
  ]
14
16
 
15
17
  from lt_tensor.torch_commons import *
@@ -61,11 +63,6 @@ def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
61
63
  return torch.sum(x * y, dim=dim)
62
64
 
63
65
 
64
- def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
65
- """Normalizes a tensor to unit norm (L2)."""
66
- return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
67
-
68
-
69
66
  def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
70
67
  """Returns log magnitude from complex STFT."""
71
68
  magnitude = torch.abs(stft_complex)
@@ -76,3 +73,19 @@ def phase(stft_complex: Tensor) -> Tensor:
76
73
  """Returns phase from complex STFT."""
77
74
  return torch.angle(stft_complex)
78
75
 
76
+
77
+ def normalize_unit_norm(x: torch.Tensor, eps: float = 1e-6):
78
+ norm = torch.norm(x, dim=-1, keepdim=True)
79
+ return x / (norm + eps)
80
+
81
+
82
+ def normalize_minmax(x: torch.Tensor, eps: float = 1e-6):
83
+ min_val = x.amin(dim=-1, keepdim=True)
84
+ max_val = x.amax(dim=-1, keepdim=True)
85
+ return (x - min_val) / (max_val - min_val + eps)
86
+
87
+
88
+ def normalize_zscore(x: torch.Tensor, eps: float = 1e-6):
89
+ mean = x.mean(dim=-1, keepdim=True)
90
+ std = x.std(dim=-1, keepdim=True)
91
+ return (x - mean) / (std + eps)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from lt_tensor.model_zoo.audio_models.hifigan import ConvNets
2
4
  from lt_utils.common import *
3
5
  from lt_tensor.torch_commons import *
@@ -5,6 +7,8 @@ from lt_tensor.model_base import Model
5
7
  from lt_tensor.model_zoo.convs import ConvNets
6
8
  from torch.nn import functional as F
7
9
  from torchaudio import transforms as T
10
+ from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
11
+
8
12
 
9
13
  MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
10
14
  List[Tensor],
@@ -14,11 +18,75 @@ MULTI_DISC_OUT_TYPE: TypeAlias = Tuple[
14
18
  ]
15
19
 
16
20
 
21
+ class MultiDiscriminatorWrapper(Model):
22
+ def __init__(self, list_discriminator: List["_MultiDiscriminatorT"]):
23
+ """Setup example:
24
+ model_d = MultiDiscriminatorStep(
25
+ [
26
+ MultiEnvelopeDiscriminator(),
27
+ MultiBandDiscriminator(),
28
+ MultiResolutionDiscriminator(),
29
+ MultiPeriodDiscriminator(0.5),
30
+ ]
31
+ )
32
+ """
33
+ super().__init__()
34
+ self.disc: Sequence[_MultiDiscriminatorT] = nn.ModuleList(list_discriminator)
35
+ self.total = len(self.disc)
36
+
37
+ def forward(
38
+ self,
39
+ y: Tensor,
40
+ y_hat: Tensor,
41
+ step_type: Literal["discriminator", "generator"],
42
+ ) -> Union[
43
+ Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
44
+ ]:
45
+ """
46
+ It returns the content based on the choice of "step_type", being it a
47
+ 'discriminator' or 'generator'
48
+
49
+ For generator it returns:
50
+ Tuple[Tensor, Tensor, List[float]]
51
+ "gen_loss, feat_loss, all_g_losses"
52
+
53
+ For 'discriminator' it returns:
54
+ Tuple[Tensor, List[float], List[float]]
55
+ "disc_loss, disc_real_losses, disc_gen_losses"
56
+ """
57
+ if step_type == "generator":
58
+ all_g_losses: List[float] = []
59
+ feat_loss: Tensor = 0
60
+ gen_loss: Tensor = 0
61
+ else:
62
+ disc_loss: Tensor = 0
63
+ disc_real_losses: List[float] = []
64
+ disc_gen_losses: List[float] = []
65
+
66
+ for disc in self.disc:
67
+ if step_type == "generator":
68
+ # feature loss, generator loss, list of generator losses (float)]
69
+ f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
70
+ gen_loss += g_loss
71
+ feat_loss += f_loss
72
+ all_g_losses.extend(g_losses)
73
+ else:
74
+ # [discriminator loss, (disc losses real, disc losses generated)]
75
+ d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
76
+ disc_loss += d_loss
77
+ disc_real_losses.extend(d_real_losses)
78
+ disc_gen_losses.extend(d_gen_losses)
79
+
80
+ if step_type == "generator":
81
+ return gen_loss, feat_loss, all_g_losses
82
+ return disc_loss, disc_real_losses, disc_gen_losses
83
+
84
+
17
85
  def get_padding(kernel_size, dilation=1):
18
86
  return int((kernel_size * dilation - dilation) / 2)
19
87
 
20
88
 
21
- class MultiDiscriminatorWrapper(ConvNets):
89
+ class _MultiDiscriminatorT(ConvNets):
22
90
  """Base for all multi-steps type of discriminators"""
23
91
 
24
92
  def __init__(self, *args, **kwargs):
@@ -171,7 +239,7 @@ class DiscriminatorP(ConvNets):
171
239
  return x.flatten(1, -1), fmap
172
240
 
173
241
 
174
- class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
242
+ class MultiPeriodDiscriminator(_MultiDiscriminatorT):
175
243
  def __init__(
176
244
  self,
177
245
  discriminator_channel_mult: Number = 1,
@@ -258,7 +326,7 @@ class DiscriminatorEnvelope(ConvNets):
258
326
  return x.flatten(1), fmap
259
327
 
260
328
 
261
- class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
329
+ class MultiEnvelopeDiscriminator(_MultiDiscriminatorT):
262
330
  def __init__(self, use_spectral_norm: bool = False):
263
331
  super().__init__()
264
332
  self.discriminators = nn.ModuleList(
@@ -375,7 +443,7 @@ class DiscriminatorB(ConvNets):
375
443
  return x, fmap
376
444
 
377
445
 
378
- class MultiBandDiscriminator(MultiDiscriminatorWrapper):
446
+ class MultiBandDiscriminator(_MultiDiscriminatorT):
379
447
  """
380
448
  Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
381
449
  and the modified code adapted from https://github.com/gemelo-ai/vocos.
@@ -514,7 +582,7 @@ class DiscriminatorR(ConvNets):
514
582
  return mag
515
583
 
516
584
 
517
- class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
585
+ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
518
586
  def __init__(
519
587
  self,
520
588
  use_spectral_norm: bool = False,
@@ -552,71 +620,3 @@ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
552
620
  y_d_gs.append(y_d_g)
553
621
  fmap_gs.append(fmap_g)
554
622
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
555
-
556
-
557
- class MultiDiscriminatorStep(Model):
558
- def __init__(
559
- self, list_discriminator: List[MultiDiscriminatorWrapper]
560
- ):
561
- """Setup example:
562
- model_d = MultiDiscriminatorStep(
563
- [
564
- MultiEnvelopeDiscriminator(),
565
- MultiBandDiscriminator(),
566
- MultiResolutionDiscriminator(),
567
- MultiPeriodDiscriminator(0.5),
568
- ]
569
- )
570
- """
571
- super().__init__()
572
- self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
573
- list_discriminator
574
- )
575
- self.total = len(self.disc)
576
-
577
- def forward(
578
- self,
579
- y: Tensor,
580
- y_hat: Tensor,
581
- step_type: Literal["discriminator", "generator"],
582
- ) -> Union[
583
- Tuple[Tensor, Tensor, List[float]], Tuple[Tensor, List[float], List[float]]
584
- ]:
585
- """
586
- It returns the content based on the choice of "step_type", being it a
587
- 'discriminator' or 'generator'
588
-
589
- For generator it returns:
590
- Tuple[Tensor, Tensor, List[float]]
591
- "gen_loss, feat_loss, all_g_losses"
592
-
593
- For 'discriminator' it returns:
594
- Tuple[Tensor, List[float], List[float]]
595
- "disc_loss, disc_real_losses, disc_gen_losses"
596
- """
597
- if step_type == "generator":
598
- all_g_losses: List[float] = []
599
- feat_loss: Tensor = 0
600
- gen_loss: Tensor = 0
601
- else:
602
- disc_loss: Tensor = 0
603
- disc_real_losses: List[float] = []
604
- disc_gen_losses: List[float] = []
605
-
606
- for disc in self.disc:
607
- if step_type == "generator":
608
- # feature loss, generator loss, list of generator losses (float)]
609
- f_loss, g_loss, g_losses = disc.gen_step(y, y_hat)
610
- gen_loss += g_loss
611
- feat_loss += f_loss
612
- all_g_losses.extend(g_losses)
613
- else:
614
- # [discriminator loss, (disc losses real, disc losses generated)]
615
- d_loss, (d_real_losses, d_gen_losses) = disc.disc_step(y, y_hat)
616
- disc_loss += d_loss
617
- disc_real_losses.extend(d_real_losses)
618
- disc_gen_losses.extend(d_gen_losses)
619
-
620
- if step_type == "generator":
621
- return gen_loss, feat_loss, all_g_losses
622
- return disc_loss, disc_real_losses, disc_gen_losses
@@ -23,7 +23,7 @@ class AudioProcessorConfig(ModelConfig):
23
23
  win_length: int = 1024
24
24
  hop_length: int = 256
25
25
  f_min: float = 0
26
- f_max: float = 8000.0
26
+ f_max: Optional[float] = None
27
27
  center: bool = True
28
28
  mel_scale: Literal["htk" "slaney"] = "htk"
29
29
  std: int = 4
@@ -41,8 +41,8 @@ class AudioProcessorConfig(ModelConfig):
41
41
  n_fft: int = 1024,
42
42
  win_length: Optional[int] = None,
43
43
  hop_length: Optional[int] = None,
44
- f_min: float = 1,
45
- f_max: float = 12000.0,
44
+ f_min: float = 0,
45
+ f_max: Optional[float] = None,
46
46
  center: bool = True,
47
47
  mel_scale: Literal["htk", "slaney"] = "htk",
48
48
  std: int = 4,
@@ -71,9 +71,12 @@ class AudioProcessorConfig(ModelConfig):
71
71
  self.post_process()
72
72
 
73
73
  def post_process(self):
74
- self.f_min = max(self.f_min, 1)
75
- self.f_max = max(min(self.f_max, self.n_fft // 2), self.f_min + 1)
76
74
  self.n_stft = self.n_fft // 2 + 1
75
+ # some functions needs this to be a non-zero or not None value.
76
+ self.f_min = max(self.f_min, (self.sample_rate / (self.n_fft - 1)) * 2)
77
+ self.default_f_max = min(
78
+ default(self.f_max, self.sample_rate // 2), self.sample_rate // 2
79
+ )
77
80
  self.hop_length = default(self.hop_length, self.n_fft // 4)
78
81
  self.win_length = default(self.win_length, self.n_fft)
79
82
 
@@ -105,7 +108,7 @@ class AudioProcessor(Model):
105
108
  onesided=self.cfg.onesided,
106
109
  normalized=self.cfg.normalized,
107
110
  )
108
- self.mel_rscale = torchaudio.transforms.InverseMelScale(
111
+ self._mel_rscale = torchaudio.transforms.InverseMelScale(
109
112
  n_stft=self.cfg.n_stft,
110
113
  n_mels=self.cfg.n_mels,
111
114
  sample_rate=self.cfg.sample_rate,
@@ -119,32 +122,39 @@ class AudioProcessor(Model):
119
122
  (torch.hann_window(self.cfg.win_length) if window is None else window),
120
123
  )
121
124
 
122
- def from_numpy(
123
- self,
124
- array: np.ndarray,
125
- device: Optional[torch.device] = None,
126
- dtype: Optional[torch.dtype] = None,
127
- ):
128
- converted = torch.from_numpy(array)
129
- if device is None:
130
- device = self.device
131
- return converted.to(device=device, dtype=dtype)
132
125
 
133
- def from_numpy_batch(
126
+
127
+ def compute_mel(
134
128
  self,
135
- arrays: List[np.ndarray],
136
- device: Optional[torch.device] = None,
137
- dtype: Optional[torch.dtype] = None,
138
- ):
139
- stacked = torch.stack([torch.from_numpy(x) for x in arrays])
140
- if device is None:
141
- device = self.device
142
- return stacked.to(device=device, dtype=dtype)
129
+ wave: Tensor,
130
+ raw_mel_only: bool = False,
131
+ eps: float = 1e-5,
132
+ *,
133
+ _recall: bool = False,
134
+ ) -> Tensor:
135
+ """Returns: [B, M, T]"""
136
+ try:
137
+ mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
138
+ if not raw_mel_only:
139
+ mel_tensor = (
140
+ torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
141
+ ) / self.cfg.std
142
+ return mel_tensor.squeeze()
143
143
 
144
- def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
145
- if isinstance(tensor, np.ndarray):
146
- return tensor
147
- return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
144
+ except RuntimeError as e:
145
+ if not _recall:
146
+ self._mel_spec.to(self.device)
147
+ return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
148
+ raise e
149
+
150
+ def compute_inverse_mel(self, melspec: Tensor, *, _recall=False):
151
+ try:
152
+ return self._mel_rscale.forward(melspec.to(self.device)).squeeze()
153
+ except RuntimeError as e:
154
+ if not _recall:
155
+ self._mel_rscale.to(self.device)
156
+ return self.compute_inverse_mel(melspec, _recall=True)
157
+ raise e
148
158
 
149
159
  def compute_rms(
150
160
  self,
@@ -192,12 +202,44 @@ class AudioProcessor(Model):
192
202
  else:
193
203
  rms_ = []
194
204
  for i in range(B):
195
- _r = librosa.feature.rms(_comp_rms_helper(i, audio, mel), **rms_kwargs)[
205
+ _t = _comp_rms_helper(i, audio, mel)
206
+ _r = librosa.feature.rms(**_t, **rms_kwargs)[
196
207
  0
197
208
  ]
198
209
  rms_.append(_r)
199
210
  return self.from_numpy_batch(rms_, default_device, default_dtype).squeeze()
200
211
 
212
+ def pitch_shift(self, audio: torch.Tensor, sample_rate: Optional[int] = None, n_steps: float = 2.0):
213
+ """
214
+ Shifts the pitch of an audio tensor by `n_steps` semitones.
215
+
216
+ Args:
217
+ audio (torch.Tensor): Tensor of shape (B, T) or (T,)
218
+ sample_rate (int, optional): Sample rate of the audio. Will use the class sample rate if unset.
219
+ n_steps (float): Number of semitones to shift. Can be negative.
220
+
221
+ Returns:
222
+ torch.Tensor: Pitch-shifted audio.
223
+ """
224
+ src_device = audio.device
225
+ src_dtype = audio.dtype
226
+ audio = audio.squeeze()
227
+ sample_rate = default(sample_rate, self.cfg.sample_rate)
228
+ def _shift_one(wav):
229
+ wav_np = self.to_numpy_safe(wav)
230
+ shifted_np = librosa.effects.pitch_shift(wav_np, sr=sample_rate, n_steps=n_steps)
231
+ return torch.from_numpy(shifted_np)
232
+
233
+ if audio.ndim == 1:
234
+ return _shift_one(audio).to(device=src_device, dtype=src_dtype)
235
+ return torch.stack([_shift_one(a) for a in audio]).to(device=src_device, dtype=src_dtype)
236
+
237
+
238
+ @staticmethod
239
+ def calc_pitch_fmin(sr:int, frame_length:float):
240
+ """For pitch f_min"""
241
+ return (sr / (frame_length - 1)) * 2
242
+
201
243
  def compute_pitch(
202
244
  self,
203
245
  audio: Tensor,
@@ -218,9 +260,9 @@ class AudioProcessor(Model):
218
260
  else:
219
261
  B = 1
220
262
  sr = default(sr, self.cfg.sample_rate)
221
- fmin = max(default(fmin, self.cfg.f_min), 65)
222
- fmax = min(default(fmax, self.cfg.f_max), sr // 2)
223
263
  frame_length = default(frame_length, self.cfg.n_fft)
264
+ fmin = max(default(fmin, self.cfg.f_min), self.calc_pitch_fmin(sr, frame_length))
265
+ fmax = min(max(default(fmax, self.cfg.default_f_max), fmin+1), sr // 2)
224
266
  hop_length = default(hop_length, self.cfg.hop_length)
225
267
  center = default(center, self.cfg.center)
226
268
  yn_kwargs = dict(
@@ -257,10 +299,10 @@ class AudioProcessor(Model):
257
299
  frame_length: Optional[Number] = None,
258
300
  ):
259
301
  sr = default(sr, self.sample_rate)
260
- fmin = max(default(fmin, self.f_min), 1)
261
- fmax = min(default(fmax, self.f_max), sr // 2)
262
- win_length = default(win_length, self.win_length)
263
- frame_length = default(frame_length, self.n_fft)
302
+ win_length = default(win_length, self.cfg.win_length)
303
+ frame_length = default(frame_length, self.cfg.n_fft)
304
+ fmin = default(fmin, self.calc_pitch_fmin(sr, frame_length))
305
+ fmax = default(fmax, self.cfg.default_f_max)
264
306
  return detect_pitch_frequency(
265
307
  audio,
266
308
  sample_rate=sr,
@@ -270,6 +312,33 @@ class AudioProcessor(Model):
270
312
  freq_high=fmax,
271
313
  ).squeeze()
272
314
 
315
+ def from_numpy(
316
+ self,
317
+ array: np.ndarray,
318
+ device: Optional[torch.device] = None,
319
+ dtype: Optional[torch.dtype] = None,
320
+ ):
321
+ converted = torch.from_numpy(array)
322
+ if device is None:
323
+ device = self.device
324
+ return converted.to(device=device, dtype=dtype)
325
+
326
+ def from_numpy_batch(
327
+ self,
328
+ arrays: List[np.ndarray],
329
+ device: Optional[torch.device] = None,
330
+ dtype: Optional[torch.dtype] = None,
331
+ ):
332
+ stacked = torch.stack([torch.from_numpy(x) for x in arrays])
333
+ if device is None:
334
+ device = self.device
335
+ return stacked.to(device=device, dtype=dtype)
336
+
337
+ def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
338
+ if isinstance(tensor, np.ndarray):
339
+ return tensor
340
+ return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
341
+
273
342
  def interpolate(
274
343
  self,
275
344
  tensor: Tensor,
@@ -391,29 +460,6 @@ class AudioProcessor(Model):
391
460
  return self.istft_norm(wave, length, _recall=True)
392
461
  raise e
393
462
 
394
- def compute_mel(
395
- self,
396
- wave: Tensor,
397
- raw_mel_only: bool = False,
398
- eps: float = 1e-5,
399
- *,
400
- _recall: bool = False,
401
- ) -> Tensor:
402
- """Returns: [B, M, T]"""
403
- try:
404
- mel_tensor = self._mel_spec(wave.to(self.device)) # [M, T]
405
- if not raw_mel_only:
406
- mel_tensor = (
407
- torch.log(eps + mel_tensor.unsqueeze(0)) - self.cfg.mean
408
- ) / self.cfg.std
409
- return mel_tensor.squeeze()
410
-
411
- except RuntimeError as e:
412
- if not _recall:
413
- self._mel_spec.to(self.device)
414
- return self.compute_mel(wave, raw_mel_only, eps, _recall=True)
415
- raise e
416
-
417
463
  def load_audio(
418
464
  self,
419
465
  path: PathLike,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a32
3
+ Version: 0.0.1a34
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils>=0.0.3
20
+ Requires-Dist: lt-utils>=0.0.4
21
21
  Requires-Dist: librosa==0.11.*
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
@@ -1,8 +1,8 @@
1
- lt_tensor/__init__.py,sha256=8FTxpJ6td2bMr_GqzW2tCV6Tr5CelbQle8N5JRWtx8M,439
1
+ lt_tensor/__init__.py,sha256=WAGPuMPq5c4DGAJ57x1Ykgzg3vMlLq9BiWk5EdJcUsU,441
2
2
  lt_tensor/config_templates.py,sha256=F9UvL8paAjkSvio890kp8WznpYeI50pYnm9iqQroBxk,2797
3
- lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
3
+ lt_tensor/losses.py,sha256=fHVMqOFo3ekjORYy89R_aRjmtT6lo27Z1egzOYjQ1W8,8646
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
- lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
5
+ lt_tensor/math_ops.py,sha256=ahX6Z1Mt3X-FhmwSZYZea5mB1B0S8GDuvKPfAm5e_FQ,2646
6
6
  lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
7
7
  lt_tensor/model_base.py,sha256=5T4dbAh4MXbQmPRpihGtMYwTY8sJTQOhY6An3VboM58,18086
8
8
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
@@ -27,11 +27,11 @@ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
29
  lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
30
- lt_tensor/model_zoo/losses/discriminators.py,sha256=ZA7Qqrhe8kELrI1-IITadGSl8JCgpgPKFCW6qvSOk1E,20724
30
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=ZpyByFgc7L7uV_XRBsV9vkdVItbJO3z--Y6LlvTvtwY,20765
31
31
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
32
- lt_tensor/processors/audio.py,sha256=1JuxxexfUsXkLjVjWUk-oTRU-QNnCCwvKX3eP0m7LGE,16452
33
- lt_tensor-0.0.1a32.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
34
- lt_tensor-0.0.1a32.dist-info/METADATA,sha256=gDYEHtmPwgyKRPNLnU3ZDRtDAqnDgrODoVW5wL2ib3c,1062
35
- lt_tensor-0.0.1a32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- lt_tensor-0.0.1a32.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
- lt_tensor-0.0.1a32.dist-info/RECORD,,
32
+ lt_tensor/processors/audio.py,sha256=HNr1GS-6M2q0Rda4cErf5y2Jlc9f4jD58FvpX2ua9d4,18369
33
+ lt_tensor-0.0.1a34.dist-info/licenses/LICENSE,sha256=TbiyJWLgNqqgqhfCnrGwFIxy7EqGNrIZZcKhHrefcuU,11354
34
+ lt_tensor-0.0.1a34.dist-info/METADATA,sha256=WkTafcY5nYZbrZ7WzUc3JXnmg9NtUAXrchx42dCok9I,1062
35
+ lt_tensor-0.0.1a34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ lt_tensor-0.0.1a34.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
+ lt_tensor-0.0.1a34.dist-info/RECORD,,
@@ -186,7 +186,7 @@
186
186
  same "printed page" as the copyright notice for easier
187
187
  identification within third-party archives.
188
188
 
189
- Copyright 2025 gr1336
189
+ Copyright 2025 gr1336 (Gabriel Ribeiro)
190
190
 
191
191
  Licensed under the Apache License, Version 2.0 (the "License");
192
192
  you may not use this file except in compliance with the License.