lt-tensor 0.0.1a6__py3-none-any.whl → 0.0.1a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +4 -0
- lt_tensor/datasets/audio.py +92 -137
- lt_tensor/model_base.py +3 -3
- lt_tensor/noise_tools.py +14 -8
- lt_tensor/processors/__init__.py +3 -0
- lt_tensor/processors/audio.py +193 -0
- lt_tensor/transform.py +177 -8
- {lt_tensor-0.0.1a6.dist-info → lt_tensor-0.0.1a7.dist-info}/METADATA +2 -2
- {lt_tensor-0.0.1a6.dist-info → lt_tensor-0.0.1a7.dist-info}/RECORD +12 -10
- {lt_tensor-0.0.1a6.dist-info → lt_tensor-0.0.1a7.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a6.dist-info → lt_tensor-0.0.1a7.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a6.dist-info → lt_tensor-0.0.1a7.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
lt_tensor/datasets/audio.py
CHANGED
@@ -1,154 +1,109 @@
|
|
1
|
-
__all__ = ["
|
1
|
+
__all__ = ["WaveMelDatasets"]
|
2
2
|
from ..torch_commons import *
|
3
|
-
import torchaudio
|
4
3
|
from lt_utils.common import *
|
5
|
-
import
|
6
|
-
from
|
7
|
-
from
|
8
|
-
|
9
|
-
from
|
4
|
+
import random
|
5
|
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
6
|
+
from ..processors import AudioProcessor
|
7
|
+
import torch.nn.functional as FT
|
8
|
+
from ..misc_utils import log_tensor
|
10
9
|
|
11
10
|
|
12
|
-
class
|
11
|
+
class WaveMelDataset(Dataset):
|
12
|
+
"""Untested!"""
|
13
|
+
|
14
|
+
data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
|
13
15
|
|
14
16
|
def __init__(
|
15
17
|
self,
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
hop_length: int = 256,
|
21
|
-
f_min: float = 0,
|
22
|
-
f_max: float | None = None,
|
23
|
-
n_iter: int = 32,
|
24
|
-
center: bool = True,
|
25
|
-
mel_scale: Literal["htk", "slaney"] = "htk",
|
26
|
-
inv_n_fft: int = 16,
|
27
|
-
inv_hop: int = 4,
|
28
|
-
std: int = 4,
|
29
|
-
mean: int = -4,
|
18
|
+
audio_processor: AudioProcessor,
|
19
|
+
path: PathLike,
|
20
|
+
limit_files: Optional[int] = None,
|
21
|
+
max_frame_length: Optional[int] = None,
|
30
22
|
):
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
self.n_fft = n_fft
|
35
|
-
self.n_stft = n_fft // 2 + 1
|
36
|
-
self.f_min = f_min
|
37
|
-
self.f_max = f_max
|
38
|
-
self.n_iter = n_iter
|
39
|
-
self.hop_length = hop_length
|
40
|
-
self.sample_rate = sample_rate
|
41
|
-
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
42
|
-
sample_rate=sample_rate,
|
43
|
-
n_mels=n_mels,
|
44
|
-
n_fft=n_fft,
|
45
|
-
win_length=win_length,
|
46
|
-
hop_length=hop_length,
|
47
|
-
center=center,
|
48
|
-
f_min=f_min,
|
49
|
-
f_max=f_max,
|
50
|
-
mel_scale=mel_scale,
|
51
|
-
)
|
52
|
-
self.mel_rscale = torchaudio.transforms.InverseMelScale(
|
53
|
-
n_stft=self.n_stft,
|
54
|
-
n_mels=n_mels,
|
55
|
-
sample_rate=sample_rate,
|
56
|
-
f_min=f_min,
|
57
|
-
f_max=f_max,
|
58
|
-
mel_scale=mel_scale,
|
59
|
-
)
|
60
|
-
self.giffin_lim = torchaudio.transforms.GriffinLim(
|
61
|
-
n_fft=n_fft,
|
62
|
-
n_iter=n_iter,
|
63
|
-
win_length=win_length,
|
64
|
-
hop_length=hop_length,
|
65
|
-
)
|
66
|
-
self._inverse_transform = lambda x, y: inverse_transform(
|
67
|
-
x, y, inv_n_fft, inv_hop, inv_n_fft
|
23
|
+
super().__init__()
|
24
|
+
assert max_frame_length is None or max_frame_length >= (
|
25
|
+
(audio_processor.n_fft // 2) + 1
|
68
26
|
)
|
27
|
+
self.post_n_fft = (audio_processor.n_fft // 2) + 1
|
28
|
+
self.ap = audio_processor
|
29
|
+
self.files = self.ap.find_audios(path)
|
30
|
+
if limit_files:
|
31
|
+
random.shuffle(self.files)
|
32
|
+
self.files = self.files[:limit_files]
|
33
|
+
self.data = []
|
69
34
|
|
70
|
-
|
71
|
-
|
35
|
+
for file in self.files:
|
36
|
+
results = self.load_data(file, max_frame_length)
|
37
|
+
self.data.extend(results)
|
72
38
|
|
73
|
-
def
|
74
|
-
|
75
|
-
wave: Tensor,
|
76
|
-
) -> Tensor:
|
77
|
-
"""Returns: [B, M, ML]"""
|
78
|
-
mel_tensor = self.mel_spec(wave) # [M, ML]
|
79
|
-
mel_tensor = (mel_tensor - self.mean) / self.std
|
80
|
-
return mel_tensor # [B, M, ML]
|
39
|
+
def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
|
40
|
+
return {"mel": audio_mel, "raw": audio_raw, "file": file}
|
81
41
|
|
82
|
-
def
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
42
|
+
def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
|
43
|
+
initial_audio = self.ap.load_audio(file)
|
44
|
+
if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
|
45
|
+
audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
|
46
|
+
return [self._add_dict(initial_audio, audio_mel, file)]
|
47
|
+
results = []
|
48
|
+
for fragment in torch.split(
|
49
|
+
initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
|
50
|
+
):
|
51
|
+
if fragment.shape[-1] < self.post_n_fft:
|
52
|
+
# sometimes the tensor will be too small to be able to pass on mel
|
53
|
+
continue
|
54
|
+
audio_mel = self.ap.compute_mel(fragment, add_base=True)
|
55
|
+
results.append(self._add_dict(fragment, audio_mel, file))
|
56
|
+
return results
|
94
57
|
|
95
|
-
def
|
58
|
+
def get_data_loader(
|
96
59
|
self,
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
60
|
+
batch_size: int = 1,
|
61
|
+
shuffle: Optional[bool] = None,
|
62
|
+
sampler: Optional[Union[Sampler, Iterable]] = None,
|
63
|
+
batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
|
64
|
+
num_workers: int = 0,
|
65
|
+
pin_memory: bool = False,
|
66
|
+
drop_last: bool = False,
|
67
|
+
timeout: float = 0,
|
68
|
+
):
|
69
|
+
return DataLoader(
|
70
|
+
self,
|
71
|
+
batch_size=batch_size,
|
72
|
+
shuffle=shuffle,
|
73
|
+
sampler=sampler,
|
74
|
+
batch_sampler=batch_sampler,
|
75
|
+
num_workers=num_workers,
|
76
|
+
pin_memory=pin_memory,
|
77
|
+
drop_last=drop_last,
|
78
|
+
timeout=timeout,
|
79
|
+
collate_fn=self.collate_fn,
|
111
80
|
)
|
112
81
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
"
|
120
|
-
"
|
121
|
-
"
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
)
|
126
|
-
return FileScan.files(
|
127
|
-
path,
|
128
|
-
extensions,
|
129
|
-
)
|
82
|
+
@staticmethod
|
83
|
+
def collate_fn(batch: Sequence[Dict[str, Tensor]]):
|
84
|
+
mels = []
|
85
|
+
audios = []
|
86
|
+
files = []
|
87
|
+
for x in batch:
|
88
|
+
mels.append(x["mel"])
|
89
|
+
audios.append(x["raw"])
|
90
|
+
files.append(x["file"])
|
91
|
+
# Find max time in mel (dim -1), and max audio length
|
92
|
+
max_mel_len = max([m.shape[-1] for m in mels])
|
93
|
+
max_audio_len = max([a.shape[-1] for a in audios])
|
130
94
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
text_files = []
|
147
|
-
for audio in audio_files:
|
148
|
-
base_audio_dir = Path(audio).parent
|
149
|
-
audio_name = get_file_name(audio, False)
|
150
|
-
for pattern in text_file_patterns:
|
151
|
-
possible_txt_file = Path(base_audio_dir, audio_name + pattern)
|
152
|
-
if is_file(possible_txt_file):
|
153
|
-
text_files.append(audio)
|
154
|
-
return audio_files, text_files
|
95
|
+
padded_mels = torch.stack(
|
96
|
+
[FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mels]
|
97
|
+
) # shape: [B, 80, T_max]
|
98
|
+
|
99
|
+
padded_audios = torch.stack(
|
100
|
+
[FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in audios]
|
101
|
+
) # shape: [B, L_max]
|
102
|
+
|
103
|
+
return padded_mels, padded_audios, files
|
104
|
+
|
105
|
+
def __len__(self):
|
106
|
+
return len(self.data)
|
107
|
+
|
108
|
+
def __getitem__(self, index):
|
109
|
+
return self.data[index]
|
lt_tensor/model_base.py
CHANGED
@@ -41,15 +41,15 @@ class Model(nn.Module, ABC):
|
|
41
41
|
def device(self, device: Union[torch.device, str]):
|
42
42
|
assert isinstance(device, (str, torch.device))
|
43
43
|
self._device = torch.device(device) if isinstance(device, str) else device
|
44
|
-
self.
|
44
|
+
self._apply_device_to()
|
45
45
|
|
46
|
-
def
|
46
|
+
def _apply_device_to(self):
|
47
47
|
"""Add here components that are needed to have device applied to them,
|
48
48
|
that usually the '.to()' function fails to apply
|
49
49
|
|
50
50
|
example:
|
51
51
|
```
|
52
|
-
def
|
52
|
+
def _apply_device_to(self):
|
53
53
|
self.my_tensor = self.my_tensor.to(device=self.device)
|
54
54
|
```
|
55
55
|
"""
|
lt_tensor/noise_tools.py
CHANGED
@@ -128,7 +128,7 @@ def apply_noise(
|
|
128
128
|
on_error != "raise"
|
129
129
|
), f"Noise '{noise_type}' is not supported for {x.ndim}D input."
|
130
130
|
if on_error == "return_unchanged":
|
131
|
-
return x
|
131
|
+
return x, None
|
132
132
|
elif on_error == "try_others":
|
133
133
|
remaining = [
|
134
134
|
n
|
@@ -136,27 +136,33 @@ def apply_noise(
|
|
136
136
|
if n not in last_tries and x.ndim in _NOISE_DIM_SUPPORT[n]
|
137
137
|
]
|
138
138
|
if not remaining:
|
139
|
-
return x
|
139
|
+
return x, None
|
140
140
|
new_type = random.choice(remaining)
|
141
141
|
last_tries.append(new_type)
|
142
|
-
return
|
143
|
-
|
142
|
+
return (
|
143
|
+
apply_noise(
|
144
|
+
x, new_type, noise_level, seed, on_error, last_tries.copy()
|
145
|
+
),
|
146
|
+
noise_type,
|
144
147
|
)
|
145
148
|
try:
|
146
149
|
if isinstance(seed, int):
|
147
150
|
set_seed(seed)
|
148
|
-
return _NOISE_MAP[noise_type](x, noise_level)
|
151
|
+
return _NOISE_MAP[noise_type](x, noise_level), noise_type
|
149
152
|
except Exception as e:
|
150
153
|
if on_error == "raise":
|
151
154
|
raise e
|
152
155
|
elif on_error == "return_unchanged":
|
153
|
-
return x
|
156
|
+
return x, None
|
154
157
|
if len(last_tries) == len(_VALID_NOISES):
|
155
|
-
return x
|
158
|
+
return x, None
|
156
159
|
remaining = [n for n in _VALID_NOISES if n not in last_tries]
|
157
160
|
new_type = random.choice(remaining)
|
158
161
|
last_tries.append(new_type)
|
159
|
-
return
|
162
|
+
return (
|
163
|
+
apply_noise(x, new_type, noise_level, seed, on_error, last_tries.copy()),
|
164
|
+
noise_type,
|
165
|
+
)
|
160
166
|
|
161
167
|
|
162
168
|
class NoiseSchedulerA(nn.Module):
|
@@ -0,0 +1,193 @@
|
|
1
|
+
__all__ = ["AudioProcessor"]
|
2
|
+
from ..torch_commons import *
|
3
|
+
from lt_utils.common import *
|
4
|
+
from lt_utils.type_utils import is_file, is_array
|
5
|
+
from ..misc_utils import log_tensor
|
6
|
+
import librosa
|
7
|
+
import torchaudio
|
8
|
+
|
9
|
+
from ..transform import InverseTransformConfig, InverseTransform
|
10
|
+
from lt_utils.file_ops import FileScan, get_file_name, path_to_str
|
11
|
+
|
12
|
+
from lt_tensor.model_base import Model
|
13
|
+
|
14
|
+
|
15
|
+
class AudioProcessor(Model):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
sample_rate: int = 24000,
|
19
|
+
n_mels: int = 80,
|
20
|
+
n_fft: int = 1024,
|
21
|
+
win_length: Optional[int] = None,
|
22
|
+
hop_length: Optional[int] = None,
|
23
|
+
f_min: float = 0,
|
24
|
+
f_max: float | None = None,
|
25
|
+
n_iter: int = 32,
|
26
|
+
center: bool = True,
|
27
|
+
mel_scale: Literal["htk", "slaney"] = "htk",
|
28
|
+
std: int = 4,
|
29
|
+
mean: int = -4,
|
30
|
+
inverse_transform_config: Union[
|
31
|
+
Dict[str, Union[Number, Tensor, bool]], InverseTransformConfig
|
32
|
+
] = dict(n_fft=16, hop_length=4, win_length=16, center=True),
|
33
|
+
*__,
|
34
|
+
**_,
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
assert isinstance(inverse_transform_config, (InverseTransformConfig, dict))
|
38
|
+
self.mean = mean
|
39
|
+
self.std = std
|
40
|
+
self.n_mels = n_mels
|
41
|
+
self.n_fft = n_fft
|
42
|
+
self.n_stft = n_fft // 2 + 1
|
43
|
+
self.f_min = f_min
|
44
|
+
self.f_max = f_max
|
45
|
+
self.n_iter = n_iter
|
46
|
+
self.hop_length = hop_length or n_fft // 4
|
47
|
+
self.win_length = win_length or n_fft
|
48
|
+
self.sample_rate = sample_rate
|
49
|
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
50
|
+
sample_rate=sample_rate,
|
51
|
+
n_mels=n_mels,
|
52
|
+
n_fft=n_fft,
|
53
|
+
win_length=win_length,
|
54
|
+
hop_length=hop_length,
|
55
|
+
center=center,
|
56
|
+
f_min=f_min,
|
57
|
+
f_max=f_max,
|
58
|
+
mel_scale=mel_scale,
|
59
|
+
)
|
60
|
+
self.mel_rscale = torchaudio.transforms.InverseMelScale(
|
61
|
+
n_stft=self.n_stft,
|
62
|
+
n_mels=n_mels,
|
63
|
+
sample_rate=sample_rate,
|
64
|
+
f_min=f_min,
|
65
|
+
f_max=f_max,
|
66
|
+
mel_scale=mel_scale,
|
67
|
+
)
|
68
|
+
self.giffin_lim = torchaudio.transforms.GriffinLim(
|
69
|
+
n_fft=n_fft,
|
70
|
+
n_iter=n_iter,
|
71
|
+
win_length=win_length,
|
72
|
+
hop_length=hop_length,
|
73
|
+
)
|
74
|
+
if isinstance(inverse_transform_config, dict):
|
75
|
+
inverse_transform_config = InverseTransformConfig(
|
76
|
+
**inverse_transform_config
|
77
|
+
)
|
78
|
+
self._inv_transform = InverseTransform(**inverse_transform_config.to_dict())
|
79
|
+
|
80
|
+
def inverse_transform(self, spec: Tensor, phase: Tensor, *_, **kwargs):
|
81
|
+
return self._inv_transform(spec, phase, **kwargs)
|
82
|
+
|
83
|
+
def compute_mel(
|
84
|
+
self, wave: Tensor, base: float = 1e-6, add_base: bool = False
|
85
|
+
) -> Tensor:
|
86
|
+
"""Returns: [B, M, ML]"""
|
87
|
+
wave_device = wave.device
|
88
|
+
mel_tensor = self.mel_spec(wave.to(self.device)) # [M, ML]
|
89
|
+
if not add_base:
|
90
|
+
return (mel_tensor - self.mean) / self.std
|
91
|
+
return ((torch.log(base + mel_tensor.unsqueeze(0)) - self.mean) / self.std).to(
|
92
|
+
device=wave_device
|
93
|
+
)
|
94
|
+
|
95
|
+
def reverse_mel(self, mel: Tensor, n_iter: Optional[int] = None):
|
96
|
+
if isinstance(n_iter, int) and n_iter != self.n_iter:
|
97
|
+
self.giffin_lim = torchaudio.transforms.GriffinLim(
|
98
|
+
n_fft=self.n_fft,
|
99
|
+
n_iter=n_iter,
|
100
|
+
win_length=self.win_length,
|
101
|
+
hop_length=self.hop_length,
|
102
|
+
)
|
103
|
+
self.n_iter = n_iter
|
104
|
+
return self.giffin_lim.forward(
|
105
|
+
self.mel_rscale(mel),
|
106
|
+
)
|
107
|
+
|
108
|
+
def load_audio(
|
109
|
+
self,
|
110
|
+
path: PathLike,
|
111
|
+
top_db: float = 30,
|
112
|
+
) -> Tensor:
|
113
|
+
is_file(path, True)
|
114
|
+
wave, sr = librosa.load(str(path), sr=self.sample_rate)
|
115
|
+
wave, _ = librosa.effects.trim(wave, top_db=top_db)
|
116
|
+
return (
|
117
|
+
torch.from_numpy(
|
118
|
+
librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
|
119
|
+
if sr != self.sample_rate
|
120
|
+
else wave
|
121
|
+
)
|
122
|
+
.float()
|
123
|
+
.unsqueeze(0)
|
124
|
+
)
|
125
|
+
|
126
|
+
def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
|
127
|
+
extensions = [
|
128
|
+
"*.wav",
|
129
|
+
"*.aac",
|
130
|
+
"*.m4a",
|
131
|
+
"*.mp3",
|
132
|
+
"*.ogg",
|
133
|
+
"*.opus",
|
134
|
+
"*.flac",
|
135
|
+
]
|
136
|
+
extensions.extend(
|
137
|
+
[x for x in additional_extensions if isinstance(x, str) and "*" in x]
|
138
|
+
)
|
139
|
+
return FileScan.files(
|
140
|
+
path,
|
141
|
+
extensions,
|
142
|
+
)
|
143
|
+
|
144
|
+
def find_audio_text_pairs(
|
145
|
+
self,
|
146
|
+
path,
|
147
|
+
additional_extensions: List[str] = [],
|
148
|
+
text_file_patterns: List[str] = [".normalized.txt", ".original.txt"],
|
149
|
+
):
|
150
|
+
is_array(text_file_patterns, True, validate=True) # Rases if empty or not valid
|
151
|
+
additional_extensions = [
|
152
|
+
x
|
153
|
+
for x in additional_extensions
|
154
|
+
if isinstance(x, str)
|
155
|
+
and "*" in x
|
156
|
+
and not any(list(map(lambda y: y in x), text_file_patterns))
|
157
|
+
]
|
158
|
+
audio_files = self.find_audios(path, additional_extensions)
|
159
|
+
results = []
|
160
|
+
for audio in audio_files:
|
161
|
+
base_audio_dir = Path(audio).parent
|
162
|
+
audio_name = get_file_name(audio, False)
|
163
|
+
for pattern in text_file_patterns:
|
164
|
+
possible_txt_file = Path(base_audio_dir, audio_name + pattern)
|
165
|
+
if is_file(possible_txt_file):
|
166
|
+
results.append((audio, path_to_str(possible_txt_file)))
|
167
|
+
break
|
168
|
+
return results
|
169
|
+
|
170
|
+
def stft_loss(self, signal: Tensor, ground: Tensor, base: float = 1e-5):
|
171
|
+
sig_mel = self(signal, base)
|
172
|
+
gnd_mel = self(ground, base)
|
173
|
+
return torch.norm(gnd_mel - sig_mel, p=1) / torch.norm(gnd_mel, p=1)
|
174
|
+
|
175
|
+
# def forward(self, wave: Tensor, base: Optional[float] = None):
|
176
|
+
def forward(
|
177
|
+
self,
|
178
|
+
*inputs: Union[Tensor, float],
|
179
|
+
ap_task: Literal[
|
180
|
+
"get_mel", "get_loss", "inv_transform", "revert_mel"
|
181
|
+
] = "get_mel",
|
182
|
+
**inputs_kwargs,
|
183
|
+
):
|
184
|
+
if ap_task == "get_mel":
|
185
|
+
return self.compute_mel(*inputs, **inputs_kwargs)
|
186
|
+
elif ap_task == "get_loss":
|
187
|
+
return self.stft_loss(*inputs, **inputs_kwargs)
|
188
|
+
elif ap_task == "inv_transform":
|
189
|
+
return self.inverse_transform(*inputs, **inputs_kwargs)
|
190
|
+
elif ap_task == "revert_mel":
|
191
|
+
return self.reverse_mel(*inputs, **inputs_kwargs)
|
192
|
+
else:
|
193
|
+
raise ValueError(f"Invalid task '{ap_task}'")
|
lt_tensor/transform.py
CHANGED
@@ -15,6 +15,8 @@ __all__ = [
|
|
15
15
|
"normalize",
|
16
16
|
"window_sumsquare",
|
17
17
|
"inverse_transform",
|
18
|
+
"InverseTransformConfig",
|
19
|
+
"InverseTransform",
|
18
20
|
"stft_istft_rebuild",
|
19
21
|
]
|
20
22
|
|
@@ -23,12 +25,15 @@ import torchaudio
|
|
23
25
|
import math
|
24
26
|
from .misc_utils import log_tensor
|
25
27
|
from lt_utils.common import *
|
28
|
+
from lt_utils.misc_utils import cache_wrapper, default
|
26
29
|
import torch.nn.functional as F
|
30
|
+
from .model_base import Model
|
31
|
+
import warnings
|
27
32
|
|
28
33
|
|
29
34
|
def to_mel_spectrogram(
|
30
35
|
waveform: torch.Tensor,
|
31
|
-
sample_rate: int =
|
36
|
+
sample_rate: int = 24000,
|
32
37
|
n_fft: int = 1024,
|
33
38
|
hop_length: Optional[int] = None,
|
34
39
|
win_length: Optional[int] = None,
|
@@ -198,8 +203,7 @@ def generate_window(
|
|
198
203
|
return torch.ones(1, device=device)
|
199
204
|
|
200
205
|
n = torch.arange(M, dtype=torch.float32, device=device)
|
201
|
-
|
202
|
-
return window
|
206
|
+
return alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
|
203
207
|
|
204
208
|
|
205
209
|
def pad_center(tensor: torch.Tensor, size: int, axis: int = -1) -> torch.Tensor:
|
@@ -260,8 +264,8 @@ def normalize(
|
|
260
264
|
def window_sumsquare(
|
261
265
|
window_spec: Union[str, int, float, Callable, List[Any], Tuple[Any, ...]],
|
262
266
|
n_frames: int,
|
263
|
-
hop_length: int =
|
264
|
-
win_length: int =
|
267
|
+
hop_length: int = 256,
|
268
|
+
win_length: int = 1024,
|
265
269
|
n_fft: int = 2048,
|
266
270
|
dtype: torch.dtype = torch.float32,
|
267
271
|
norm: Optional[Union[int, float]] = None,
|
@@ -294,9 +298,9 @@ def window_sumsquare(
|
|
294
298
|
def inverse_transform(
|
295
299
|
spec: Tensor,
|
296
300
|
phase: Tensor,
|
297
|
-
n_fft: int =
|
298
|
-
hop_length: int =
|
299
|
-
win_length: int =
|
301
|
+
n_fft: int = 1024,
|
302
|
+
hop_length: Optional[int] = None,
|
303
|
+
win_length: Optional[int] = None,
|
300
304
|
length: Optional[Any] = None,
|
301
305
|
window: Optional[Tensor] = None,
|
302
306
|
):
|
@@ -310,3 +314,168 @@ def inverse_transform(
|
|
310
314
|
window=window,
|
311
315
|
length=length,
|
312
316
|
)
|
317
|
+
|
318
|
+
|
319
|
+
def is_nand(a: bool, b: bool):
|
320
|
+
"""[a -> b = result]
|
321
|
+
```
|
322
|
+
False -> False = True
|
323
|
+
False -> True = True
|
324
|
+
True -> False = True
|
325
|
+
True -> True = False
|
326
|
+
```
|
327
|
+
"""
|
328
|
+
return not (a and b)
|
329
|
+
|
330
|
+
|
331
|
+
class InverseTransformConfig:
|
332
|
+
def __init__(
|
333
|
+
self,
|
334
|
+
n_fft: int = 1024,
|
335
|
+
hop_length: Optional[int] = None,
|
336
|
+
win_length: Optional[int] = None,
|
337
|
+
length: Optional[int] = None,
|
338
|
+
window: Optional[Tensor] = None,
|
339
|
+
onesided: Optional[bool] = None,
|
340
|
+
return_complex: bool = False,
|
341
|
+
normalized: bool = False,
|
342
|
+
center: bool = True,
|
343
|
+
):
|
344
|
+
self.n_fft = n_fft
|
345
|
+
self.hop_length = hop_length
|
346
|
+
self.win_length = win_length
|
347
|
+
self.length = length
|
348
|
+
self.onesided = onesided
|
349
|
+
self.return_complex = return_complex
|
350
|
+
self.normalized = normalized
|
351
|
+
self.center = center
|
352
|
+
self.window = window
|
353
|
+
|
354
|
+
def to_dict(self):
|
355
|
+
return self.__dict__.copy()
|
356
|
+
|
357
|
+
|
358
|
+
class InverseTransform(Model):
|
359
|
+
def __init__(
|
360
|
+
self,
|
361
|
+
n_fft: int = 1024,
|
362
|
+
hop_length: Optional[int] = None,
|
363
|
+
win_length: Optional[int] = None,
|
364
|
+
length: Optional[int] = None,
|
365
|
+
window: Optional[Tensor] = None,
|
366
|
+
onesided: Optional[bool] = None,
|
367
|
+
return_complex: bool = False,
|
368
|
+
normalized: bool = False,
|
369
|
+
center: bool = True,
|
370
|
+
):
|
371
|
+
"""
|
372
|
+
Module for inverting a magnitude + phase spectrogram to a waveform using ISTFT.
|
373
|
+
|
374
|
+
This class encapsulates common ISTFT parameters at initialization and applies
|
375
|
+
the inverse transformation in the `forward()` method with minimal per-call overhead.
|
376
|
+
|
377
|
+
Parameters
|
378
|
+
----------
|
379
|
+
n_fft : int, optional
|
380
|
+
Size of FFT to use during inversion. Default is 1024.
|
381
|
+
hop_length : int, optional
|
382
|
+
Number of audio samples between STFT columns. Defaults to `n_fft`.
|
383
|
+
win_length : int, optional
|
384
|
+
Size of the window function. Defaults to `n_fft // 4`.
|
385
|
+
length : int, optional
|
386
|
+
Output waveform length. If not provided, length will be inferred.
|
387
|
+
window : Tensor, optional
|
388
|
+
Custom window tensor. If None, a Hann window is used.
|
389
|
+
onesided : bool, optional
|
390
|
+
Whether the input STFT was onesided. Required only for consistency checks.
|
391
|
+
return_complex : bool, default=False
|
392
|
+
Must be False if `onesided` is True. Not used internally.
|
393
|
+
normalized : bool, default=False
|
394
|
+
Whether the STFT was normalized.
|
395
|
+
center : bool, default=True
|
396
|
+
Whether the signal was padded during STFT.
|
397
|
+
|
398
|
+
Methods
|
399
|
+
-------
|
400
|
+
forward(spec, phase)
|
401
|
+
Applies ISTFT using stored settings on the given magnitude and phase tensors.
|
402
|
+
update_settings(...)
|
403
|
+
Updates ISTFT parameters dynamically (used internally during forward).
|
404
|
+
"""
|
405
|
+
super().__init__()
|
406
|
+
assert window is None or isinstance(window, Tensor)
|
407
|
+
assert any(
|
408
|
+
(
|
409
|
+
(not return_complex and not onesided),
|
410
|
+
(not onesided and return_complex),
|
411
|
+
(not return_complex and onesided),
|
412
|
+
)
|
413
|
+
)
|
414
|
+
self.n_fft = n_fft
|
415
|
+
self.length = length
|
416
|
+
self.win_length = win_length or n_fft // 4
|
417
|
+
self.hop_length = hop_length or n_fft
|
418
|
+
self.center = center // 4
|
419
|
+
self.return_complex = return_complex
|
420
|
+
self.onesided = onesided
|
421
|
+
self.normalized = normalized
|
422
|
+
self.window = torch.hann_window(win_length) if window is None else window
|
423
|
+
|
424
|
+
def _apply_device_to(self):
|
425
|
+
"""Applies to device while used with module `Model`"""
|
426
|
+
self.window = self.window.to(device=self.device)
|
427
|
+
|
428
|
+
def update_settings(
|
429
|
+
self,
|
430
|
+
*,
|
431
|
+
n_fft: Optional[int] = None,
|
432
|
+
hop_length: Optional[int] = None,
|
433
|
+
win_length: Optional[int] = None,
|
434
|
+
length: Optional[int] = None,
|
435
|
+
window: Optional[Tensor] = None,
|
436
|
+
onesided: Optional[bool] = None,
|
437
|
+
return_complex: Optional[bool] = None,
|
438
|
+
center: Optional[bool] = None,
|
439
|
+
normalized: Optional[bool] = None,
|
440
|
+
**_,
|
441
|
+
):
|
442
|
+
|
443
|
+
self.kwargs = dict(
|
444
|
+
n_fft=default(n_fft, self.n_fft),
|
445
|
+
hop_length=default(hop_length, self.hop_length),
|
446
|
+
win_length=default(win_length, self.win_length),
|
447
|
+
length=default(length, self.length),
|
448
|
+
window=default(window, self.window),
|
449
|
+
onesided=default(onesided, self.onesided),
|
450
|
+
return_complex=default(return_complex, self.return_complex),
|
451
|
+
center=default(center, self.center),
|
452
|
+
normalized=default(normalized, self.normalized),
|
453
|
+
)
|
454
|
+
if self.kwargs["onesided"] and self.kwargs["return_complex"]:
|
455
|
+
warnings.warn(
|
456
|
+
"You cannot use return_complex with `onesided` enabled. `return_complex` is set to False."
|
457
|
+
)
|
458
|
+
self.kwargs["return_complex"] = False
|
459
|
+
|
460
|
+
def forward(self, spec: Tensor, phase: Tensor, **kwargs):
|
461
|
+
"""
|
462
|
+
Perform the inverse short-time Fourier transform.
|
463
|
+
|
464
|
+
Parameters
|
465
|
+
----------
|
466
|
+
spec : Tensor
|
467
|
+
Magnitude spectrogram of shape (batch, freq, time).
|
468
|
+
phase : Tensor
|
469
|
+
Phase angles tensor, same shape as `spec`, in radians.
|
470
|
+
**kwargs : dict, optional
|
471
|
+
Optional ISTFT override parameters (same as in `update_settings`).
|
472
|
+
|
473
|
+
Returns
|
474
|
+
-------
|
475
|
+
Tensor
|
476
|
+
Time-domain waveform reconstructed from `spec` and `phase`.
|
477
|
+
"""
|
478
|
+
if kwargs:
|
479
|
+
self.update_settings(**kwargs)
|
480
|
+
|
481
|
+
return torch.istft(spec * torch.exp(phase * 1j), **self.kwargs)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: lt-tensor
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a7
|
4
4
|
Summary: General utilities for PyTorch and others. Built for general use.
|
5
5
|
Home-page: https://github.com/gr1336/lt-tensor/
|
6
6
|
Author: gr1336
|
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
|
|
17
17
|
Requires-Dist: tokenizers
|
18
18
|
Requires-Dist: pyyaml>=6.0.0
|
19
19
|
Requires-Dist: numba>0.60.0
|
20
|
-
Requires-Dist: lt-utils==0.0.1
|
20
|
+
Requires-Dist: lt-utils==0.0.1
|
21
21
|
Requires-Dist: librosa>=0.11.0
|
22
22
|
Dynamic: author
|
23
23
|
Dynamic: classifier
|
@@ -1,15 +1,15 @@
|
|
1
|
-
lt_tensor/__init__.py,sha256=
|
1
|
+
lt_tensor/__init__.py,sha256=uwJ7uiO18VYj8Z1V4KSOQ3ZrnowSgJWKCIiFBrzLMOI,429
|
2
2
|
lt_tensor/losses.py,sha256=TinZJP2ypZ7Tdg6d9nnFWFkPyormfgQ0Z9P2ER3sqzE,4341
|
3
3
|
lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
|
4
4
|
lt_tensor/math_ops.py,sha256=ewIYkvxIy_Lab_9ExjFUgLs-oYLOu8IRRDo7f1pn3i8,2248
|
5
5
|
lt_tensor/misc_utils.py,sha256=sjWUkUaHFhaCdN4rZ6X-cQDbPieimfKchKq9VtjiwEA,17029
|
6
|
-
lt_tensor/model_base.py,sha256=
|
6
|
+
lt_tensor/model_base.py,sha256=8qN7oklALFanOz-eqVzdnB9RD2kN_3ltynSMAPOl-TI,13413
|
7
7
|
lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
|
8
|
-
lt_tensor/noise_tools.py,sha256=
|
8
|
+
lt_tensor/noise_tools.py,sha256=JkWw0-bCMRNNMShwXKKt5KbO3104tvNiBePt-ThPkEo,11366
|
9
9
|
lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
|
10
|
-
lt_tensor/transform.py,sha256=
|
10
|
+
lt_tensor/transform.py,sha256=va4bQjpfhH-tnaBDvJZpmYmfg9zwn5_Y6pPOoTswS-U,14471
|
11
11
|
lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
lt_tensor/datasets/audio.py,sha256=
|
12
|
+
lt_tensor/datasets/audio.py,sha256=kVluXRbLX7C5W5aFN7CUMb-O1KTHjTiTvE4f7iBcnZk,3754
|
13
13
|
lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
|
14
14
|
lt_tensor/model_zoo/bsc.py,sha256=muxIR7dU-Pvf-HFE-iy3zmRb1sTJlcs1vqdlnbU1Hss,6307
|
15
15
|
lt_tensor/model_zoo/disc.py,sha256=jZPhoSV1hlrba3ohXGutYAAcSl4pWkqGYFpOlOoN3eo,4740
|
@@ -19,8 +19,10 @@ lt_tensor/model_zoo/istft.py,sha256=0Xms2QNPAgz_ib8XTfaWl1SCHgS53oKC6-EkDkl_qe4,
|
|
19
19
|
lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
|
20
20
|
lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
|
21
21
|
lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
|
22
|
-
lt_tensor
|
23
|
-
lt_tensor
|
24
|
-
lt_tensor-0.0.
|
25
|
-
lt_tensor-0.0.
|
26
|
-
lt_tensor-0.0.
|
22
|
+
lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
|
23
|
+
lt_tensor/processors/audio.py,sha256=mU0usiagVyNPd0uEadL_lC4BFzSMNpjTIwth82gFJRI,6650
|
24
|
+
lt_tensor-0.0.1a7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
25
|
+
lt_tensor-0.0.1a7.dist-info/METADATA,sha256=5A9UrpFdhQikU44UlJbLQmQAVTceVhcbPV25jmLM9Os,965
|
26
|
+
lt_tensor-0.0.1a7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
27
|
+
lt_tensor-0.0.1a7.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
28
|
+
lt_tensor-0.0.1a7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|