lt-tensor 0.0.1a4__py3-none-any.whl → 0.0.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1,13 +1,17 @@
1
1
  __version__ = "0.0.1a"
2
2
 
3
3
  from . import (
4
+ lr_schedulers,
4
5
  model_zoo,
5
6
  model_base,
6
7
  math_ops,
7
8
  misc_utils,
8
9
  monotonic_align,
9
10
  transform,
10
- lr_schedulers,
11
+ noise_tools,
12
+ losses,
13
+ processors,
14
+ datasets,
11
15
  )
12
16
 
13
17
  __all__ = [
@@ -18,4 +22,8 @@ __all__ = [
18
22
  "monotonic_align",
19
23
  "transform",
20
24
  "lr_schedulers",
25
+ "noise_tools",
26
+ "losses",
27
+ "processors",
28
+ "datasets",
21
29
  ]
@@ -1,110 +1,109 @@
1
- __all__ = ["AudioProcessor"]
1
+ __all__ = ["WaveMelDatasets"]
2
2
  from ..torch_commons import *
3
- import torchaudio
4
3
  from lt_utils.common import *
5
- import librosa
6
- from lt_utils.type_utils import is_file
7
- from torchaudio.functional import resample
8
- from ..transform import inverse_transform
4
+ import random
5
+ from torch.utils.data import Dataset, DataLoader, Sampler
6
+ from ..processors import AudioProcessor
7
+ import torch.nn.functional as FT
8
+ from ..misc_utils import log_tensor
9
9
 
10
10
 
11
- class AudioProcessor:
11
+ class WaveMelDataset(Dataset):
12
+ """Untested!"""
13
+
14
+ data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
12
15
 
13
16
  def __init__(
14
17
  self,
15
- sample_rate: int = 24000,
16
- n_mels: int = 80,
17
- n_fft: int = 1024,
18
- win_length: int = 1024,
19
- hop_length: int = 256,
20
- f_min: float = 0,
21
- f_max: float | None = None,
22
- n_iter: int = 32,
23
- center: bool = True,
24
- mel_scale: Literal["htk", "slaney"] = "htk",
25
- inv_n_fft: int = 16,
26
- inv_hop: int = 4,
27
- std: int = 4,
28
- mean: int = -4,
18
+ audio_processor: AudioProcessor,
19
+ path: PathLike,
20
+ limit_files: Optional[int] = None,
21
+ max_frame_length: Optional[int] = None,
29
22
  ):
30
- self.mean = mean
31
- self.std = std
32
- self.n_mels = n_mels
33
- self.n_fft = n_fft
34
- self.n_stft = n_fft // 2 + 1
35
- self.f_min = f_min
36
- self.f_max = f_max
37
- self.n_iter = n_iter
38
- self.hop_length = hop_length
39
- self.sample_rate = sample_rate
40
- self.mel_spec = torchaudio.transforms.MelSpectrogram(
41
- sample_rate=sample_rate,
42
- n_mels=n_mels,
43
- n_fft=n_fft,
44
- win_length=win_length,
45
- hop_length=hop_length,
46
- center=center,
47
- f_min=f_min,
48
- f_max=f_max,
49
- mel_scale=mel_scale,
50
- )
51
- self.mel_rscale = torchaudio.transforms.InverseMelScale(
52
- n_stft=self.n_stft,
53
- n_mels=n_mels,
54
- sample_rate=sample_rate,
55
- f_min=f_min,
56
- f_max=f_max,
57
- mel_scale=mel_scale,
58
- )
59
- self.giffin_lim = torchaudio.transforms.GriffinLim(
60
- n_fft=n_fft,
61
- n_iter=n_iter,
62
- win_length=win_length,
63
- hop_length=hop_length,
64
- )
65
- self._inverse_transform = lambda x, y: inverse_transform(
66
- x, y, inv_n_fft, inv_hop, inv_n_fft
23
+ super().__init__()
24
+ assert max_frame_length is None or max_frame_length >= (
25
+ (audio_processor.n_fft // 2) + 1
67
26
  )
27
+ self.post_n_fft = (audio_processor.n_fft // 2) + 1
28
+ self.ap = audio_processor
29
+ self.files = self.ap.find_audios(path)
30
+ if limit_files:
31
+ random.shuffle(self.files)
32
+ self.files = self.files[:limit_files]
33
+ self.data = []
68
34
 
69
- def inverse_transform(self, spec: Tensor, phase: Tensor):
70
- return self._inverse_transform(spec, phase)
35
+ for file in self.files:
36
+ results = self.load_data(file, max_frame_length)
37
+ self.data.extend(results)
71
38
 
72
- def compute_mel(
73
- self,
74
- wave: Tensor,
75
- ) -> Tensor:
76
- """Returns: [B, M, ML]"""
77
- mel_tensor = self.mel_spec(wave) # [M, ML]
78
- mel_tensor = (mel_tensor - self.mean) / self.std
79
- return mel_tensor # [B, M, ML]
39
+ def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
40
+ return {"mel": audio_mel, "raw": audio_raw, "file": file}
80
41
 
81
- def reverse_mel(self, mel: Tensor, n_iter: Optional[int] = None):
82
- if isinstance(n_iter, int) and n_iter != self.n_iter:
83
- self.giffin_lim = torchaudio.transforms.GriffinLim(
84
- n_fft=self.n_fft,
85
- n_iter=n_iter,
86
- win_length=self.win_length,
87
- hop_length=self.hop_length,
88
- )
89
- self.n_iter = n_iter
90
- return self.giffin_lim.forward(
91
- self.mel_rscale(mel),
92
- )
42
+ def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
43
+ initial_audio = self.ap.load_audio(file)
44
+ if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
45
+ audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
46
+ return [self._add_dict(initial_audio, audio_mel, file)]
47
+ results = []
48
+ for fragment in torch.split(
49
+ initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
50
+ ):
51
+ if fragment.shape[-1] < self.post_n_fft:
52
+ # sometimes the tensor will be too small to be able to pass on mel
53
+ continue
54
+ audio_mel = self.ap.compute_mel(fragment, add_base=True)
55
+ results.append(self._add_dict(fragment, audio_mel, file))
56
+ return results
93
57
 
94
- def load_audio(
58
+ def get_data_loader(
95
59
  self,
96
- path: PathLike,
97
- top_db: float = 30,
98
- ) -> Tensor:
99
- is_file(path, True)
100
- wave, sr = librosa.load(str(path), sr=self.sample_rate)
101
- wave, _ = librosa.effects.trim(wave, top_db=top_db)
102
- return (
103
- torch.from_numpy(
104
- librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
105
- if sr != self.sample_rate
106
- else wave
107
- )
108
- .float()
109
- .unsqueeze(0)
60
+ batch_size: int = 1,
61
+ shuffle: Optional[bool] = None,
62
+ sampler: Optional[Union[Sampler, Iterable]] = None,
63
+ batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
64
+ num_workers: int = 0,
65
+ pin_memory: bool = False,
66
+ drop_last: bool = False,
67
+ timeout: float = 0,
68
+ ):
69
+ return DataLoader(
70
+ self,
71
+ batch_size=batch_size,
72
+ shuffle=shuffle,
73
+ sampler=sampler,
74
+ batch_sampler=batch_sampler,
75
+ num_workers=num_workers,
76
+ pin_memory=pin_memory,
77
+ drop_last=drop_last,
78
+ timeout=timeout,
79
+ collate_fn=self.collate_fn,
110
80
  )
81
+
82
+ @staticmethod
83
+ def collate_fn(batch: Sequence[Dict[str, Tensor]]):
84
+ mels = []
85
+ audios = []
86
+ files = []
87
+ for x in batch:
88
+ mels.append(x["mel"])
89
+ audios.append(x["raw"])
90
+ files.append(x["file"])
91
+ # Find max time in mel (dim -1), and max audio length
92
+ max_mel_len = max([m.shape[-1] for m in mels])
93
+ max_audio_len = max([a.shape[-1] for a in audios])
94
+
95
+ padded_mels = torch.stack(
96
+ [FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mels]
97
+ ) # shape: [B, 80, T_max]
98
+
99
+ padded_audios = torch.stack(
100
+ [FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in audios]
101
+ ) # shape: [B, L_max]
102
+
103
+ return padded_mels, padded_audios, files
104
+
105
+ def __len__(self):
106
+ return len(self.data)
107
+
108
+ def __getitem__(self, index):
109
+ return self.data[index]
lt_tensor/losses.py ADDED
@@ -0,0 +1,145 @@
1
+ __all__ = ["masked_cross_entropy"]
2
+ import math
3
+ import random
4
+ from .torch_commons import *
5
+ from lt_utils.common import *
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def masked_cross_entropy(
10
+ logits: torch.Tensor, # [B, T, V]
11
+ targets: torch.Tensor, # [B, T]
12
+ lengths: torch.Tensor, # [B]
13
+ reduction: str = "mean",
14
+ ) -> torch.Tensor:
15
+ """
16
+ CrossEntropyLoss with masking for variable-length sequences.
17
+ - logits: unnormalized scores [B, T, V]
18
+ - targets: ground truth indices [B, T]
19
+ - lengths: actual sequence lengths [B]
20
+ """
21
+ B, T, V = logits.size()
22
+ logits = logits.view(-1, V)
23
+ targets = targets.view(-1)
24
+
25
+ # Create mask
26
+ mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
27
+ mask = mask.reshape(-1)
28
+
29
+ # Apply CE only where mask == True
30
+ loss = F.cross_entropy(
31
+ logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
32
+ )
33
+ if reduction == "none":
34
+ return loss
35
+ return loss
36
+
37
+
38
+ def diff_loss(pred_noise, true_noise, mask=None):
39
+ """Standard diffusion noise-prediction loss (e.g., DDPM)"""
40
+ if mask is not None:
41
+ return F.mse_loss(pred_noise * mask, true_noise * mask)
42
+ return F.mse_loss(pred_noise, true_noise)
43
+
44
+
45
+ def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
46
+ """Combines L1 and L2"""
47
+ l1 = F.l1_loss(pred_noise, true_noise)
48
+ l2 = F.mse_loss(pred_noise, true_noise)
49
+ return alpha * l1 + (1 - alpha) * l2
50
+
51
+
52
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
53
+ loss = 0
54
+ for real, fake in zip(real_preds, fake_preds):
55
+ if use_lsgan:
56
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
57
+ fake, torch.zeros_like(fake)
58
+ )
59
+ else:
60
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
61
+ torch.log(1 - fake + 1e-7)
62
+ )
63
+ return loss
64
+
65
+
66
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
67
+ loss = 0
68
+ for real, fake in zip(real_preds, fake_preds):
69
+ if use_lsgan:
70
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
71
+ fake, torch.zeros_like(fake)
72
+ )
73
+ else:
74
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
75
+ torch.log(1 - fake + 1e-7)
76
+ )
77
+ return loss
78
+
79
+
80
+ def gan_g_loss(fake_preds, use_lsgan=True):
81
+ loss = 0
82
+ for fake in fake_preds:
83
+ if use_lsgan:
84
+ loss += F.mse_loss(fake, torch.ones_like(fake))
85
+ else:
86
+ loss += -torch.mean(torch.log(fake + 1e-7))
87
+ return loss
88
+
89
+
90
+ def feature_matching_loss(real_feats, fake_feats):
91
+ """real_feats and fake_feats are lists of intermediate features"""
92
+ loss = 0
93
+ for real_layers, fake_layers in zip(real_feats, fake_feats):
94
+ for r, f in zip(real_layers, fake_layers):
95
+ loss += F.l1_loss(f, r.detach())
96
+ return loss
97
+
98
+
99
+ def feature_loss(real_fmaps, fake_fmaps, weight=2.0):
100
+ loss = 0.0
101
+ for dr, dg in zip(real_fmaps, fake_fmaps): # Each (layer list from a discriminator)
102
+ for r_feat, g_feat in zip(dr, dg):
103
+ loss += F.l1_loss(r_feat, g_feat)
104
+ return loss * weight
105
+
106
+
107
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
108
+ loss = 0.0
109
+ r_losses = []
110
+ g_losses = []
111
+
112
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
113
+ r_loss = F.mse_loss(dr, torch.ones_like(dr))
114
+ g_loss = F.mse_loss(dg, torch.zeros_like(dg))
115
+ loss += r_loss + g_loss
116
+ r_losses.append(r_loss)
117
+ g_losses.append(g_loss)
118
+
119
+ return loss, r_losses, g_losses
120
+
121
+
122
+ def generator_loss(fake_outputs):
123
+ total = 0.0
124
+ g_losses = []
125
+ for out in fake_outputs:
126
+ loss = F.mse_loss(out, torch.ones_like(out))
127
+ g_losses.append(loss)
128
+ total += loss
129
+ return total, g_losses
130
+
131
+
132
+ def multi_resolution_stft_loss(y, y_hat, fft_sizes=[512, 1024, 2048]):
133
+ loss = 0
134
+ for fft_size in fft_sizes:
135
+ hop = fft_size // 4
136
+ win = fft_size
137
+ y_stft = torch.stft(
138
+ y, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
139
+ )
140
+ y_hat_stft = torch.stft(
141
+ y_hat, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
142
+ )
143
+
144
+ loss += F.l1_loss(torch.abs(y_stft), torch.abs(y_hat_stft))
145
+ return loss
lt_tensor/math_ops.py CHANGED
@@ -8,6 +8,7 @@ __all__ = [
8
8
  "dot_product",
9
9
  "normalize_tensor",
10
10
  "log_magnitude",
11
+ "shift_time",
11
12
  "phase",
12
13
  ]
13
14
 
@@ -50,6 +51,11 @@ def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
50
51
  return torch.roll(x, shifts=1, dims=dim)
51
52
 
52
53
 
54
+ def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
55
+ """Shifts tensor along time axis (last dim)."""
56
+ return torch.roll(x, shifts=shift, dims=-1)
57
+
58
+
53
59
  def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
54
60
  """Computes dot product along the specified dimension."""
55
61
  return torch.sum(x * y, dim=dim)
@@ -69,3 +75,4 @@ def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
69
75
  def phase(stft_complex: Tensor) -> Tensor:
70
76
  """Returns phase from complex STFT."""
71
77
  return torch.angle(stft_complex)
78
+
lt_tensor/misc_utils.py CHANGED
@@ -8,7 +8,6 @@ __all__ = [
8
8
  "unfreeze_selected_weights",
9
9
  "clip_gradients",
10
10
  "detach_hidden",
11
- "tensor_summary",
12
11
  "one_hot",
13
12
  "safe_divide",
14
13
  "batch_pad",
@@ -18,22 +17,20 @@ __all__ = [
18
17
  "default_device",
19
18
  "Packing",
20
19
  "Padding",
21
- "MaskUtils",
22
- "masked_cross_entropy",
23
- "NoiseScheduler",
20
+ "Masking",
24
21
  ]
25
22
 
26
23
  import gc
24
+ import sys
27
25
  import random
28
26
  import numpy as np
29
27
  from lt_utils.type_utils import is_str
30
28
  from .torch_commons import *
31
- from lt_utils.misc_utils import log_traceback, cache_wrapper
32
- from lt_utils.file_ops import load_json, load_yaml, save_json, save_yaml
33
- import math
29
+ from lt_utils.misc_utils import cache_wrapper
34
30
  from lt_utils.common import *
35
31
  import torch.nn.functional as F
36
32
 
33
+
37
34
  def log_tensor(
38
35
  item: Union[Tensor, np.ndarray],
39
36
  title: Optional[str] = None,
@@ -64,10 +61,13 @@ def log_tensor(
64
61
  print(f"mean: {item.mean(dim=dim):.4f}")
65
62
  except:
66
63
  pass
67
- if print_tensor:
68
- print(item)
64
+ if print_tensor:
65
+ print(item)
69
66
  if has_title:
70
67
  print("".join(["-"] * _b), "\n")
68
+ else:
69
+ print("\n")
70
+ sys.stdout.flush()
71
71
 
72
72
 
73
73
  def set_seed(seed: int):
@@ -136,11 +136,6 @@ def detach_hidden(hidden):
136
136
  return tuple(detach_hidden(h) for h in hidden)
137
137
 
138
138
 
139
- def tensor_summary(tensor: torch.Tensor) -> str:
140
- """Prints min/max/mean/std of a tensor for debugging."""
141
- return f"Shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}, min: {tensor.min():.4f}, max: {tensor.max():.4f}, mean: {tensor.mean():.4f}, std: {tensor.std():.4f}"
142
-
143
-
144
139
  def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
145
140
  """One-hot encodes a tensor of labels."""
146
141
  return F.one_hot(labels, num_classes).float()
@@ -463,7 +458,7 @@ class Padding:
463
458
  return torch.stack(padded), lengths
464
459
 
465
460
 
466
- class MaskUtils:
461
+ class Masking:
467
462
 
468
463
  @staticmethod
469
464
  def apply_mask(x: Tensor, mask: Tensor, fill_value: Number = 0) -> Tensor:
@@ -546,84 +541,3 @@ class MaskUtils:
546
541
  return (
547
542
  causal_mask & pad_mask.unsqueeze(1).expand(-1, pad_mask.size(1), -1).bool()
548
543
  )
549
-
550
-
551
- def masked_cross_entropy(
552
- logits: torch.Tensor, # [B, T, V]
553
- targets: torch.Tensor, # [B, T]
554
- lengths: torch.Tensor, # [B]
555
- reduction: str = "mean",
556
- ) -> torch.Tensor:
557
- """
558
- CrossEntropyLoss with masking for variable-length sequences.
559
- - logits: unnormalized scores [B, T, V]
560
- - targets: ground truth indices [B, T]
561
- - lengths: actual sequence lengths [B]
562
- """
563
- B, T, V = logits.size()
564
- logits = logits.view(-1, V)
565
- targets = targets.view(-1)
566
-
567
- # Create mask
568
- mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
569
- mask = mask.reshape(-1)
570
-
571
- # Apply CE only where mask == True
572
- loss = F.cross_entropy(
573
- logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
574
- )
575
- if reduction == "none":
576
- return loss
577
- return loss
578
-
579
-
580
- class NoiseScheduler(nn.Module):
581
- def __init__(self, timesteps: int = 512):
582
- super().__init__()
583
-
584
- betas = torch.linspace(1e-4, 0.02, timesteps)
585
- alphas = 1.0 - betas
586
- alpha_cumprod = torch.cumprod(alphas, dim=0)
587
-
588
- self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
589
- self.register_buffer(
590
- "sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
591
- )
592
-
593
- self.timesteps = timesteps
594
- self.default_noise = math.sqrt(1.25)
595
-
596
- def get_random_noise(
597
- self, min_max: Tuple[float, float] = (-3, 3), seed: int = 0
598
- ) -> float:
599
- if seed > 0:
600
- random.seed(seed)
601
- return random.uniform(*min_max)
602
-
603
- def set_noise(
604
- self,
605
- seed: int = 0,
606
- min_max: Tuple[float, float] = (-3, 3),
607
- default: bool = False,
608
- ):
609
- self.default_noise = (
610
- math.sqrt(1.25) if default else self.get_random_noise(min_max, seed)
611
- )
612
-
613
- def forward(
614
- self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
615
- ) -> Tensor:
616
- if t < 0 or t >= self.timesteps:
617
- raise ValueError(
618
- f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
619
- )
620
-
621
- if noise is None:
622
- noise = self.default_noise
623
-
624
- if isinstance(noise, (float, int)):
625
- noise = torch.randn_like(x_0) * noise
626
-
627
- alpha_term = self.sqrt_alpha_cumprod[t] * x_0
628
- noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
629
- return alpha_term + noise_term