lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -12,6 +12,7 @@ from . import (
12
12
  losses,
13
13
  processors,
14
14
  datasets,
15
+ torch_commons,
15
16
  )
16
17
 
17
18
  __all__ = [
@@ -26,4 +27,5 @@ __all__ = [
26
27
  "losses",
27
28
  "processors",
28
29
  "datasets",
30
+ "torch_commons",
29
31
  ]
@@ -0,0 +1,97 @@
1
+ from lt_utils.common import *
2
+ from lt_utils.file_ops import load_json, save_json, FileScan
3
+ from lt_utils.misc_utils import log_traceback, get_current_time
4
+ from lt_utils.type_utils import is_pathlike, is_file, is_dir, is_dict, is_str
5
+ from lt_tensor.misc_utils import updateDict
6
+
7
+
8
+ class ModelConfig(ABC, OrderedDict):
9
+ _default_settings: Dict[str, Any] = {}
10
+ _forbidden_list: List[str] = [
11
+ "_settings",
12
+ ]
13
+
14
+ def __init__(
15
+ self,
16
+ settings: Dict[str, Any] = None,
17
+ path_name: Optional[Union[str, PathLike]] = None,
18
+ ):
19
+ assert is_dict(settings)
20
+ self._default_settings = settings
21
+ if path_name is not None and is_pathlike(path_name):
22
+ if not str(path_name).endswith(".json"):
23
+ self.path_name = str(Path(path_name, "config.json")).replace("\\", "/")
24
+ else:
25
+ self.path_name = str(path_name).replace("\\", "/")
26
+ else:
27
+ self.path_name = "config.json"
28
+ self.reset_settings()
29
+
30
+ def _setup_path_name(self, path_name: Union[str, PathLike]):
31
+ if is_file(path_name):
32
+ self.from_path(path_name)
33
+ self.path_name = str(path_name).replace("\\", "/")
34
+ elif is_str(path_name):
35
+ self.path_name = str(path_name).replace("\\", "/")
36
+ if not self.path_name.endswith((".json")):
37
+ self.path_name += ".json"
38
+
39
+ def reset_settings(self):
40
+ for s_name, setting in self._default_settings.items():
41
+ if s_name in self._forbidden_list:
42
+ continue
43
+ updateDict(self, {s_name: setting})
44
+
45
+ def save_config(
46
+ self,
47
+ path_name: Union[PathLike, str],
48
+ ):
49
+ assert is_pathlike(
50
+ path_name, True
51
+ ), f"path_name should be a non-empty string or pathlike object! received instead: {path_name}"
52
+ self._setup_path_name(path_name)
53
+ base = {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
54
+ save_json(self.path_name, base, indent=2)
55
+
56
+ def to_dict(self):
57
+ return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
58
+
59
+ def set_value(self, var_name: str, value: str) -> None:
60
+ updateDict(self, {var_name: value})
61
+
62
+ def get_value(self, var_name: str) -> Any:
63
+ return self.__dict__.get(var_name)
64
+
65
+ @classmethod
66
+ def from_dict(
67
+ cls, dictionary: Dict[str, Any], path: Optional[Union[str, PathLike]] = None
68
+ ) -> "ModelConfig":
69
+ assert is_dict(dictionary)
70
+ return ModelConfig(dictionary, path)
71
+
72
+ @classmethod
73
+ def from_path(cls, path_name: PathLike) -> "ModelConfig":
74
+ assert is_file(path_name) or is_dir(path_name)
75
+ settings = {}
76
+
77
+ if is_file(path_name):
78
+ settings.update(load_json(path_name, {}, errors="ignore"))
79
+ else:
80
+ files = FileScan.files(
81
+ path_name,
82
+ [
83
+ "*_config.json",
84
+ "config_*.json",
85
+ "*_config.json",
86
+ "cfg_*.json",
87
+ "*_cfg.json",
88
+ "cfg.json",
89
+ "config.json",
90
+ "settings.json",
91
+ "settings_*.json",
92
+ "*_settings.json",
93
+ ],
94
+ )
95
+ assert files, "No config file found in the provided directory!"
96
+ settings.update(load_json(files[-1], {}, errors="ignore"))
97
+ return ModelConfig(settings, path_name)
@@ -1,58 +1,140 @@
1
1
  __all__ = ["WaveMelDataset"]
2
2
  from lt_tensor.torch_commons import *
3
3
  from lt_utils.common import *
4
+ from lt_utils.misc_utils import default
4
5
  import random
5
6
  from torch.utils.data import Dataset, DataLoader, Sampler
6
7
  from lt_tensor.processors import AudioProcessor
7
8
  import torch.nn.functional as FT
8
9
  from lt_tensor.misc_utils import log_tensor
10
+ from tqdm import tqdm
9
11
 
10
12
 
11
- class WaveMelDataset(Dataset):
12
- """Untested!"""
13
+ DEFAULT_DEVICE = torch.tensor([0]).device
14
+
13
15
 
14
- data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
16
+ class WaveMelDataset(Dataset):
17
+ cached_data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
18
+ loaded_files: Dict[str, List[Dict[str, Tensor]]] = {}
19
+ normalize_waves: bool = False
20
+ randomize_ranges: bool = False
21
+ alpha_wv: float = 1.0
22
+ limit_files: Optional[int] = None
23
+ max_frame_length: Optional[int] = None
15
24
 
16
25
  def __init__(
17
26
  self,
18
27
  audio_processor: AudioProcessor,
19
- path: PathLike,
28
+ dataset_path: PathLike,
20
29
  limit_files: Optional[int] = None,
21
30
  max_frame_length: Optional[int] = None,
31
+ randomize_ranges: Optional[bool] = None,
32
+ pre_load: bool = False,
33
+ normalize_waves: Optional[bool] = None,
34
+ alpha_wv: Optional[float] = None,
35
+ n_noises: int = 0, # TODO: Implement the random noises into the dataset
22
36
  ):
23
37
  super().__init__()
24
38
  assert max_frame_length is None or max_frame_length >= (
25
39
  (audio_processor.n_fft // 2) + 1
26
40
  )
27
- self.post_n_fft = (audio_processor.n_fft // 2) + 1
28
41
  self.ap = audio_processor
29
- self.files = self.ap.find_audios(path)
42
+ self.dataset_path = dataset_path
43
+ if limit_files:
44
+ self.limit_files = limit_files
45
+ if normalize_waves is not None:
46
+ self.normalize_waves = normalize_waves
47
+ if alpha_wv is not None:
48
+ self.alpha_wv = alpha_wv
49
+ if pre_load is not None:
50
+ self.pre_loaded = pre_load
51
+ if randomize_ranges is not None:
52
+ self.randomize_ranges = randomize_ranges
53
+
54
+ self.post_n_fft = (audio_processor.n_fft // 2) + 1
55
+
56
+ if max_frame_length is not None:
57
+ max_frame_length = max(self.post_n_fft + 1, max_frame_length)
58
+ self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
59
+ self.max_frame_length = max_frame_length
60
+
61
+ self.files = self.ap.find_audios(dataset_path, maximum=None)
30
62
  if limit_files:
31
63
  random.shuffle(self.files)
32
- self.files = self.files[:limit_files]
33
- self.data = []
64
+ self.files = self.files[-self.limit_files :]
65
+ if pre_load:
66
+ for file in tqdm(self.files, "Loading files"):
67
+ results = self.load_data(file)
68
+ if not results:
69
+ continue
70
+ self.cached_data.extend(results)
34
71
 
35
- for file in self.files:
36
- results = self.load_data(file, max_frame_length)
37
- self.data.extend(results)
72
+ def renew_dataset(self, new_path: Optional[PathLike] = None):
73
+ new_path = default(new_path, self.dataset_path)
74
+ self.files = self.ap.find_audios(new_path, maximum=None)
75
+ random.shuffle(self.files)
76
+ for file in tqdm(self.files, "Loading files"):
77
+ results = self.load_data(file)
78
+ if not results:
79
+ continue
80
+ self.cached_data.extend(results)
38
81
 
39
- def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
40
- return {"mel": audio_mel, "raw": audio_raw, "file": file}
82
+ def _add_dict(
83
+ self,
84
+ audio_wave: Tensor,
85
+ audio_mel: Tensor,
86
+ pitch: Tensor,
87
+ rms: Tensor,
88
+ file: PathLike,
89
+ ):
90
+ return {
91
+ "wave": audio_wave,
92
+ "pitch": pitch,
93
+ "rms": rms,
94
+ "mel": audio_mel,
95
+ "file": file,
96
+ }
41
97
 
42
- def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
43
- initial_audio = self.ap.load_audio(file)
44
- if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
98
+ def load_data(self, file: PathLike):
99
+ initial_audio = self.ap.normalize_audio(
100
+ self.ap.load_audio(
101
+ file, normalize=self.normalize_waves, alpha=self.alpha_wv
102
+ )
103
+ )
104
+ if initial_audio.shape[-1] < self.post_n_fft:
105
+ return None
106
+
107
+ if (
108
+ not self.max_frame_length
109
+ or initial_audio.shape[-1] <= self.max_frame_length
110
+ ):
111
+ audio_rms = self.ap.compute_rms(initial_audio)
112
+ audio_pitch = self.ap.compute_pitch(initial_audio)
45
113
  audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
46
- return [self._add_dict(initial_audio, audio_mel, file)]
114
+ return [
115
+ self._add_dict(initial_audio, audio_mel, audio_pitch, audio_rms, file)
116
+ ]
47
117
  results = []
48
- for fragment in torch.split(
49
- initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
50
- ):
118
+
119
+ if self.randomize_ranges:
120
+ frame_limit = random.randint(self.r_range, self.max_frame_length)
121
+ else:
122
+ frame_limit = self.max_frame_length
123
+
124
+ fragments = list(
125
+ torch.split(initial_audio, split_size_or_sections=frame_limit, dim=-1)
126
+ )
127
+ random.shuffle(fragments)
128
+ for fragment in fragments:
51
129
  if fragment.shape[-1] < self.post_n_fft:
52
- # sometimes the tensor will be too small to be able to pass on mel
130
+ # Too small
53
131
  continue
132
+ audio_rms = self.ap.compute_rms(fragment)
133
+ audio_pitch = self.ap.compute_pitch(fragment)
54
134
  audio_mel = self.ap.compute_mel(fragment, add_base=True)
55
- results.append(self._add_dict(fragment, audio_mel, file))
135
+ results.append(
136
+ self._add_dict(fragment, audio_mel, audio_pitch, audio_rms, file)
137
+ )
56
138
  return results
57
139
 
58
140
  def get_data_loader(
@@ -79,31 +161,58 @@ class WaveMelDataset(Dataset):
79
161
  collate_fn=self.collate_fn,
80
162
  )
81
163
 
82
- @staticmethod
83
- def collate_fn(batch: Sequence[Dict[str, Tensor]]):
84
- mels = []
85
- audios = []
86
- files = []
164
+ def collate_fn(self, batch: Sequence[Dict[str, Tensor]]):
165
+ mel = []
166
+ wave = []
167
+ file = []
168
+ rms = []
169
+ pitch = []
87
170
  for x in batch:
88
- mels.append(x["mel"])
89
- audios.append(x["raw"])
90
- files.append(x["file"])
171
+ mel.append(x["mel"])
172
+ wave.append(x["wave"])
173
+ file.append(x["file"])
174
+ rms.append(x["rms"])
175
+ pitch.append(x["pitch"])
91
176
  # Find max time in mel (dim -1), and max audio length
92
- max_mel_len = max([m.shape[-1] for m in mels])
93
- max_audio_len = max([a.shape[-1] for a in audios])
177
+ max_mel_len = max([m.shape[-1] for m in mel])
178
+ max_audio_len = max([a.shape[-1] for a in wave])
179
+ max_pitch_len = max([a.shape[-1] for a in pitch])
180
+ max_rms_len = max([a.shape[-1] for a in rms])
94
181
 
95
- padded_mels = torch.stack(
96
- [FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mels]
182
+ padded_mel = torch.stack(
183
+ [FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mel]
97
184
  ) # shape: [B, 80, T_max]
98
185
 
99
- padded_audios = torch.stack(
100
- [FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in audios]
186
+ padded_wave = torch.stack(
187
+ [FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in wave]
188
+ ) # shape: [B, L_max]
189
+
190
+ padded_pitch = torch.stack(
191
+ [FT.pad(a, (0, max_pitch_len - a.shape[-1])) for a in pitch]
101
192
  ) # shape: [B, L_max]
193
+ padded_rms = torch.stack(
194
+ [FT.pad(a, (0, max_rms_len - a.shape[-1])) for a in rms]
195
+ ) # shape: [B, L_max]
196
+ return dict(
197
+ mel=padded_mel,
198
+ wave=padded_wave,
199
+ pitch=padded_pitch,
200
+ rms=padded_rms,
201
+ file=file,
202
+ )
102
203
 
103
- return padded_mels, padded_audios, files
204
+ def get_item(self, idx: int):
205
+ if self.pre_loaded:
206
+ return self.cached_data[idx]
207
+ file = self.files[idx]
208
+ if file not in self.loaded_files:
209
+ self.loaded_files[file] = self.load_data(file)
210
+ return random.choice(self.loaded_files[file])
104
211
 
105
212
  def __len__(self):
106
- return len(self.data)
213
+ if self.pre_loaded:
214
+ return len(self.cached_data)
215
+ return len(self.files)
107
216
 
108
- def __getitem__(self, index):
109
- return self.data[index]
217
+ def __getitem__(self, index: int):
218
+ return self.get_item(index)
lt_tensor/losses.py CHANGED
@@ -11,7 +11,7 @@ __all__ = [
11
11
  ]
12
12
  import math
13
13
  import random
14
- from .torch_commons import *
14
+ from lt_tensor.torch_commons import *
15
15
  from lt_utils.common import *
16
16
  import torch.nn.functional as F
17
17
 
lt_tensor/math_ops.py CHANGED
@@ -12,7 +12,7 @@ __all__ = [
12
12
  "phase",
13
13
  ]
14
14
 
15
- from .torch_commons import *
15
+ from lt_tensor.torch_commons import *
16
16
 
17
17
 
18
18
  def sin_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
lt_tensor/misc_utils.py CHANGED
@@ -21,6 +21,9 @@ __all__ = [
21
21
  "Masking",
22
22
  "LogTensor",
23
23
  "get_losses",
24
+ "plot_view",
25
+ "get_weights",
26
+ "get_activated_conv",
24
27
  ]
25
28
 
26
29
  import re
@@ -28,13 +31,116 @@ import gc
28
31
  import sys
29
32
  import random
30
33
  import numpy as np
31
- from lt_utils.type_utils import is_str
32
- from .torch_commons import *
34
+ import warnings
35
+ from lt_utils.type_utils import is_str, is_dir, is_file, is_pathlike, is_path_valid
36
+ from lt_utils.file_ops import FileScan, find_files, path_to_str, load_json, load_yaml
37
+ from lt_tensor.torch_commons import *
33
38
  from lt_utils.misc_utils import cache_wrapper
34
39
  from lt_utils.common import *
35
40
  from lt_utils.misc_utils import ff_list
36
41
  import torch.nn.functional as F
37
42
 
43
+ CONV_MAP = {
44
+ "conv1d": nn.Conv1d,
45
+ "conv1d": nn.Conv2d,
46
+ "conv1d": nn.Conv3d,
47
+ "convtranspose1d": nn.ConvTranspose1d,
48
+ "convtranspose2d": nn.ConvTranspose2d,
49
+ "convtranspose3d": nn.ConvTranspose3d,
50
+ }
51
+
52
+
53
+ def get_activated_conv(
54
+ in_channels: int,
55
+ out_channels: int,
56
+ kernel_size: int = 1,
57
+ stride: int = 1,
58
+ padding: Union[int, str] = 0,
59
+ groups: int = 1,
60
+ conv_type: Literal[
61
+ "Conv1d",
62
+ "Conv2d",
63
+ "Conv3d",
64
+ "ConvTranspose1d",
65
+ "ConvTranspose2d",
66
+ "ConvTranspose3d",
67
+ ] = "Conv1d",
68
+ activation: nn.Module = nn.LeakyReLU(0.2),
69
+ norm_fn: Callable[[nn.Module], nn.Module] = lambda x: x,
70
+ ):
71
+ assert conv_type.lower() in CONV_MAP, f"Invalid conv type: {conv_type}."
72
+ return nn.Sequential(
73
+ activation,
74
+ norm_fn(
75
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups)
76
+ ),
77
+ )
78
+
79
+
80
+ def plot_view(
81
+ data: Dict[str, List[Any]],
82
+ title: str = "Loss",
83
+ max_amount: int = 0,
84
+ xaxis_title="Step/Epoch",
85
+ yaxis_title="Loss",
86
+ template="plotly_dark",
87
+ ):
88
+ try:
89
+ import plotly.graph_objs as go
90
+ except ModuleNotFoundError:
91
+ warnings.warn(
92
+ "No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
93
+ )
94
+ return
95
+ fig = go.Figure()
96
+ for mode, values in data.items():
97
+ if values:
98
+ items = values if not max_amount > 0 else values[-max_amount:]
99
+ fig.add_trace(go.Scatter(y=items, name=mode.capitalize()))
100
+ fig.update_layout(
101
+ title=title,
102
+ xaxis_title=xaxis_title,
103
+ yaxis_title=yaxis_title,
104
+ template=template,
105
+ )
106
+ return fig
107
+
108
+
109
+ def get_weights(directory: Union[str, PathLike]):
110
+ is_path_valid(directory, validate=True) # raises validation if its invalid path
111
+ directory = Path(directory)
112
+ if is_file(directory):
113
+ if directory.name.endswith((".pt", ".ckpt", ".pth")):
114
+ return directory
115
+ directory = directory.parent
116
+ res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
117
+ return res[-1] if res else None
118
+
119
+
120
+ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
121
+ # raises validation if its invalid path only when default is None otherwise it returns the defaults.
122
+ if not is_path_valid(directory, validate=default is None):
123
+ return default
124
+ directory = Path(directory)
125
+ if is_file(directory):
126
+ if directory.name.endswith((".json", ".yaml", ".yml")):
127
+ if directory.name.endswith(".json"):
128
+ return load_json(directory, default)
129
+ return load_yaml(directory, default)
130
+ directory = directory.parent
131
+ res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
132
+ if res:
133
+ res = res[-1]
134
+ if Path(res).name.endswith(".json"):
135
+ return load_json(directory, default)
136
+ return load_yaml(directory, default)
137
+ return default
138
+
139
+
140
+ def updateDict(self, dct: dict[str, Any]):
141
+ for k, v in dct.items():
142
+ setattr(self, k, v)
143
+
38
144
 
39
145
  def soft_restore(tensor, epsilon=1e-6):
40
146
  return torch.where(tensor == 0, torch.full_like(tensor, epsilon), tensor)