lt-tensor 0.0.1a18__py3-none-any.whl → 0.0.1a20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +0 -2
- lt_tensor/config_templates.py +29 -43
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py +32 -14
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +28 -10
- lt_tensor/model_zoo/audio_models/istft/__init__.py +31 -5
- {lt_tensor-0.0.1a18.dist-info → lt_tensor-0.0.1a20.dist-info}/METADATA +2 -2
- {lt_tensor-0.0.1a18.dist-info → lt_tensor-0.0.1a20.dist-info}/RECORD +10 -12
- lt_tensor/datasets/__init__.py +0 -0
- lt_tensor/datasets/audio.py +0 -235
- {lt_tensor-0.0.1a18.dist-info → lt_tensor-0.0.1a20.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a18.dist-info → lt_tensor-0.0.1a20.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a18.dist-info → lt_tensor-0.0.1a20.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
lt_tensor/config_templates.py
CHANGED
@@ -7,75 +7,58 @@ from lt_tensor.misc_utils import updateDict
|
|
7
7
|
|
8
8
|
class ModelConfig(ABC, OrderedDict):
|
9
9
|
_default_settings: Dict[str, Any] = {}
|
10
|
-
_forbidden_list: List[str] = [
|
10
|
+
_forbidden_list: List[str] = [
|
11
|
+
"_default_settings",
|
12
|
+
"_forbidden_list",
|
13
|
+
]
|
11
14
|
|
12
15
|
def __init__(
|
13
16
|
self,
|
14
|
-
settings
|
15
|
-
path_name: Optional[Union[str, PathLike]] = None,
|
17
|
+
**settings,
|
16
18
|
):
|
17
|
-
assert is_dict(settings, False)
|
18
19
|
self._default_settings = settings
|
19
|
-
|
20
|
-
if not str(path_name).endswith(".json"):
|
21
|
-
self.path_name = str(Path(path_name, "config.json")).replace("\\", "/")
|
22
|
-
else:
|
23
|
-
self.path_name = str(path_name).replace("\\", "/")
|
24
|
-
else:
|
25
|
-
self.path_name = "config.json"
|
26
|
-
self.reset_settings()
|
27
|
-
|
28
|
-
def _setup_path_name(self, path_name: Union[str, PathLike]):
|
29
|
-
if is_file(path_name):
|
30
|
-
self.from_path(path_name)
|
31
|
-
self.path_name = str(path_name).replace("\\", "/")
|
32
|
-
elif is_str(path_name):
|
33
|
-
self.path_name = str(path_name).replace("\\", "/")
|
34
|
-
if not self.path_name.endswith((".json")):
|
35
|
-
self.path_name += ".json"
|
20
|
+
self.set_state_dict(self._default_settings)
|
36
21
|
|
37
22
|
def reset_settings(self):
|
38
|
-
|
39
|
-
for s_name, setting in self._default_settings.items():
|
40
|
-
if s_name in self._forbidden_list or s_name not in dk_keys:
|
41
|
-
continue
|
42
|
-
updateDict(self, {s_name: setting})
|
23
|
+
raise NotImplementedError("Not implemented")
|
43
24
|
|
44
25
|
def save_config(
|
45
26
|
self,
|
46
|
-
|
27
|
+
path: str,
|
47
28
|
):
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
else:
|
54
|
-
self._setup_path_name(path_name)
|
29
|
+
base = {
|
30
|
+
k: v
|
31
|
+
for k, v in self.state_dict().items()
|
32
|
+
if isinstance(v, (str, int, float, list, tuple, dict, set, bytes))
|
33
|
+
}
|
55
34
|
|
56
|
-
base =
|
57
|
-
save_json(self.path_name, base, indent=2)
|
35
|
+
save_json(path, base, indent=4)
|
58
36
|
|
59
37
|
def set_value(self, var_name: str, value: str) -> None:
|
60
|
-
assert var_name in self.__dict__, "Value not registered!"
|
61
38
|
assert var_name not in self._forbidden_list, "Not allowed!"
|
62
39
|
updateDict(self, {var_name: value})
|
40
|
+
self.update({var_name: value})
|
63
41
|
|
64
42
|
def get_value(self, var_name: str) -> Any:
|
65
43
|
return self.__dict__.get(var_name)
|
66
44
|
|
67
|
-
def
|
68
|
-
|
45
|
+
def set_state_dict(self, new_state: dict[str, str]):
|
46
|
+
new_state = {
|
47
|
+
k: y for k, y in new_state.items() if k not in self._forbidden_list
|
48
|
+
}
|
49
|
+
updateDict(self, new_state)
|
50
|
+
self.update(**new_state)
|
69
51
|
|
70
|
-
def
|
52
|
+
def state_dict(self):
|
71
53
|
return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
|
72
54
|
|
73
55
|
@classmethod
|
74
56
|
def from_dict(
|
75
|
-
cls,
|
57
|
+
cls,
|
58
|
+
dictionary: Dict[str, Any],
|
76
59
|
) -> "ModelConfig":
|
77
60
|
assert is_dict(dictionary)
|
78
|
-
return ModelConfig(dictionary
|
61
|
+
return ModelConfig(**dictionary)
|
79
62
|
|
80
63
|
@classmethod
|
81
64
|
def from_path(cls, path_name: PathLike) -> "ModelConfig":
|
@@ -102,4 +85,7 @@ class ModelConfig(ABC, OrderedDict):
|
|
102
85
|
)
|
103
86
|
assert files, "No config file found in the provided directory!"
|
104
87
|
settings.update(load_json(files[-1], {}, errors="ignore"))
|
105
|
-
|
88
|
+
settings.pop("path", None)
|
89
|
+
settings.pop("path_name", None)
|
90
|
+
|
91
|
+
return ModelConfig(**settings)
|
@@ -1,9 +1,8 @@
|
|
1
1
|
__all__ = ["DiffWave", "DiffWaveConfig", "SpectrogramUpsample", "DiffusionEmbedding"]
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
import
|
5
|
-
|
6
|
-
import torch.nn.functional as F
|
4
|
+
from lt_tensor.torch_commons import *
|
5
|
+
from torch.nn import functional as F
|
7
6
|
from lt_tensor.config_templates import ModelConfig
|
8
7
|
from lt_tensor.torch_commons import *
|
9
8
|
from lt_tensor.model_base import Model
|
@@ -12,16 +11,9 @@ from lt_utils.common import *
|
|
12
11
|
|
13
12
|
|
14
13
|
class DiffWaveConfig(ModelConfig):
|
15
|
-
#
|
16
|
-
batch_size = 16
|
17
|
-
learning_rate = 2e-4
|
18
|
-
max_grad_norm = None
|
19
|
-
# Data params
|
20
|
-
sample_rate = 24000
|
14
|
+
# Model params
|
21
15
|
n_mels = 80
|
22
|
-
n_fft = 1024
|
23
16
|
hop_samples = 256
|
24
|
-
# Model params
|
25
17
|
residual_layers = 30
|
26
18
|
residual_channels = 64
|
27
19
|
dilation_cycle_length = 10
|
@@ -35,10 +27,36 @@ class DiffWaveConfig(ModelConfig):
|
|
35
27
|
|
36
28
|
def __init__(
|
37
29
|
self,
|
38
|
-
|
39
|
-
|
30
|
+
n_mels=80,
|
31
|
+
hop_samples=256,
|
32
|
+
residual_layers=30,
|
33
|
+
residual_channels=64,
|
34
|
+
dilation_cycle_length=10,
|
35
|
+
unconditional=False,
|
36
|
+
noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist(),
|
37
|
+
interpolate_cond=False,
|
38
|
+
interpolation_mode: Literal[
|
39
|
+
"nearest",
|
40
|
+
"linear",
|
41
|
+
"bilinear",
|
42
|
+
"bicubic",
|
43
|
+
"trilinear",
|
44
|
+
"area",
|
45
|
+
"nearest-exact",
|
46
|
+
] = "nearest",
|
40
47
|
):
|
41
|
-
|
48
|
+
settings = {
|
49
|
+
"n_mels": n_mels,
|
50
|
+
"hop_samples": hop_samples,
|
51
|
+
"residual_layers": residual_layers,
|
52
|
+
"dilation_cycle_length": dilation_cycle_length,
|
53
|
+
"residual_channels": residual_channels,
|
54
|
+
"unconditional": unconditional,
|
55
|
+
"noise_schedule": noise_schedule,
|
56
|
+
"interpolate": interpolate_cond,
|
57
|
+
"interpolation_mode": interpolation_mode,
|
58
|
+
}
|
59
|
+
super().__init__(**settings)
|
42
60
|
|
43
61
|
|
44
62
|
def Conv1d(*args, **kwargs):
|
@@ -4,10 +4,6 @@ from lt_tensor.torch_commons import *
|
|
4
4
|
from lt_tensor.model_zoo.residual import ConvNets
|
5
5
|
from torch.nn import functional as F
|
6
6
|
|
7
|
-
import torch
|
8
|
-
import torch.nn.functional as F
|
9
|
-
import torch.nn as nn
|
10
|
-
|
11
7
|
|
12
8
|
def get_padding(kernel_size, dilation=1):
|
13
9
|
return int((kernel_size * dilation - dilation) / 2)
|
@@ -21,7 +17,7 @@ class HifiganConfig(ModelConfig):
|
|
21
17
|
in_channels: int = 80
|
22
18
|
upsample_rates: List[Union[int, List[int]]] = [8, 8]
|
23
19
|
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16]
|
24
|
-
upsample_initial_channel: int =
|
20
|
+
upsample_initial_channel: int = 512
|
25
21
|
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
|
26
22
|
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
27
23
|
[1, 3, 5],
|
@@ -34,10 +30,32 @@ class HifiganConfig(ModelConfig):
|
|
34
30
|
|
35
31
|
def __init__(
|
36
32
|
self,
|
37
|
-
|
38
|
-
|
33
|
+
in_channels: int = 80,
|
34
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8],
|
35
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
|
36
|
+
upsample_initial_channel: int = 512,
|
37
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
38
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
39
|
+
[1, 3, 5],
|
40
|
+
[1, 3, 5],
|
41
|
+
[1, 3, 5],
|
42
|
+
],
|
43
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
44
|
+
resblock: int = 0,
|
45
|
+
*args,
|
46
|
+
**kwargs,
|
39
47
|
):
|
40
|
-
|
48
|
+
settings = {
|
49
|
+
"in_channels": in_channels,
|
50
|
+
"upsample_rates": upsample_rates,
|
51
|
+
"upsample_kernel_sizes": upsample_kernel_sizes,
|
52
|
+
"upsample_initial_channel": upsample_initial_channel,
|
53
|
+
"resblock_kernel_sizes": resblock_kernel_sizes,
|
54
|
+
"resblock_dilation_sizes": resblock_dilation_sizes,
|
55
|
+
"activation": activation,
|
56
|
+
"resblock": resblock,
|
57
|
+
}
|
58
|
+
super().__init__(**settings)
|
41
59
|
|
42
60
|
|
43
61
|
class ResBlock1(ConvNets):
|
@@ -177,10 +195,10 @@ class HifiganGenerator(ConvNets):
|
|
177
195
|
self.conv_pre = weight_norm(
|
178
196
|
nn.Conv1d(cfg.in_channels, cfg.upsample_initial_channel, 7, 1, padding=3)
|
179
197
|
)
|
180
|
-
resblock = ResBlock1 if resblock == 0 else ResBlock2
|
198
|
+
resblock = ResBlock1 if cfg.resblock == 0 else ResBlock2
|
181
199
|
self.activation = cfg.activation
|
182
200
|
self.ups = nn.ModuleList()
|
183
|
-
for i, (u, k) in enumerate(zip(cfg.
|
201
|
+
for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
|
184
202
|
self.ups.append(
|
185
203
|
weight_norm(
|
186
204
|
nn.ConvTranspose1d(
|
@@ -11,7 +11,7 @@ class iSTFTNetConfig(ModelConfig):
|
|
11
11
|
in_channels: int = 80
|
12
12
|
upsample_rates: List[Union[int, List[int]]] = [8, 8]
|
13
13
|
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16]
|
14
|
-
upsample_initial_channel: int =
|
14
|
+
upsample_initial_channel: int = 512
|
15
15
|
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
|
16
16
|
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
17
17
|
[1, 3, 5],
|
@@ -26,10 +26,36 @@ class iSTFTNetConfig(ModelConfig):
|
|
26
26
|
|
27
27
|
def __init__(
|
28
28
|
self,
|
29
|
-
|
30
|
-
|
29
|
+
in_channels: int = 80,
|
30
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8],
|
31
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
|
32
|
+
upsample_initial_channel: int = 512,
|
33
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
34
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
35
|
+
[1, 3, 5],
|
36
|
+
[1, 3, 5],
|
37
|
+
[1, 3, 5],
|
38
|
+
],
|
39
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
40
|
+
resblock: int = 0,
|
41
|
+
gen_istft_n_fft: int = 16,
|
42
|
+
sampling_rate: Number = 24000,
|
43
|
+
*args,
|
44
|
+
**kwargs,
|
31
45
|
):
|
32
|
-
|
46
|
+
settings = {
|
47
|
+
"in_channels": in_channels,
|
48
|
+
"upsample_rates": upsample_rates,
|
49
|
+
"upsample_kernel_sizes": upsample_kernel_sizes,
|
50
|
+
"upsample_initial_channel": upsample_initial_channel,
|
51
|
+
"resblock_kernel_sizes": resblock_kernel_sizes,
|
52
|
+
"resblock_dilation_sizes": resblock_dilation_sizes,
|
53
|
+
"activation": activation,
|
54
|
+
"resblock": resblock,
|
55
|
+
"gen_istft_n_fft": gen_istft_n_fft,
|
56
|
+
"sampling_rate": sampling_rate,
|
57
|
+
}
|
58
|
+
super().__init__(**settings)
|
33
59
|
|
34
60
|
|
35
61
|
def get_padding(ks, d):
|
@@ -169,7 +195,7 @@ class iSTFTNetGenerator(ConvNets):
|
|
169
195
|
self.conv_pre = weight_norm(
|
170
196
|
nn.Conv1d(cfg.in_channels, cfg.upsample_initial_channel, 7, 1, padding=3)
|
171
197
|
)
|
172
|
-
resblock = ResBlock1 if resblock == 0 else ResBlock2
|
198
|
+
resblock = ResBlock1 if cfg.resblock == 0 else ResBlock2
|
173
199
|
|
174
200
|
self.ups = nn.ModuleList()
|
175
201
|
for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: lt-tensor
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a20
|
4
4
|
Summary: General utilities for PyTorch and others. Built for general use.
|
5
5
|
Home-page: https://github.com/gr1336/lt-tensor/
|
6
6
|
Author: gr1336
|
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
|
|
17
17
|
Requires-Dist: tokenizers
|
18
18
|
Requires-Dist: pyyaml>=6.0.0
|
19
19
|
Requires-Dist: numba>0.60.0
|
20
|
-
Requires-Dist: lt-utils==0.0.
|
20
|
+
Requires-Dist: lt-utils==0.0.2a3
|
21
21
|
Requires-Dist: librosa==0.11.*
|
22
22
|
Requires-Dist: einops
|
23
23
|
Requires-Dist: plotly
|
@@ -1,5 +1,5 @@
|
|
1
|
-
lt_tensor/__init__.py,sha256=
|
2
|
-
lt_tensor/config_templates.py,sha256=
|
1
|
+
lt_tensor/__init__.py,sha256=8FTxpJ6td2bMr_GqzW2tCV6Tr5CelbQle8N5JRWtx8M,439
|
2
|
+
lt_tensor/config_templates.py,sha256=RP7EFVRj6mRUj6xDLe7FMXgN5TIo8_o9h1Kb8epdmfo,2825
|
3
3
|
lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
|
4
4
|
lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
|
5
5
|
lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
|
@@ -9,8 +9,6 @@ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,
|
|
9
9
|
lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
|
10
10
|
lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
|
11
11
|
lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
|
12
|
-
lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
lt_tensor/datasets/audio.py,sha256=5Wvz1BJ7xXkLYpVLLw9RY3X3RgMdPPeGiN0-MmJDQy0,8045
|
14
12
|
lt_tensor/model_zoo/__init__.py,sha256=ltVTvmOlbOCfDc5Trvg0-Ta_Ujgkw0UVF9V5rqHx-RI,378
|
15
13
|
lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
|
16
14
|
lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
|
@@ -19,13 +17,13 @@ lt_tensor/model_zoo/pos_encoder.py,sha256=3d1EYLinCU9UAy-WuEWeYMGhMqaGknCiQ5qEmh
|
|
19
17
|
lt_tensor/model_zoo/residual.py,sha256=i5V4ju7DB3WesKBVm6KH_LyPoKGDUOyo2Usfs-PyP58,9394
|
20
18
|
lt_tensor/model_zoo/transformer.py,sha256=HUFoFFh7EQJErxdd9XIxhssdjvNVx2tNGDJOTUfwG2A,4301
|
21
19
|
lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9iedBXt66bplzW-s,83
|
22
|
-
lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=
|
23
|
-
lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=
|
24
|
-
lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=
|
20
|
+
lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=OUyh421xRCcxOMi_Ek6Ak3-FPe1k6WTDQ-6gd6OjaCU,8091
|
21
|
+
lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=JNebaYO3nsyyqpYCCOyL13zY2uxLY3NOCeNynF6-96k,13940
|
22
|
+
lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=JdFChpPhURaI2qb9mDV6vzDcZN757FBGGtgzN3vxtJ0,14821
|
25
23
|
lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
|
26
24
|
lt_tensor/processors/audio.py,sha256=SMqNSl4Den-x1awTCQ8-TcR-0jPiv5lDaUpU93SRRaw,14749
|
27
|
-
lt_tensor-0.0.
|
28
|
-
lt_tensor-0.0.
|
29
|
-
lt_tensor-0.0.
|
30
|
-
lt_tensor-0.0.
|
31
|
-
lt_tensor-0.0.
|
25
|
+
lt_tensor-0.0.1a20.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
|
26
|
+
lt_tensor-0.0.1a20.dist-info/METADATA,sha256=y-nfT5fMD5c8JgD299060Aij5Ugz2xHlDYYNoPXtAy8,1033
|
27
|
+
lt_tensor-0.0.1a20.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
28
|
+
lt_tensor-0.0.1a20.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
29
|
+
lt_tensor-0.0.1a20.dist-info/RECORD,,
|
lt_tensor/datasets/__init__.py
DELETED
File without changes
|
lt_tensor/datasets/audio.py
DELETED
@@ -1,235 +0,0 @@
|
|
1
|
-
__all__ = ["WaveMelDataset"]
|
2
|
-
from lt_tensor.torch_commons import *
|
3
|
-
from lt_utils.common import *
|
4
|
-
from lt_utils.misc_utils import default
|
5
|
-
import random
|
6
|
-
from torch.utils.data import Dataset, DataLoader, Sampler
|
7
|
-
from lt_tensor.processors import AudioProcessor
|
8
|
-
import torch.nn.functional as FT
|
9
|
-
from lt_tensor.misc_utils import log_tensor
|
10
|
-
from tqdm import tqdm
|
11
|
-
|
12
|
-
|
13
|
-
DEFAULT_DEVICE = torch.tensor([0]).device
|
14
|
-
|
15
|
-
|
16
|
-
class WaveMelDataset(Dataset):
|
17
|
-
cached_data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
|
18
|
-
loaded_files: Dict[str, List[Dict[str, Tensor]]] = {}
|
19
|
-
normalize_waves: bool = False
|
20
|
-
randomize_ranges: bool = False
|
21
|
-
alpha_wv: float = 1.0
|
22
|
-
limit_files: Optional[int] = None
|
23
|
-
min_frame_length: Optional[int] = None
|
24
|
-
max_frame_length: Optional[int] = None
|
25
|
-
|
26
|
-
def __init__(
|
27
|
-
self,
|
28
|
-
audio_processor: AudioProcessor,
|
29
|
-
dataset_path: PathLike,
|
30
|
-
limit_files: Optional[int] = None,
|
31
|
-
min_frame_length: Optional[int] = None,
|
32
|
-
max_frame_length: Optional[int] = None,
|
33
|
-
randomize_ranges: Optional[bool] = None,
|
34
|
-
pre_load: bool = False,
|
35
|
-
normalize_waves: Optional[bool] = None,
|
36
|
-
alpha_wv: Optional[float] = None,
|
37
|
-
lib_norm: bool = True,
|
38
|
-
):
|
39
|
-
super().__init__()
|
40
|
-
assert max_frame_length is None or max_frame_length >= (
|
41
|
-
(audio_processor.n_fft // 2) + 1
|
42
|
-
)
|
43
|
-
self.ap = audio_processor
|
44
|
-
self.dataset_path = dataset_path
|
45
|
-
if limit_files:
|
46
|
-
self.limit_files = limit_files
|
47
|
-
if normalize_waves is not None:
|
48
|
-
self.normalize_waves = normalize_waves
|
49
|
-
if alpha_wv is not None:
|
50
|
-
self.alpha_wv = alpha_wv
|
51
|
-
if pre_load is not None:
|
52
|
-
self.pre_loaded = pre_load
|
53
|
-
if randomize_ranges is not None:
|
54
|
-
self.randomize_ranges = randomize_ranges
|
55
|
-
|
56
|
-
self.post_n_fft = (audio_processor.n_fft // 2) + 1
|
57
|
-
self.lib_norm = lib_norm
|
58
|
-
if max_frame_length is not None:
|
59
|
-
max_frame_length = max(self.post_n_fft + 1, max_frame_length)
|
60
|
-
self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
|
61
|
-
self.max_frame_length = max_frame_length
|
62
|
-
if min_frame_length is not None:
|
63
|
-
self.min_frame_length = max(
|
64
|
-
self.post_n_fft + 1, min(min_frame_length, max_frame_length)
|
65
|
-
)
|
66
|
-
|
67
|
-
self.files = self.ap.find_audios(dataset_path, maximum=None)
|
68
|
-
if limit_files:
|
69
|
-
random.shuffle(self.files)
|
70
|
-
self.files = self.files[-self.limit_files :]
|
71
|
-
if pre_load:
|
72
|
-
for file in tqdm(self.files, "Loading files"):
|
73
|
-
results = self.load_data(file)
|
74
|
-
if not results:
|
75
|
-
continue
|
76
|
-
self.cached_data.extend(results)
|
77
|
-
|
78
|
-
def renew_dataset(self, new_path: Optional[PathLike] = None):
|
79
|
-
new_path = default(new_path, self.dataset_path)
|
80
|
-
self.files = self.ap.find_audios(new_path, maximum=None)
|
81
|
-
random.shuffle(self.files)
|
82
|
-
for file in tqdm(self.files, "Loading files"):
|
83
|
-
results = self.load_data(file)
|
84
|
-
if not results:
|
85
|
-
continue
|
86
|
-
self.cached_data.extend(results)
|
87
|
-
|
88
|
-
def _add_dict(
|
89
|
-
self,
|
90
|
-
audio_wave: Tensor,
|
91
|
-
audio_mel: Tensor,
|
92
|
-
pitch: Tensor,
|
93
|
-
rms: Tensor,
|
94
|
-
file: PathLike,
|
95
|
-
):
|
96
|
-
return {
|
97
|
-
"wave": audio_wave,
|
98
|
-
"pitch": pitch,
|
99
|
-
"rms": rms,
|
100
|
-
"mel": audio_mel,
|
101
|
-
"file": file,
|
102
|
-
}
|
103
|
-
|
104
|
-
def load_data(self, file: PathLike):
|
105
|
-
initial_audio = self.ap.load_audio(
|
106
|
-
file, normalize=self.lib_norm, alpha=self.alpha_wv
|
107
|
-
)
|
108
|
-
if self.normalize_waves:
|
109
|
-
initial_audio = self.ap.normalize_audio(initial_audio)
|
110
|
-
if initial_audio.shape[-1] < self.post_n_fft:
|
111
|
-
return None
|
112
|
-
|
113
|
-
if self.min_frame_length is not None:
|
114
|
-
if self.min_frame_length > initial_audio.shape[-1]:
|
115
|
-
return None
|
116
|
-
if (
|
117
|
-
not self.max_frame_length
|
118
|
-
or initial_audio.shape[-1] <= self.max_frame_length
|
119
|
-
):
|
120
|
-
|
121
|
-
audio_rms = self.ap.compute_rms(initial_audio)
|
122
|
-
audio_pitch = self.ap.compute_pitch(initial_audio)
|
123
|
-
audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
|
124
|
-
|
125
|
-
return [
|
126
|
-
self._add_dict(initial_audio, audio_mel, audio_pitch, audio_rms, file)
|
127
|
-
]
|
128
|
-
results = []
|
129
|
-
|
130
|
-
if self.randomize_ranges:
|
131
|
-
frame_limit = random.randint(self.r_range, self.max_frame_length)
|
132
|
-
else:
|
133
|
-
frame_limit = self.max_frame_length
|
134
|
-
|
135
|
-
fragments = list(
|
136
|
-
torch.split(initial_audio, split_size_or_sections=frame_limit, dim=-1)
|
137
|
-
)
|
138
|
-
random.shuffle(fragments)
|
139
|
-
for fragment in fragments:
|
140
|
-
if fragment.shape[-1] < self.post_n_fft:
|
141
|
-
# Too small
|
142
|
-
continue
|
143
|
-
if (
|
144
|
-
self.min_frame_length is not None
|
145
|
-
and self.min_frame_length > fragment.shape[-1]
|
146
|
-
):
|
147
|
-
continue
|
148
|
-
|
149
|
-
audio_rms = self.ap.compute_rms(fragment)
|
150
|
-
audio_pitch = self.ap.compute_pitch(fragment)
|
151
|
-
audio_mel = self.ap.compute_mel(fragment, add_base=True)
|
152
|
-
results.append(
|
153
|
-
self._add_dict(fragment, audio_mel, audio_pitch, audio_rms, file)
|
154
|
-
)
|
155
|
-
return results
|
156
|
-
|
157
|
-
def get_data_loader(
|
158
|
-
self,
|
159
|
-
batch_size: int = 1,
|
160
|
-
shuffle: Optional[bool] = None,
|
161
|
-
sampler: Optional[Union[Sampler, Iterable]] = None,
|
162
|
-
batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
|
163
|
-
num_workers: int = 0,
|
164
|
-
pin_memory: bool = False,
|
165
|
-
drop_last: bool = False,
|
166
|
-
timeout: float = 0,
|
167
|
-
):
|
168
|
-
return DataLoader(
|
169
|
-
self,
|
170
|
-
batch_size=batch_size,
|
171
|
-
shuffle=shuffle,
|
172
|
-
sampler=sampler,
|
173
|
-
batch_sampler=batch_sampler,
|
174
|
-
num_workers=num_workers,
|
175
|
-
pin_memory=pin_memory,
|
176
|
-
drop_last=drop_last,
|
177
|
-
timeout=timeout,
|
178
|
-
collate_fn=self.collate_fn,
|
179
|
-
)
|
180
|
-
|
181
|
-
def collate_fn(self, batch: Sequence[Dict[str, Tensor]]):
|
182
|
-
mel = []
|
183
|
-
wave = []
|
184
|
-
file = []
|
185
|
-
rms = []
|
186
|
-
pitch = []
|
187
|
-
for x in batch:
|
188
|
-
mel.append(x["mel"])
|
189
|
-
wave.append(x["wave"])
|
190
|
-
file.append(x["file"])
|
191
|
-
rms.append(x["rms"])
|
192
|
-
pitch.append(x["pitch"])
|
193
|
-
# Find max time in mel (dim -1), and max audio length
|
194
|
-
max_mel_len = max([m.shape[-1] for m in mel])
|
195
|
-
max_audio_len = max([a.shape[-1] for a in wave])
|
196
|
-
max_pitch_len = max([a.shape[-1] for a in pitch])
|
197
|
-
max_rms_len = max([a.shape[-1] for a in rms])
|
198
|
-
|
199
|
-
padded_mel = torch.stack(
|
200
|
-
[FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mel]
|
201
|
-
) # shape: [B, 80, T_max]
|
202
|
-
|
203
|
-
padded_wave = torch.stack(
|
204
|
-
[FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in wave]
|
205
|
-
) # shape: [B, L_max]
|
206
|
-
|
207
|
-
padded_pitch = torch.stack(
|
208
|
-
[FT.pad(a, (0, max_pitch_len - a.shape[-1])) for a in pitch]
|
209
|
-
) # shape: [B, L_max]
|
210
|
-
padded_rms = torch.stack(
|
211
|
-
[FT.pad(a, (0, max_rms_len - a.shape[-1])) for a in rms]
|
212
|
-
) # shape: [B, L_max]
|
213
|
-
return dict(
|
214
|
-
mel=padded_mel,
|
215
|
-
wave=padded_wave,
|
216
|
-
pitch=padded_pitch,
|
217
|
-
rms=padded_rms,
|
218
|
-
file=file,
|
219
|
-
)
|
220
|
-
|
221
|
-
def get_item(self, idx: int):
|
222
|
-
if self.pre_loaded:
|
223
|
-
return self.cached_data[idx]
|
224
|
-
file = self.files[idx]
|
225
|
-
if file not in self.loaded_files:
|
226
|
-
self.loaded_files[file] = self.load_data(file)
|
227
|
-
return random.choice(self.loaded_files[file])
|
228
|
-
|
229
|
-
def __len__(self):
|
230
|
-
if self.pre_loaded:
|
231
|
-
return len(self.cached_data)
|
232
|
-
return len(self.files)
|
233
|
-
|
234
|
-
def __getitem__(self, index: int):
|
235
|
-
return self.get_item(index)
|
File without changes
|
File without changes
|
File without changes
|