lt-tensor 0.0.1a18__py3-none-any.whl → 0.0.1a20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -11,7 +11,6 @@ from . import (
11
11
  noise_tools,
12
12
  losses,
13
13
  processors,
14
- datasets,
15
14
  torch_commons,
16
15
  )
17
16
 
@@ -26,6 +25,5 @@ __all__ = [
26
25
  "noise_tools",
27
26
  "losses",
28
27
  "processors",
29
- "datasets",
30
28
  "torch_commons",
31
29
  ]
@@ -7,75 +7,58 @@ from lt_tensor.misc_utils import updateDict
7
7
 
8
8
  class ModelConfig(ABC, OrderedDict):
9
9
  _default_settings: Dict[str, Any] = {}
10
- _forbidden_list: List[str] = ["_default_settings", "_forbidden_list" "path_name"]
10
+ _forbidden_list: List[str] = [
11
+ "_default_settings",
12
+ "_forbidden_list",
13
+ ]
11
14
 
12
15
  def __init__(
13
16
  self,
14
- settings: Dict[str, Any] = {},
15
- path_name: Optional[Union[str, PathLike]] = None,
17
+ **settings,
16
18
  ):
17
- assert is_dict(settings, False)
18
19
  self._default_settings = settings
19
- if path_name is not None and is_pathlike(path_name):
20
- if not str(path_name).endswith(".json"):
21
- self.path_name = str(Path(path_name, "config.json")).replace("\\", "/")
22
- else:
23
- self.path_name = str(path_name).replace("\\", "/")
24
- else:
25
- self.path_name = "config.json"
26
- self.reset_settings()
27
-
28
- def _setup_path_name(self, path_name: Union[str, PathLike]):
29
- if is_file(path_name):
30
- self.from_path(path_name)
31
- self.path_name = str(path_name).replace("\\", "/")
32
- elif is_str(path_name):
33
- self.path_name = str(path_name).replace("\\", "/")
34
- if not self.path_name.endswith((".json")):
35
- self.path_name += ".json"
20
+ self.set_state_dict(self._default_settings)
36
21
 
37
22
  def reset_settings(self):
38
- dk_keys = self.__dict__.keys()
39
- for s_name, setting in self._default_settings.items():
40
- if s_name in self._forbidden_list or s_name not in dk_keys:
41
- continue
42
- updateDict(self, {s_name: setting})
23
+ raise NotImplementedError("Not implemented")
43
24
 
44
25
  def save_config(
45
26
  self,
46
- path_name: Optional[Union[PathLike, str]] = None,
27
+ path: str,
47
28
  ):
48
- if not is_pathlike(path_name, True):
49
- assert (
50
- path_name is None
51
- ), f"path_name should be a non-empty string or pathlike object! received instead: {path_name}."
52
- path_name = self.path_name
53
- else:
54
- self._setup_path_name(path_name)
29
+ base = {
30
+ k: v
31
+ for k, v in self.state_dict().items()
32
+ if isinstance(v, (str, int, float, list, tuple, dict, set, bytes))
33
+ }
55
34
 
56
- base = self.get_state_dict()
57
- save_json(self.path_name, base, indent=2)
35
+ save_json(path, base, indent=4)
58
36
 
59
37
  def set_value(self, var_name: str, value: str) -> None:
60
- assert var_name in self.__dict__, "Value not registered!"
61
38
  assert var_name not in self._forbidden_list, "Not allowed!"
62
39
  updateDict(self, {var_name: value})
40
+ self.update({var_name: value})
63
41
 
64
42
  def get_value(self, var_name: str) -> Any:
65
43
  return self.__dict__.get(var_name)
66
44
 
67
- def __getattribute__(self, name):
68
- return self.__dict__.get(name)
45
+ def set_state_dict(self, new_state: dict[str, str]):
46
+ new_state = {
47
+ k: y for k, y in new_state.items() if k not in self._forbidden_list
48
+ }
49
+ updateDict(self, new_state)
50
+ self.update(**new_state)
69
51
 
70
- def get_state_dict(self):
52
+ def state_dict(self):
71
53
  return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
72
54
 
73
55
  @classmethod
74
56
  def from_dict(
75
- cls, dictionary: Dict[str, Any], path: Optional[Union[str, PathLike]] = None
57
+ cls,
58
+ dictionary: Dict[str, Any],
76
59
  ) -> "ModelConfig":
77
60
  assert is_dict(dictionary)
78
- return ModelConfig(dictionary, path)
61
+ return ModelConfig(**dictionary)
79
62
 
80
63
  @classmethod
81
64
  def from_path(cls, path_name: PathLike) -> "ModelConfig":
@@ -102,4 +85,7 @@ class ModelConfig(ABC, OrderedDict):
102
85
  )
103
86
  assert files, "No config file found in the provided directory!"
104
87
  settings.update(load_json(files[-1], {}, errors="ignore"))
105
- return ModelConfig(settings, path_name)
88
+ settings.pop("path", None)
89
+ settings.pop("path_name", None)
90
+
91
+ return ModelConfig(**settings)
@@ -1,9 +1,8 @@
1
1
  __all__ = ["DiffWave", "DiffWaveConfig", "SpectrogramUpsample", "DiffusionEmbedding"]
2
2
 
3
3
  import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
4
+ from lt_tensor.torch_commons import *
5
+ from torch.nn import functional as F
7
6
  from lt_tensor.config_templates import ModelConfig
8
7
  from lt_tensor.torch_commons import *
9
8
  from lt_tensor.model_base import Model
@@ -12,16 +11,9 @@ from lt_utils.common import *
12
11
 
13
12
 
14
13
  class DiffWaveConfig(ModelConfig):
15
- # Training params
16
- batch_size = 16
17
- learning_rate = 2e-4
18
- max_grad_norm = None
19
- # Data params
20
- sample_rate = 24000
14
+ # Model params
21
15
  n_mels = 80
22
- n_fft = 1024
23
16
  hop_samples = 256
24
- # Model params
25
17
  residual_layers = 30
26
18
  residual_channels = 64
27
19
  dilation_cycle_length = 10
@@ -35,10 +27,36 @@ class DiffWaveConfig(ModelConfig):
35
27
 
36
28
  def __init__(
37
29
  self,
38
- settings: Dict[str, Any] = {},
39
- path_name: Optional[Union[str, PathLike]] = None,
30
+ n_mels=80,
31
+ hop_samples=256,
32
+ residual_layers=30,
33
+ residual_channels=64,
34
+ dilation_cycle_length=10,
35
+ unconditional=False,
36
+ noise_schedule: list[int] = np.linspace(1e-4, 0.05, 50).tolist(),
37
+ interpolate_cond=False,
38
+ interpolation_mode: Literal[
39
+ "nearest",
40
+ "linear",
41
+ "bilinear",
42
+ "bicubic",
43
+ "trilinear",
44
+ "area",
45
+ "nearest-exact",
46
+ ] = "nearest",
40
47
  ):
41
- super().__init__(settings, path_name)
48
+ settings = {
49
+ "n_mels": n_mels,
50
+ "hop_samples": hop_samples,
51
+ "residual_layers": residual_layers,
52
+ "dilation_cycle_length": dilation_cycle_length,
53
+ "residual_channels": residual_channels,
54
+ "unconditional": unconditional,
55
+ "noise_schedule": noise_schedule,
56
+ "interpolate": interpolate_cond,
57
+ "interpolation_mode": interpolation_mode,
58
+ }
59
+ super().__init__(**settings)
42
60
 
43
61
 
44
62
  def Conv1d(*args, **kwargs):
@@ -4,10 +4,6 @@ from lt_tensor.torch_commons import *
4
4
  from lt_tensor.model_zoo.residual import ConvNets
5
5
  from torch.nn import functional as F
6
6
 
7
- import torch
8
- import torch.nn.functional as F
9
- import torch.nn as nn
10
-
11
7
 
12
8
  def get_padding(kernel_size, dilation=1):
13
9
  return int((kernel_size * dilation - dilation) / 2)
@@ -21,7 +17,7 @@ class HifiganConfig(ModelConfig):
21
17
  in_channels: int = 80
22
18
  upsample_rates: List[Union[int, List[int]]] = [8, 8]
23
19
  upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16]
24
- upsample_initial_channel: int = (512,)
20
+ upsample_initial_channel: int = 512
25
21
  resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
26
22
  resblock_dilation_sizes: List[Union[int, List[int]]] = [
27
23
  [1, 3, 5],
@@ -34,10 +30,32 @@ class HifiganConfig(ModelConfig):
34
30
 
35
31
  def __init__(
36
32
  self,
37
- settings: Dict[str, Any] = {},
38
- path_name: Optional[Union[str, PathLike]] = None,
33
+ in_channels: int = 80,
34
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
35
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
36
+ upsample_initial_channel: int = 512,
37
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
38
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
39
+ [1, 3, 5],
40
+ [1, 3, 5],
41
+ [1, 3, 5],
42
+ ],
43
+ activation: nn.Module = nn.LeakyReLU(0.1),
44
+ resblock: int = 0,
45
+ *args,
46
+ **kwargs,
39
47
  ):
40
- super().__init__(settings, path_name)
48
+ settings = {
49
+ "in_channels": in_channels,
50
+ "upsample_rates": upsample_rates,
51
+ "upsample_kernel_sizes": upsample_kernel_sizes,
52
+ "upsample_initial_channel": upsample_initial_channel,
53
+ "resblock_kernel_sizes": resblock_kernel_sizes,
54
+ "resblock_dilation_sizes": resblock_dilation_sizes,
55
+ "activation": activation,
56
+ "resblock": resblock,
57
+ }
58
+ super().__init__(**settings)
41
59
 
42
60
 
43
61
  class ResBlock1(ConvNets):
@@ -177,10 +195,10 @@ class HifiganGenerator(ConvNets):
177
195
  self.conv_pre = weight_norm(
178
196
  nn.Conv1d(cfg.in_channels, cfg.upsample_initial_channel, 7, 1, padding=3)
179
197
  )
180
- resblock = ResBlock1 if resblock == 0 else ResBlock2
198
+ resblock = ResBlock1 if cfg.resblock == 0 else ResBlock2
181
199
  self.activation = cfg.activation
182
200
  self.ups = nn.ModuleList()
183
- for i, (u, k) in enumerate(zip(cfg.psample_rates, cfg.upsample_kernel_sizes)):
201
+ for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
184
202
  self.ups.append(
185
203
  weight_norm(
186
204
  nn.ConvTranspose1d(
@@ -11,7 +11,7 @@ class iSTFTNetConfig(ModelConfig):
11
11
  in_channels: int = 80
12
12
  upsample_rates: List[Union[int, List[int]]] = [8, 8]
13
13
  upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16]
14
- upsample_initial_channel: int = (512,)
14
+ upsample_initial_channel: int = 512
15
15
  resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
16
16
  resblock_dilation_sizes: List[Union[int, List[int]]] = [
17
17
  [1, 3, 5],
@@ -26,10 +26,36 @@ class iSTFTNetConfig(ModelConfig):
26
26
 
27
27
  def __init__(
28
28
  self,
29
- settings: Dict[str, Any] = {},
30
- path_name: Optional[Union[str, PathLike]] = None,
29
+ in_channels: int = 80,
30
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
31
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
32
+ upsample_initial_channel: int = 512,
33
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
34
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
35
+ [1, 3, 5],
36
+ [1, 3, 5],
37
+ [1, 3, 5],
38
+ ],
39
+ activation: nn.Module = nn.LeakyReLU(0.1),
40
+ resblock: int = 0,
41
+ gen_istft_n_fft: int = 16,
42
+ sampling_rate: Number = 24000,
43
+ *args,
44
+ **kwargs,
31
45
  ):
32
- super().__init__(settings, path_name)
46
+ settings = {
47
+ "in_channels": in_channels,
48
+ "upsample_rates": upsample_rates,
49
+ "upsample_kernel_sizes": upsample_kernel_sizes,
50
+ "upsample_initial_channel": upsample_initial_channel,
51
+ "resblock_kernel_sizes": resblock_kernel_sizes,
52
+ "resblock_dilation_sizes": resblock_dilation_sizes,
53
+ "activation": activation,
54
+ "resblock": resblock,
55
+ "gen_istft_n_fft": gen_istft_n_fft,
56
+ "sampling_rate": sampling_rate,
57
+ }
58
+ super().__init__(**settings)
33
59
 
34
60
 
35
61
  def get_padding(ks, d):
@@ -169,7 +195,7 @@ class iSTFTNetGenerator(ConvNets):
169
195
  self.conv_pre = weight_norm(
170
196
  nn.Conv1d(cfg.in_channels, cfg.upsample_initial_channel, 7, 1, padding=3)
171
197
  )
172
- resblock = ResBlock1 if resblock == 0 else ResBlock2
198
+ resblock = ResBlock1 if cfg.resblock == 0 else ResBlock2
173
199
 
174
200
  self.ups = nn.ModuleList()
175
201
  for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a18
3
+ Version: 0.0.1a20
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.2a2
20
+ Requires-Dist: lt-utils==0.0.2a3
21
21
  Requires-Dist: librosa==0.11.*
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
@@ -1,5 +1,5 @@
1
- lt_tensor/__init__.py,sha256=XxNCGcVL-haJyMpifr-GRaamo32R6jmqe3iOuS4ecfs,469
2
- lt_tensor/config_templates.py,sha256=xWZhktYVlkwvJVreqyACpWo-lJ5htG9vTZyqZ6OexzA,3899
1
+ lt_tensor/__init__.py,sha256=8FTxpJ6td2bMr_GqzW2tCV6Tr5CelbQle8N5JRWtx8M,439
2
+ lt_tensor/config_templates.py,sha256=RP7EFVRj6mRUj6xDLe7FMXgN5TIo8_o9h1Kb8epdmfo,2825
3
3
  lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
5
  lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
@@ -9,8 +9,6 @@ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,
9
9
  lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
10
  lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
11
11
  lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
- lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- lt_tensor/datasets/audio.py,sha256=5Wvz1BJ7xXkLYpVLLw9RY3X3RgMdPPeGiN0-MmJDQy0,8045
14
12
  lt_tensor/model_zoo/__init__.py,sha256=ltVTvmOlbOCfDc5Trvg0-Ta_Ujgkw0UVF9V5rqHx-RI,378
15
13
  lt_tensor/model_zoo/basic.py,sha256=pI8HyiHK-cmWcEEaVY_EduUJOjZW6HOtXvJd8Rbhq30,15452
16
14
  lt_tensor/model_zoo/features.py,sha256=DO8dlE0kmPKTNC1Xkv9wKegOOYkQa_rkxM4hhcNwJWA,15655
@@ -19,13 +17,13 @@ lt_tensor/model_zoo/pos_encoder.py,sha256=3d1EYLinCU9UAy-WuEWeYMGhMqaGknCiQ5qEmh
19
17
  lt_tensor/model_zoo/residual.py,sha256=i5V4ju7DB3WesKBVm6KH_LyPoKGDUOyo2Usfs-PyP58,9394
20
18
  lt_tensor/model_zoo/transformer.py,sha256=HUFoFFh7EQJErxdd9XIxhssdjvNVx2tNGDJOTUfwG2A,4301
21
19
  lt_tensor/model_zoo/audio_models/__init__.py,sha256=MoG9YjxLyvscq_6njK1ljGBletK9iedBXt66bplzW-s,83
22
- lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=R14hY-nCbCO-T3ox9f4MXCPgQQogFUKAJ2WtntLz09w,7393
23
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=6ZGYyNiTMGHnOjGU0gq_TSM8Y9LtYlP3neGwa01Ghyk,13135
24
- lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=noi4GLGZQ_qg5H-ipe5d7j8rvt4Hic_sXiME-TE-B2c,13783
20
+ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=OUyh421xRCcxOMi_Ek6Ak3-FPe1k6WTDQ-6gd6OjaCU,8091
21
+ lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=JNebaYO3nsyyqpYCCOyL13zY2uxLY3NOCeNynF6-96k,13940
22
+ lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=JdFChpPhURaI2qb9mDV6vzDcZN757FBGGtgzN3vxtJ0,14821
25
23
  lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
26
24
  lt_tensor/processors/audio.py,sha256=SMqNSl4Den-x1awTCQ8-TcR-0jPiv5lDaUpU93SRRaw,14749
27
- lt_tensor-0.0.1a18.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
28
- lt_tensor-0.0.1a18.dist-info/METADATA,sha256=fgRzOiw5tMmkaEY9HrGEKNL2v9mN5JVbf9r-bf18Am0,1033
29
- lt_tensor-0.0.1a18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- lt_tensor-0.0.1a18.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
31
- lt_tensor-0.0.1a18.dist-info/RECORD,,
25
+ lt_tensor-0.0.1a20.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
26
+ lt_tensor-0.0.1a20.dist-info/METADATA,sha256=y-nfT5fMD5c8JgD299060Aij5Ugz2xHlDYYNoPXtAy8,1033
27
+ lt_tensor-0.0.1a20.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ lt_tensor-0.0.1a20.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
29
+ lt_tensor-0.0.1a20.dist-info/RECORD,,
File without changes
@@ -1,235 +0,0 @@
1
- __all__ = ["WaveMelDataset"]
2
- from lt_tensor.torch_commons import *
3
- from lt_utils.common import *
4
- from lt_utils.misc_utils import default
5
- import random
6
- from torch.utils.data import Dataset, DataLoader, Sampler
7
- from lt_tensor.processors import AudioProcessor
8
- import torch.nn.functional as FT
9
- from lt_tensor.misc_utils import log_tensor
10
- from tqdm import tqdm
11
-
12
-
13
- DEFAULT_DEVICE = torch.tensor([0]).device
14
-
15
-
16
- class WaveMelDataset(Dataset):
17
- cached_data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
18
- loaded_files: Dict[str, List[Dict[str, Tensor]]] = {}
19
- normalize_waves: bool = False
20
- randomize_ranges: bool = False
21
- alpha_wv: float = 1.0
22
- limit_files: Optional[int] = None
23
- min_frame_length: Optional[int] = None
24
- max_frame_length: Optional[int] = None
25
-
26
- def __init__(
27
- self,
28
- audio_processor: AudioProcessor,
29
- dataset_path: PathLike,
30
- limit_files: Optional[int] = None,
31
- min_frame_length: Optional[int] = None,
32
- max_frame_length: Optional[int] = None,
33
- randomize_ranges: Optional[bool] = None,
34
- pre_load: bool = False,
35
- normalize_waves: Optional[bool] = None,
36
- alpha_wv: Optional[float] = None,
37
- lib_norm: bool = True,
38
- ):
39
- super().__init__()
40
- assert max_frame_length is None or max_frame_length >= (
41
- (audio_processor.n_fft // 2) + 1
42
- )
43
- self.ap = audio_processor
44
- self.dataset_path = dataset_path
45
- if limit_files:
46
- self.limit_files = limit_files
47
- if normalize_waves is not None:
48
- self.normalize_waves = normalize_waves
49
- if alpha_wv is not None:
50
- self.alpha_wv = alpha_wv
51
- if pre_load is not None:
52
- self.pre_loaded = pre_load
53
- if randomize_ranges is not None:
54
- self.randomize_ranges = randomize_ranges
55
-
56
- self.post_n_fft = (audio_processor.n_fft // 2) + 1
57
- self.lib_norm = lib_norm
58
- if max_frame_length is not None:
59
- max_frame_length = max(self.post_n_fft + 1, max_frame_length)
60
- self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
61
- self.max_frame_length = max_frame_length
62
- if min_frame_length is not None:
63
- self.min_frame_length = max(
64
- self.post_n_fft + 1, min(min_frame_length, max_frame_length)
65
- )
66
-
67
- self.files = self.ap.find_audios(dataset_path, maximum=None)
68
- if limit_files:
69
- random.shuffle(self.files)
70
- self.files = self.files[-self.limit_files :]
71
- if pre_load:
72
- for file in tqdm(self.files, "Loading files"):
73
- results = self.load_data(file)
74
- if not results:
75
- continue
76
- self.cached_data.extend(results)
77
-
78
- def renew_dataset(self, new_path: Optional[PathLike] = None):
79
- new_path = default(new_path, self.dataset_path)
80
- self.files = self.ap.find_audios(new_path, maximum=None)
81
- random.shuffle(self.files)
82
- for file in tqdm(self.files, "Loading files"):
83
- results = self.load_data(file)
84
- if not results:
85
- continue
86
- self.cached_data.extend(results)
87
-
88
- def _add_dict(
89
- self,
90
- audio_wave: Tensor,
91
- audio_mel: Tensor,
92
- pitch: Tensor,
93
- rms: Tensor,
94
- file: PathLike,
95
- ):
96
- return {
97
- "wave": audio_wave,
98
- "pitch": pitch,
99
- "rms": rms,
100
- "mel": audio_mel,
101
- "file": file,
102
- }
103
-
104
- def load_data(self, file: PathLike):
105
- initial_audio = self.ap.load_audio(
106
- file, normalize=self.lib_norm, alpha=self.alpha_wv
107
- )
108
- if self.normalize_waves:
109
- initial_audio = self.ap.normalize_audio(initial_audio)
110
- if initial_audio.shape[-1] < self.post_n_fft:
111
- return None
112
-
113
- if self.min_frame_length is not None:
114
- if self.min_frame_length > initial_audio.shape[-1]:
115
- return None
116
- if (
117
- not self.max_frame_length
118
- or initial_audio.shape[-1] <= self.max_frame_length
119
- ):
120
-
121
- audio_rms = self.ap.compute_rms(initial_audio)
122
- audio_pitch = self.ap.compute_pitch(initial_audio)
123
- audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
124
-
125
- return [
126
- self._add_dict(initial_audio, audio_mel, audio_pitch, audio_rms, file)
127
- ]
128
- results = []
129
-
130
- if self.randomize_ranges:
131
- frame_limit = random.randint(self.r_range, self.max_frame_length)
132
- else:
133
- frame_limit = self.max_frame_length
134
-
135
- fragments = list(
136
- torch.split(initial_audio, split_size_or_sections=frame_limit, dim=-1)
137
- )
138
- random.shuffle(fragments)
139
- for fragment in fragments:
140
- if fragment.shape[-1] < self.post_n_fft:
141
- # Too small
142
- continue
143
- if (
144
- self.min_frame_length is not None
145
- and self.min_frame_length > fragment.shape[-1]
146
- ):
147
- continue
148
-
149
- audio_rms = self.ap.compute_rms(fragment)
150
- audio_pitch = self.ap.compute_pitch(fragment)
151
- audio_mel = self.ap.compute_mel(fragment, add_base=True)
152
- results.append(
153
- self._add_dict(fragment, audio_mel, audio_pitch, audio_rms, file)
154
- )
155
- return results
156
-
157
- def get_data_loader(
158
- self,
159
- batch_size: int = 1,
160
- shuffle: Optional[bool] = None,
161
- sampler: Optional[Union[Sampler, Iterable]] = None,
162
- batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
163
- num_workers: int = 0,
164
- pin_memory: bool = False,
165
- drop_last: bool = False,
166
- timeout: float = 0,
167
- ):
168
- return DataLoader(
169
- self,
170
- batch_size=batch_size,
171
- shuffle=shuffle,
172
- sampler=sampler,
173
- batch_sampler=batch_sampler,
174
- num_workers=num_workers,
175
- pin_memory=pin_memory,
176
- drop_last=drop_last,
177
- timeout=timeout,
178
- collate_fn=self.collate_fn,
179
- )
180
-
181
- def collate_fn(self, batch: Sequence[Dict[str, Tensor]]):
182
- mel = []
183
- wave = []
184
- file = []
185
- rms = []
186
- pitch = []
187
- for x in batch:
188
- mel.append(x["mel"])
189
- wave.append(x["wave"])
190
- file.append(x["file"])
191
- rms.append(x["rms"])
192
- pitch.append(x["pitch"])
193
- # Find max time in mel (dim -1), and max audio length
194
- max_mel_len = max([m.shape[-1] for m in mel])
195
- max_audio_len = max([a.shape[-1] for a in wave])
196
- max_pitch_len = max([a.shape[-1] for a in pitch])
197
- max_rms_len = max([a.shape[-1] for a in rms])
198
-
199
- padded_mel = torch.stack(
200
- [FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mel]
201
- ) # shape: [B, 80, T_max]
202
-
203
- padded_wave = torch.stack(
204
- [FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in wave]
205
- ) # shape: [B, L_max]
206
-
207
- padded_pitch = torch.stack(
208
- [FT.pad(a, (0, max_pitch_len - a.shape[-1])) for a in pitch]
209
- ) # shape: [B, L_max]
210
- padded_rms = torch.stack(
211
- [FT.pad(a, (0, max_rms_len - a.shape[-1])) for a in rms]
212
- ) # shape: [B, L_max]
213
- return dict(
214
- mel=padded_mel,
215
- wave=padded_wave,
216
- pitch=padded_pitch,
217
- rms=padded_rms,
218
- file=file,
219
- )
220
-
221
- def get_item(self, idx: int):
222
- if self.pre_loaded:
223
- return self.cached_data[idx]
224
- file = self.files[idx]
225
- if file not in self.loaded_files:
226
- self.loaded_files[file] = self.load_data(file)
227
- return random.choice(self.loaded_files[file])
228
-
229
- def __len__(self):
230
- if self.pre_loaded:
231
- return len(self.cached_data)
232
- return len(self.files)
233
-
234
- def __getitem__(self, index: int):
235
- return self.get_item(index)