lt-tensor 0.0.1a6__py3-none-any.whl → 0.0.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -10,6 +10,8 @@ from . import (
10
10
  transform,
11
11
  noise_tools,
12
12
  losses,
13
+ processors,
14
+ datasets,
13
15
  )
14
16
 
15
17
  __all__ = [
@@ -22,4 +24,6 @@ __all__ = [
22
24
  "lr_schedulers",
23
25
  "noise_tools",
24
26
  "losses",
27
+ "processors",
28
+ "datasets",
25
29
  ]
@@ -1,154 +1,109 @@
1
- __all__ = ["AudioProcessor"]
1
+ __all__ = ["WaveMelDatasets"]
2
2
  from ..torch_commons import *
3
- import torchaudio
4
3
  from lt_utils.common import *
5
- import librosa
6
- from lt_utils.type_utils import is_file, is_array
7
- from torchaudio.functional import resample
8
- from ..transform import inverse_transform
9
- from lt_utils.file_ops import FileScan, load_text, get_file_name
4
+ import random
5
+ from torch.utils.data import Dataset, DataLoader, Sampler
6
+ from ..processors import AudioProcessor
7
+ import torch.nn.functional as FT
8
+ from ..misc_utils import log_tensor
10
9
 
11
10
 
12
- class AudioProcessor:
11
+ class WaveMelDataset(Dataset):
12
+ """Untested!"""
13
+
14
+ data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
13
15
 
14
16
  def __init__(
15
17
  self,
16
- sample_rate: int = 24000,
17
- n_mels: int = 80,
18
- n_fft: int = 1024,
19
- win_length: int = 1024,
20
- hop_length: int = 256,
21
- f_min: float = 0,
22
- f_max: float | None = None,
23
- n_iter: int = 32,
24
- center: bool = True,
25
- mel_scale: Literal["htk", "slaney"] = "htk",
26
- inv_n_fft: int = 16,
27
- inv_hop: int = 4,
28
- std: int = 4,
29
- mean: int = -4,
18
+ audio_processor: AudioProcessor,
19
+ path: PathLike,
20
+ limit_files: Optional[int] = None,
21
+ max_frame_length: Optional[int] = None,
30
22
  ):
31
- self.mean = mean
32
- self.std = std
33
- self.n_mels = n_mels
34
- self.n_fft = n_fft
35
- self.n_stft = n_fft // 2 + 1
36
- self.f_min = f_min
37
- self.f_max = f_max
38
- self.n_iter = n_iter
39
- self.hop_length = hop_length
40
- self.sample_rate = sample_rate
41
- self.mel_spec = torchaudio.transforms.MelSpectrogram(
42
- sample_rate=sample_rate,
43
- n_mels=n_mels,
44
- n_fft=n_fft,
45
- win_length=win_length,
46
- hop_length=hop_length,
47
- center=center,
48
- f_min=f_min,
49
- f_max=f_max,
50
- mel_scale=mel_scale,
51
- )
52
- self.mel_rscale = torchaudio.transforms.InverseMelScale(
53
- n_stft=self.n_stft,
54
- n_mels=n_mels,
55
- sample_rate=sample_rate,
56
- f_min=f_min,
57
- f_max=f_max,
58
- mel_scale=mel_scale,
59
- )
60
- self.giffin_lim = torchaudio.transforms.GriffinLim(
61
- n_fft=n_fft,
62
- n_iter=n_iter,
63
- win_length=win_length,
64
- hop_length=hop_length,
65
- )
66
- self._inverse_transform = lambda x, y: inverse_transform(
67
- x, y, inv_n_fft, inv_hop, inv_n_fft
23
+ super().__init__()
24
+ assert max_frame_length is None or max_frame_length >= (
25
+ (audio_processor.n_fft // 2) + 1
68
26
  )
27
+ self.post_n_fft = (audio_processor.n_fft // 2) + 1
28
+ self.ap = audio_processor
29
+ self.files = self.ap.find_audios(path)
30
+ if limit_files:
31
+ random.shuffle(self.files)
32
+ self.files = self.files[:limit_files]
33
+ self.data = []
69
34
 
70
- def inverse_transform(self, spec: Tensor, phase: Tensor):
71
- return self._inverse_transform(spec, phase)
35
+ for file in self.files:
36
+ results = self.load_data(file, max_frame_length)
37
+ self.data.extend(results)
72
38
 
73
- def compute_mel(
74
- self,
75
- wave: Tensor,
76
- ) -> Tensor:
77
- """Returns: [B, M, ML]"""
78
- mel_tensor = self.mel_spec(wave) # [M, ML]
79
- mel_tensor = (mel_tensor - self.mean) / self.std
80
- return mel_tensor # [B, M, ML]
39
+ def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
40
+ return {"mel": audio_mel, "raw": audio_raw, "file": file}
81
41
 
82
- def reverse_mel(self, mel: Tensor, n_iter: Optional[int] = None):
83
- if isinstance(n_iter, int) and n_iter != self.n_iter:
84
- self.giffin_lim = torchaudio.transforms.GriffinLim(
85
- n_fft=self.n_fft,
86
- n_iter=n_iter,
87
- win_length=self.win_length,
88
- hop_length=self.hop_length,
89
- )
90
- self.n_iter = n_iter
91
- return self.giffin_lim.forward(
92
- self.mel_rscale(mel),
93
- )
42
+ def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
43
+ initial_audio = self.ap.load_audio(file)
44
+ if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
45
+ audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
46
+ return [self._add_dict(initial_audio, audio_mel, file)]
47
+ results = []
48
+ for fragment in torch.split(
49
+ initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
50
+ ):
51
+ if fragment.shape[-1] < self.post_n_fft:
52
+ # sometimes the tensor will be too small to be able to pass on mel
53
+ continue
54
+ audio_mel = self.ap.compute_mel(fragment, add_base=True)
55
+ results.append(self._add_dict(fragment, audio_mel, file))
56
+ return results
94
57
 
95
- def load_audio(
58
+ def get_data_loader(
96
59
  self,
97
- path: PathLike,
98
- top_db: float = 30,
99
- ) -> Tensor:
100
- is_file(path, True)
101
- wave, sr = librosa.load(str(path), sr=self.sample_rate)
102
- wave, _ = librosa.effects.trim(wave, top_db=top_db)
103
- return (
104
- torch.from_numpy(
105
- librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
106
- if sr != self.sample_rate
107
- else wave
108
- )
109
- .float()
110
- .unsqueeze(0)
60
+ batch_size: int = 1,
61
+ shuffle: Optional[bool] = None,
62
+ sampler: Optional[Union[Sampler, Iterable]] = None,
63
+ batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
64
+ num_workers: int = 0,
65
+ pin_memory: bool = False,
66
+ drop_last: bool = False,
67
+ timeout: float = 0,
68
+ ):
69
+ return DataLoader(
70
+ self,
71
+ batch_size=batch_size,
72
+ shuffle=shuffle,
73
+ sampler=sampler,
74
+ batch_sampler=batch_sampler,
75
+ num_workers=num_workers,
76
+ pin_memory=pin_memory,
77
+ drop_last=drop_last,
78
+ timeout=timeout,
79
+ collate_fn=self.collate_fn,
111
80
  )
112
81
 
113
- def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
114
- extensions = [
115
- "*.wav",
116
- "*.aac",
117
- "*.m4a",
118
- "*.mp3",
119
- "*.ogg",
120
- "*.opus",
121
- "*.flac",
122
- ]
123
- extensions.extend(
124
- [x for x in additional_extensions if isinstance(x, str) and "*" in x]
125
- )
126
- return FileScan.files(
127
- path,
128
- extensions,
129
- )
82
+ @staticmethod
83
+ def collate_fn(batch: Sequence[Dict[str, Tensor]]):
84
+ mels = []
85
+ audios = []
86
+ files = []
87
+ for x in batch:
88
+ mels.append(x["mel"])
89
+ audios.append(x["raw"])
90
+ files.append(x["file"])
91
+ # Find max time in mel (dim -1), and max audio length
92
+ max_mel_len = max([m.shape[-1] for m in mels])
93
+ max_audio_len = max([a.shape[-1] for a in audios])
130
94
 
131
- def find_audio_text_pairs(
132
- self,
133
- path,
134
- additional_extensions: List[str] = [],
135
- text_file_patterns: List[str] = [".normalized.txt", ".original.txt"],
136
- ):
137
- is_array(text_file_patterns, True, validate=True) # Rases if empty or not valid
138
- additional_extensions = [
139
- x
140
- for x in additional_extensions
141
- if isinstance(x, str)
142
- and "*" in x
143
- and not any(list(map(lambda y: y in x), text_file_patterns))
144
- ]
145
- audio_files = self.find_audios(path, additional_extensions)
146
- text_files = []
147
- for audio in audio_files:
148
- base_audio_dir = Path(audio).parent
149
- audio_name = get_file_name(audio, False)
150
- for pattern in text_file_patterns:
151
- possible_txt_file = Path(base_audio_dir, audio_name + pattern)
152
- if is_file(possible_txt_file):
153
- text_files.append(audio)
154
- return audio_files, text_files
95
+ padded_mels = torch.stack(
96
+ [FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mels]
97
+ ) # shape: [B, 80, T_max]
98
+
99
+ padded_audios = torch.stack(
100
+ [FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in audios]
101
+ ) # shape: [B, L_max]
102
+
103
+ return padded_mels, padded_audios, files
104
+
105
+ def __len__(self):
106
+ return len(self.data)
107
+
108
+ def __getitem__(self, index):
109
+ return self.data[index]
lt_tensor/model_base.py CHANGED
@@ -41,15 +41,15 @@ class Model(nn.Module, ABC):
41
41
  def device(self, device: Union[torch.device, str]):
42
42
  assert isinstance(device, (str, torch.device))
43
43
  self._device = torch.device(device) if isinstance(device, str) else device
44
- self.tp_apply_device_to()
44
+ self._apply_device_to()
45
45
 
46
- def tp_apply_device_to(self):
46
+ def _apply_device_to(self):
47
47
  """Add here components that are needed to have device applied to them,
48
48
  that usually the '.to()' function fails to apply
49
49
 
50
50
  example:
51
51
  ```
52
- def tp_apply_device_to(self):
52
+ def _apply_device_to(self):
53
53
  self.my_tensor = self.my_tensor.to(device=self.device)
54
54
  ```
55
55
  """
lt_tensor/noise_tools.py CHANGED
@@ -128,7 +128,7 @@ def apply_noise(
128
128
  on_error != "raise"
129
129
  ), f"Noise '{noise_type}' is not supported for {x.ndim}D input."
130
130
  if on_error == "return_unchanged":
131
- return x
131
+ return x, None
132
132
  elif on_error == "try_others":
133
133
  remaining = [
134
134
  n
@@ -136,27 +136,33 @@ def apply_noise(
136
136
  if n not in last_tries and x.ndim in _NOISE_DIM_SUPPORT[n]
137
137
  ]
138
138
  if not remaining:
139
- return x
139
+ return x, None
140
140
  new_type = random.choice(remaining)
141
141
  last_tries.append(new_type)
142
- return apply_noise(
143
- x, new_type, noise_level, seed, on_error, last_tries.copy()
142
+ return (
143
+ apply_noise(
144
+ x, new_type, noise_level, seed, on_error, last_tries.copy()
145
+ ),
146
+ noise_type,
144
147
  )
145
148
  try:
146
149
  if isinstance(seed, int):
147
150
  set_seed(seed)
148
- return _NOISE_MAP[noise_type](x, noise_level)
151
+ return _NOISE_MAP[noise_type](x, noise_level), noise_type
149
152
  except Exception as e:
150
153
  if on_error == "raise":
151
154
  raise e
152
155
  elif on_error == "return_unchanged":
153
- return x
156
+ return x, None
154
157
  if len(last_tries) == len(_VALID_NOISES):
155
- return x
158
+ return x, None
156
159
  remaining = [n for n in _VALID_NOISES if n not in last_tries]
157
160
  new_type = random.choice(remaining)
158
161
  last_tries.append(new_type)
159
- return apply_noise(x, new_type, noise_level, seed, on_error, last_tries.copy())
162
+ return (
163
+ apply_noise(x, new_type, noise_level, seed, on_error, last_tries.copy()),
164
+ noise_type,
165
+ )
160
166
 
161
167
 
162
168
  class NoiseSchedulerA(nn.Module):
@@ -0,0 +1,3 @@
1
+ from .audio import AudioProcessor
2
+
3
+ __all__ = ["AudioProcessor"]
@@ -0,0 +1,193 @@
1
+ __all__ = ["AudioProcessor"]
2
+ from ..torch_commons import *
3
+ from lt_utils.common import *
4
+ from lt_utils.type_utils import is_file, is_array
5
+ from ..misc_utils import log_tensor
6
+ import librosa
7
+ import torchaudio
8
+
9
+ from ..transform import InverseTransformConfig, InverseTransform
10
+ from lt_utils.file_ops import FileScan, get_file_name, path_to_str
11
+
12
+ from lt_tensor.model_base import Model
13
+
14
+
15
+ class AudioProcessor(Model):
16
+ def __init__(
17
+ self,
18
+ sample_rate: int = 24000,
19
+ n_mels: int = 80,
20
+ n_fft: int = 1024,
21
+ win_length: Optional[int] = None,
22
+ hop_length: Optional[int] = None,
23
+ f_min: float = 0,
24
+ f_max: float | None = None,
25
+ n_iter: int = 32,
26
+ center: bool = True,
27
+ mel_scale: Literal["htk", "slaney"] = "htk",
28
+ std: int = 4,
29
+ mean: int = -4,
30
+ inverse_transform_config: Union[
31
+ Dict[str, Union[Number, Tensor, bool]], InverseTransformConfig
32
+ ] = dict(n_fft=16, hop_length=4, win_length=16, center=True),
33
+ *__,
34
+ **_,
35
+ ):
36
+ super().__init__()
37
+ assert isinstance(inverse_transform_config, (InverseTransformConfig, dict))
38
+ self.mean = mean
39
+ self.std = std
40
+ self.n_mels = n_mels
41
+ self.n_fft = n_fft
42
+ self.n_stft = n_fft // 2 + 1
43
+ self.f_min = f_min
44
+ self.f_max = f_max
45
+ self.n_iter = n_iter
46
+ self.hop_length = hop_length or n_fft // 4
47
+ self.win_length = win_length or n_fft
48
+ self.sample_rate = sample_rate
49
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
50
+ sample_rate=sample_rate,
51
+ n_mels=n_mels,
52
+ n_fft=n_fft,
53
+ win_length=win_length,
54
+ hop_length=hop_length,
55
+ center=center,
56
+ f_min=f_min,
57
+ f_max=f_max,
58
+ mel_scale=mel_scale,
59
+ )
60
+ self.mel_rscale = torchaudio.transforms.InverseMelScale(
61
+ n_stft=self.n_stft,
62
+ n_mels=n_mels,
63
+ sample_rate=sample_rate,
64
+ f_min=f_min,
65
+ f_max=f_max,
66
+ mel_scale=mel_scale,
67
+ )
68
+ self.giffin_lim = torchaudio.transforms.GriffinLim(
69
+ n_fft=n_fft,
70
+ n_iter=n_iter,
71
+ win_length=win_length,
72
+ hop_length=hop_length,
73
+ )
74
+ if isinstance(inverse_transform_config, dict):
75
+ inverse_transform_config = InverseTransformConfig(
76
+ **inverse_transform_config
77
+ )
78
+ self._inv_transform = InverseTransform(**inverse_transform_config.to_dict())
79
+
80
+ def inverse_transform(self, spec: Tensor, phase: Tensor, *_, **kwargs):
81
+ return self._inv_transform(spec, phase, **kwargs)
82
+
83
+ def compute_mel(
84
+ self, wave: Tensor, base: float = 1e-6, add_base: bool = False
85
+ ) -> Tensor:
86
+ """Returns: [B, M, ML]"""
87
+ wave_device = wave.device
88
+ mel_tensor = self.mel_spec(wave.to(self.device)) # [M, ML]
89
+ if not add_base:
90
+ return (mel_tensor - self.mean) / self.std
91
+ return ((torch.log(base + mel_tensor.unsqueeze(0)) - self.mean) / self.std).to(
92
+ device=wave_device
93
+ )
94
+
95
+ def reverse_mel(self, mel: Tensor, n_iter: Optional[int] = None):
96
+ if isinstance(n_iter, int) and n_iter != self.n_iter:
97
+ self.giffin_lim = torchaudio.transforms.GriffinLim(
98
+ n_fft=self.n_fft,
99
+ n_iter=n_iter,
100
+ win_length=self.win_length,
101
+ hop_length=self.hop_length,
102
+ )
103
+ self.n_iter = n_iter
104
+ return self.giffin_lim.forward(
105
+ self.mel_rscale(mel),
106
+ )
107
+
108
+ def load_audio(
109
+ self,
110
+ path: PathLike,
111
+ top_db: float = 30,
112
+ ) -> Tensor:
113
+ is_file(path, True)
114
+ wave, sr = librosa.load(str(path), sr=self.sample_rate)
115
+ wave, _ = librosa.effects.trim(wave, top_db=top_db)
116
+ return (
117
+ torch.from_numpy(
118
+ librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
119
+ if sr != self.sample_rate
120
+ else wave
121
+ )
122
+ .float()
123
+ .unsqueeze(0)
124
+ )
125
+
126
+ def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
127
+ extensions = [
128
+ "*.wav",
129
+ "*.aac",
130
+ "*.m4a",
131
+ "*.mp3",
132
+ "*.ogg",
133
+ "*.opus",
134
+ "*.flac",
135
+ ]
136
+ extensions.extend(
137
+ [x for x in additional_extensions if isinstance(x, str) and "*" in x]
138
+ )
139
+ return FileScan.files(
140
+ path,
141
+ extensions,
142
+ )
143
+
144
+ def find_audio_text_pairs(
145
+ self,
146
+ path,
147
+ additional_extensions: List[str] = [],
148
+ text_file_patterns: List[str] = [".normalized.txt", ".original.txt"],
149
+ ):
150
+ is_array(text_file_patterns, True, validate=True) # Rases if empty or not valid
151
+ additional_extensions = [
152
+ x
153
+ for x in additional_extensions
154
+ if isinstance(x, str)
155
+ and "*" in x
156
+ and not any(list(map(lambda y: y in x), text_file_patterns))
157
+ ]
158
+ audio_files = self.find_audios(path, additional_extensions)
159
+ results = []
160
+ for audio in audio_files:
161
+ base_audio_dir = Path(audio).parent
162
+ audio_name = get_file_name(audio, False)
163
+ for pattern in text_file_patterns:
164
+ possible_txt_file = Path(base_audio_dir, audio_name + pattern)
165
+ if is_file(possible_txt_file):
166
+ results.append((audio, path_to_str(possible_txt_file)))
167
+ break
168
+ return results
169
+
170
+ def stft_loss(self, signal: Tensor, ground: Tensor, base: float = 1e-5):
171
+ sig_mel = self(signal, base)
172
+ gnd_mel = self(ground, base)
173
+ return torch.norm(gnd_mel - sig_mel, p=1) / torch.norm(gnd_mel, p=1)
174
+
175
+ # def forward(self, wave: Tensor, base: Optional[float] = None):
176
+ def forward(
177
+ self,
178
+ *inputs: Union[Tensor, float],
179
+ ap_task: Literal[
180
+ "get_mel", "get_loss", "inv_transform", "revert_mel"
181
+ ] = "get_mel",
182
+ **inputs_kwargs,
183
+ ):
184
+ if ap_task == "get_mel":
185
+ return self.compute_mel(*inputs, **inputs_kwargs)
186
+ elif ap_task == "get_loss":
187
+ return self.stft_loss(*inputs, **inputs_kwargs)
188
+ elif ap_task == "inv_transform":
189
+ return self.inverse_transform(*inputs, **inputs_kwargs)
190
+ elif ap_task == "revert_mel":
191
+ return self.reverse_mel(*inputs, **inputs_kwargs)
192
+ else:
193
+ raise ValueError(f"Invalid task '{ap_task}'")
lt_tensor/transform.py CHANGED
@@ -15,6 +15,8 @@ __all__ = [
15
15
  "normalize",
16
16
  "window_sumsquare",
17
17
  "inverse_transform",
18
+ "InverseTransformConfig",
19
+ "InverseTransform",
18
20
  "stft_istft_rebuild",
19
21
  ]
20
22
 
@@ -23,12 +25,15 @@ import torchaudio
23
25
  import math
24
26
  from .misc_utils import log_tensor
25
27
  from lt_utils.common import *
28
+ from lt_utils.misc_utils import cache_wrapper, default
26
29
  import torch.nn.functional as F
30
+ from .model_base import Model
31
+ import warnings
27
32
 
28
33
 
29
34
  def to_mel_spectrogram(
30
35
  waveform: torch.Tensor,
31
- sample_rate: int = 22050,
36
+ sample_rate: int = 24000,
32
37
  n_fft: int = 1024,
33
38
  hop_length: Optional[int] = None,
34
39
  win_length: Optional[int] = None,
@@ -198,8 +203,7 @@ def generate_window(
198
203
  return torch.ones(1, device=device)
199
204
 
200
205
  n = torch.arange(M, dtype=torch.float32, device=device)
201
- window = alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
202
- return window
206
+ return alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
203
207
 
204
208
 
205
209
  def pad_center(tensor: torch.Tensor, size: int, axis: int = -1) -> torch.Tensor:
@@ -260,8 +264,8 @@ def normalize(
260
264
  def window_sumsquare(
261
265
  window_spec: Union[str, int, float, Callable, List[Any], Tuple[Any, ...]],
262
266
  n_frames: int,
263
- hop_length: int = 300,
264
- win_length: int = 1200,
267
+ hop_length: int = 256,
268
+ win_length: int = 1024,
265
269
  n_fft: int = 2048,
266
270
  dtype: torch.dtype = torch.float32,
267
271
  norm: Optional[Union[int, float]] = None,
@@ -294,9 +298,9 @@ def window_sumsquare(
294
298
  def inverse_transform(
295
299
  spec: Tensor,
296
300
  phase: Tensor,
297
- n_fft: int = 2048,
298
- hop_length: int = 300,
299
- win_length: int = 1200,
301
+ n_fft: int = 1024,
302
+ hop_length: Optional[int] = None,
303
+ win_length: Optional[int] = None,
300
304
  length: Optional[Any] = None,
301
305
  window: Optional[Tensor] = None,
302
306
  ):
@@ -310,3 +314,168 @@ def inverse_transform(
310
314
  window=window,
311
315
  length=length,
312
316
  )
317
+
318
+
319
+ def is_nand(a: bool, b: bool):
320
+ """[a -> b = result]
321
+ ```
322
+ False -> False = True
323
+ False -> True = True
324
+ True -> False = True
325
+ True -> True = False
326
+ ```
327
+ """
328
+ return not (a and b)
329
+
330
+
331
+ class InverseTransformConfig:
332
+ def __init__(
333
+ self,
334
+ n_fft: int = 1024,
335
+ hop_length: Optional[int] = None,
336
+ win_length: Optional[int] = None,
337
+ length: Optional[int] = None,
338
+ window: Optional[Tensor] = None,
339
+ onesided: Optional[bool] = None,
340
+ return_complex: bool = False,
341
+ normalized: bool = False,
342
+ center: bool = True,
343
+ ):
344
+ self.n_fft = n_fft
345
+ self.hop_length = hop_length
346
+ self.win_length = win_length
347
+ self.length = length
348
+ self.onesided = onesided
349
+ self.return_complex = return_complex
350
+ self.normalized = normalized
351
+ self.center = center
352
+ self.window = window
353
+
354
+ def to_dict(self):
355
+ return self.__dict__.copy()
356
+
357
+
358
+ class InverseTransform(Model):
359
+ def __init__(
360
+ self,
361
+ n_fft: int = 1024,
362
+ hop_length: Optional[int] = None,
363
+ win_length: Optional[int] = None,
364
+ length: Optional[int] = None,
365
+ window: Optional[Tensor] = None,
366
+ onesided: Optional[bool] = None,
367
+ return_complex: bool = False,
368
+ normalized: bool = False,
369
+ center: bool = True,
370
+ ):
371
+ """
372
+ Module for inverting a magnitude + phase spectrogram to a waveform using ISTFT.
373
+
374
+ This class encapsulates common ISTFT parameters at initialization and applies
375
+ the inverse transformation in the `forward()` method with minimal per-call overhead.
376
+
377
+ Parameters
378
+ ----------
379
+ n_fft : int, optional
380
+ Size of FFT to use during inversion. Default is 1024.
381
+ hop_length : int, optional
382
+ Number of audio samples between STFT columns. Defaults to `n_fft`.
383
+ win_length : int, optional
384
+ Size of the window function. Defaults to `n_fft // 4`.
385
+ length : int, optional
386
+ Output waveform length. If not provided, length will be inferred.
387
+ window : Tensor, optional
388
+ Custom window tensor. If None, a Hann window is used.
389
+ onesided : bool, optional
390
+ Whether the input STFT was onesided. Required only for consistency checks.
391
+ return_complex : bool, default=False
392
+ Must be False if `onesided` is True. Not used internally.
393
+ normalized : bool, default=False
394
+ Whether the STFT was normalized.
395
+ center : bool, default=True
396
+ Whether the signal was padded during STFT.
397
+
398
+ Methods
399
+ -------
400
+ forward(spec, phase)
401
+ Applies ISTFT using stored settings on the given magnitude and phase tensors.
402
+ update_settings(...)
403
+ Updates ISTFT parameters dynamically (used internally during forward).
404
+ """
405
+ super().__init__()
406
+ assert window is None or isinstance(window, Tensor)
407
+ assert any(
408
+ (
409
+ (not return_complex and not onesided),
410
+ (not onesided and return_complex),
411
+ (not return_complex and onesided),
412
+ )
413
+ )
414
+ self.n_fft = n_fft
415
+ self.length = length
416
+ self.win_length = win_length or n_fft // 4
417
+ self.hop_length = hop_length or n_fft
418
+ self.center = center // 4
419
+ self.return_complex = return_complex
420
+ self.onesided = onesided
421
+ self.normalized = normalized
422
+ self.window = torch.hann_window(win_length) if window is None else window
423
+
424
+ def _apply_device_to(self):
425
+ """Applies to device while used with module `Model`"""
426
+ self.window = self.window.to(device=self.device)
427
+
428
+ def update_settings(
429
+ self,
430
+ *,
431
+ n_fft: Optional[int] = None,
432
+ hop_length: Optional[int] = None,
433
+ win_length: Optional[int] = None,
434
+ length: Optional[int] = None,
435
+ window: Optional[Tensor] = None,
436
+ onesided: Optional[bool] = None,
437
+ return_complex: Optional[bool] = None,
438
+ center: Optional[bool] = None,
439
+ normalized: Optional[bool] = None,
440
+ **_,
441
+ ):
442
+
443
+ self.kwargs = dict(
444
+ n_fft=default(n_fft, self.n_fft),
445
+ hop_length=default(hop_length, self.hop_length),
446
+ win_length=default(win_length, self.win_length),
447
+ length=default(length, self.length),
448
+ window=default(window, self.window),
449
+ onesided=default(onesided, self.onesided),
450
+ return_complex=default(return_complex, self.return_complex),
451
+ center=default(center, self.center),
452
+ normalized=default(normalized, self.normalized),
453
+ )
454
+ if self.kwargs["onesided"] and self.kwargs["return_complex"]:
455
+ warnings.warn(
456
+ "You cannot use return_complex with `onesided` enabled. `return_complex` is set to False."
457
+ )
458
+ self.kwargs["return_complex"] = False
459
+
460
+ def forward(self, spec: Tensor, phase: Tensor, **kwargs):
461
+ """
462
+ Perform the inverse short-time Fourier transform.
463
+
464
+ Parameters
465
+ ----------
466
+ spec : Tensor
467
+ Magnitude spectrogram of shape (batch, freq, time).
468
+ phase : Tensor
469
+ Phase angles tensor, same shape as `spec`, in radians.
470
+ **kwargs : dict, optional
471
+ Optional ISTFT override parameters (same as in `update_settings`).
472
+
473
+ Returns
474
+ -------
475
+ Tensor
476
+ Time-domain waveform reconstructed from `spec` and `phase`.
477
+ """
478
+ if kwargs:
479
+ self.update_settings(**kwargs)
480
+
481
+ return torch.istft(spec * torch.exp(phase * 1j), **self.kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a6
3
+ Version: 0.0.1a7
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.1.a3
20
+ Requires-Dist: lt-utils==0.0.1
21
21
  Requires-Dist: librosa>=0.11.0
22
22
  Dynamic: author
23
23
  Dynamic: classifier
@@ -1,15 +1,15 @@
1
- lt_tensor/__init__.py,sha256=D-oEjsuKWhtk1qyiADERgNO78aRCXUJJz0hs65h8LOg,365
1
+ lt_tensor/__init__.py,sha256=uwJ7uiO18VYj8Z1V4KSOQ3ZrnowSgJWKCIiFBrzLMOI,429
2
2
  lt_tensor/losses.py,sha256=TinZJP2ypZ7Tdg6d9nnFWFkPyormfgQ0Z9P2ER3sqzE,4341
3
3
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
4
4
  lt_tensor/math_ops.py,sha256=ewIYkvxIy_Lab_9ExjFUgLs-oYLOu8IRRDo7f1pn3i8,2248
5
5
  lt_tensor/misc_utils.py,sha256=sjWUkUaHFhaCdN4rZ6X-cQDbPieimfKchKq9VtjiwEA,17029
6
- lt_tensor/model_base.py,sha256=2W4m6hlvMyfRx1efWJ0NIIwctzLjL4rip208vL9_n0Y,13419
6
+ lt_tensor/model_base.py,sha256=8qN7oklALFanOz-eqVzdnB9RD2kN_3ltynSMAPOl-TI,13413
7
7
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
8
- lt_tensor/noise_tools.py,sha256=O4oq5oi0jLJuQNIuxOBZa-rB0S065QXtb1gjQUXVaLs,11212
8
+ lt_tensor/noise_tools.py,sha256=JkWw0-bCMRNNMShwXKKt5KbO3104tvNiBePt-ThPkEo,11366
9
9
  lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
10
- lt_tensor/transform.py,sha256=hqsP6nXRn4nqMGkN2hBi4y-kHxEQdlIUS0y89Y1mjVI,8589
10
+ lt_tensor/transform.py,sha256=va4bQjpfhH-tnaBDvJZpmYmfg9zwn5_Y6pPOoTswS-U,14471
11
11
  lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- lt_tensor/datasets/audio.py,sha256=frftmRYNk0eXqHEiFggC46RMuCoGyuwBAlnPxfFsS7Y,4858
12
+ lt_tensor/datasets/audio.py,sha256=kVluXRbLX7C5W5aFN7CUMb-O1KTHjTiTvE4f7iBcnZk,3754
13
13
  lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
14
14
  lt_tensor/model_zoo/bsc.py,sha256=muxIR7dU-Pvf-HFE-iy3zmRb1sTJlcs1vqdlnbU1Hss,6307
15
15
  lt_tensor/model_zoo/disc.py,sha256=jZPhoSV1hlrba3ohXGutYAAcSl4pWkqGYFpOlOoN3eo,4740
@@ -19,8 +19,10 @@ lt_tensor/model_zoo/istft.py,sha256=0Xms2QNPAgz_ib8XTfaWl1SCHgS53oKC6-EkDkl_qe4,
19
19
  lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
20
20
  lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
21
21
  lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
22
- lt_tensor-0.0.1a6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
23
- lt_tensor-0.0.1a6.dist-info/METADATA,sha256=-89IqEHsZD3W8moDuKWR8UodkdR2pwefqrG9C7P7y_Y,968
24
- lt_tensor-0.0.1a6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- lt_tensor-0.0.1a6.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
26
- lt_tensor-0.0.1a6.dist-info/RECORD,,
22
+ lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
23
+ lt_tensor/processors/audio.py,sha256=mU0usiagVyNPd0uEadL_lC4BFzSMNpjTIwth82gFJRI,6650
24
+ lt_tensor-0.0.1a7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
25
+ lt_tensor-0.0.1a7.dist-info/METADATA,sha256=5A9UrpFdhQikU44UlJbLQmQAVTceVhcbPV25jmLM9Os,965
26
+ lt_tensor-0.0.1a7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ lt_tensor-0.0.1a7.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
28
+ lt_tensor-0.0.1a7.dist-info/RECORD,,