lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +2 -0
- lt_tensor/config_templates.py +97 -0
- lt_tensor/datasets/audio.py +21 -7
- lt_tensor/losses.py +1 -1
- lt_tensor/math_ops.py +1 -1
- lt_tensor/misc_utils.py +71 -2
- lt_tensor/model_base.py +157 -203
- lt_tensor/model_zoo/__init__.py +2 -2
- lt_tensor/model_zoo/bsc.py +6 -6
- lt_tensor/model_zoo/disc.py +1 -1
- lt_tensor/model_zoo/fsn.py +2 -2
- lt_tensor/model_zoo/gns.py +4 -4
- lt_tensor/model_zoo/istft/__init__.py +5 -0
- lt_tensor/model_zoo/istft/generator.py +150 -0
- lt_tensor/model_zoo/istft/trainer.py +450 -0
- lt_tensor/model_zoo/istft.py +508 -25
- lt_tensor/model_zoo/pos.py +2 -2
- lt_tensor/model_zoo/rsd.py +16 -146
- lt_tensor/model_zoo/tfrms.py +4 -4
- lt_tensor/noise_tools.py +2 -2
- lt_tensor/processors/audio.py +87 -16
- lt_tensor/transform.py +30 -37
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/METADATA +3 -2
- lt_tensor-0.0.1a12.dist-info/RECORD +32 -0
- lt_tensor-0.0.1a11.dist-info/RECORD +0 -28
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a11.dist-info → lt_tensor-0.0.1a12.dist-info}/top_level.txt +0 -0
lt_tensor/model_base.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
|
-
__all__ = ["Model", "
|
1
|
+
__all__ = ["Model", "LossTracker"]
|
2
2
|
|
3
3
|
import gc
|
4
4
|
import json
|
5
5
|
import math
|
6
6
|
import warnings
|
7
|
-
from .torch_commons import *
|
7
|
+
from lt_tensor.torch_commons import *
|
8
8
|
from lt_utils.common import *
|
9
|
+
from lt_tensor.misc_utils import plot_view, get_weights, get_config, updateDict
|
9
10
|
from lt_utils.misc_utils import log_traceback, get_current_time
|
11
|
+
from lt_utils.file_ops import load_json, save_json
|
10
12
|
|
11
13
|
T = TypeVar("T")
|
12
14
|
|
@@ -29,79 +31,47 @@ class LossTracker:
|
|
29
31
|
"eval": [],
|
30
32
|
}
|
31
33
|
|
32
|
-
def append(
|
33
|
-
|
34
|
+
def append(
|
35
|
+
self, loss: float, mode: Union[Literal["train", "eval"], "str"] = "train"
|
36
|
+
):
|
34
37
|
self.history[mode].append(float(loss))
|
35
38
|
if len(self.history[mode]) > self.max_len:
|
36
39
|
self.history[mode] = self.history[mode][-self.max_len :]
|
37
40
|
|
38
|
-
def get(self, mode: Literal["train", "eval"] = "train"):
|
41
|
+
def get(self, mode: Union[Literal["train", "eval"], "str"] = "train"):
|
39
42
|
return self.history.get(mode, [])
|
40
43
|
|
41
44
|
def save(self, path: Optional[PathLike] = None):
|
42
45
|
if path is None:
|
43
46
|
path = f"logs/history_{get_current_time()}.json"
|
44
|
-
|
45
|
-
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
46
|
-
with open(path, "w") as f:
|
47
|
-
json.dump(self.history, f, indent=2)
|
48
|
-
|
47
|
+
save_json(path, self.history, indent=2)
|
49
48
|
self.last_file = path
|
50
49
|
|
51
50
|
def load(self, path: Optional[PathLike] = None):
|
52
51
|
if path is None:
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
with open(_path) as f:
|
57
|
-
self.history = json.load(f)
|
58
|
-
if path is not None:
|
59
|
-
self.last_file = path
|
60
|
-
|
61
|
-
def plot(self, backend: Literal["matplotlib", "plotly"] = "plotly"):
|
62
|
-
if backend == "plotly":
|
63
|
-
try:
|
64
|
-
import plotly.graph_objs as go
|
65
|
-
except ModuleNotFoundError:
|
66
|
-
warnings.warn(
|
67
|
-
"No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
|
68
|
-
)
|
69
|
-
return
|
70
|
-
fig = go.Figure()
|
71
|
-
for mode, losses in self.history.items():
|
72
|
-
if losses:
|
73
|
-
fig.add_trace(go.Scatter(y=losses, name=mode.capitalize()))
|
74
|
-
fig.update_layout(
|
75
|
-
title="Training vs Evaluation Loss",
|
76
|
-
xaxis_title="Step",
|
77
|
-
yaxis_title="Loss",
|
78
|
-
template="plotly_dark",
|
79
|
-
)
|
80
|
-
fig.show()
|
81
|
-
|
82
|
-
elif backend == "matplotlib":
|
83
|
-
import matplotlib.pyplot as plt
|
84
|
-
|
85
|
-
for mode, losses in self.history.items():
|
86
|
-
if losses:
|
87
|
-
plt.plot(losses, label=f"{mode.capitalize()} Loss")
|
88
|
-
plt.title("Loss over Time")
|
89
|
-
plt.xlabel("Step")
|
90
|
-
plt.ylabel("Loss")
|
91
|
-
plt.legend()
|
92
|
-
plt.grid(True)
|
93
|
-
plt.show()
|
52
|
+
path = self.last_file
|
53
|
+
self.history = load_json(path, [])
|
54
|
+
self.last_file = path
|
94
55
|
|
56
|
+
def plot(
|
57
|
+
self,
|
58
|
+
history_keys: Union[str, List[str]],
|
59
|
+
max_amount: int = 0,
|
60
|
+
title: str = "Loss",
|
61
|
+
):
|
62
|
+
if isinstance(history_keys, str):
|
63
|
+
history_keys = [history_keys]
|
64
|
+
return plot_view(
|
65
|
+
{k: v for k, v in self.history.items() if k in history_keys},
|
66
|
+
title,
|
67
|
+
max_amount,
|
68
|
+
)
|
95
69
|
|
96
|
-
class Model(nn.Module, ABC):
|
97
|
-
"""
|
98
|
-
This makes it easier to assign a device and retrieves it later
|
99
|
-
"""
|
100
70
|
|
71
|
+
class _Devices_Base(nn.Module):
|
101
72
|
_device: torch.device = ROOT_DEVICE
|
102
73
|
_autocast: bool = False
|
103
74
|
_loss_history: LossTracker = LossTracker(100_000)
|
104
|
-
_is_unfrozen: bool = False
|
105
75
|
|
106
76
|
@property
|
107
77
|
def autocast(self):
|
@@ -119,7 +89,96 @@ class Model(nn.Module, ABC):
|
|
119
89
|
def device(self, device: Union[torch.device, str]):
|
120
90
|
assert isinstance(device, (str, torch.device))
|
121
91
|
self._device = torch.device(device) if isinstance(device, str) else device
|
122
|
-
|
92
|
+
|
93
|
+
def to(self, *args, **kwargs):
|
94
|
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
95
|
+
*args, **kwargs
|
96
|
+
)
|
97
|
+
|
98
|
+
if dtype is not None:
|
99
|
+
if not (dtype.is_floating_point or dtype.is_complex):
|
100
|
+
raise TypeError(
|
101
|
+
"nn.Module.to only accepts floating point or complex "
|
102
|
+
f"dtypes, but got desired dtype={dtype}"
|
103
|
+
)
|
104
|
+
if dtype.is_complex:
|
105
|
+
warnings.warn(
|
106
|
+
"Complex modules are a new feature under active development whose design may change, "
|
107
|
+
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
|
108
|
+
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
|
109
|
+
"if a complex module does not work as expected."
|
110
|
+
)
|
111
|
+
|
112
|
+
def convert(t: Tensor):
|
113
|
+
try:
|
114
|
+
if convert_to_format is not None and t.dim() in (4, 5):
|
115
|
+
return t.to(
|
116
|
+
device,
|
117
|
+
dtype if t.is_floating_point() or t.is_complex() else None,
|
118
|
+
non_blocking,
|
119
|
+
memory_format=convert_to_format,
|
120
|
+
)
|
121
|
+
return t.to(
|
122
|
+
device,
|
123
|
+
dtype if t.is_floating_point() or t.is_complex() else None,
|
124
|
+
non_blocking,
|
125
|
+
)
|
126
|
+
except NotImplementedError as e:
|
127
|
+
if str(e) == "Cannot copy out of meta tensor; no data!":
|
128
|
+
raise NotImplementedError(
|
129
|
+
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
|
130
|
+
f"when moving module from meta to a different device."
|
131
|
+
) from None
|
132
|
+
else:
|
133
|
+
raise
|
134
|
+
|
135
|
+
self._apply(convert)
|
136
|
+
self.device = device
|
137
|
+
return self
|
138
|
+
|
139
|
+
def _to_dvc(
|
140
|
+
self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
|
141
|
+
):
|
142
|
+
device = device_name
|
143
|
+
if device_id is not None:
|
144
|
+
if isinstance(device_id, Number):
|
145
|
+
device += ":" + str(int(device_id))
|
146
|
+
elif hasattr(device_id, "index"):
|
147
|
+
device += ":" + str(device_id.index)
|
148
|
+
self.device = device
|
149
|
+
|
150
|
+
def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
151
|
+
super().ipu(device)
|
152
|
+
self._to_dvc("ipu", device)
|
153
|
+
return self
|
154
|
+
|
155
|
+
def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
156
|
+
super().xpu(device)
|
157
|
+
self._to_dvc("xpu", device)
|
158
|
+
return self
|
159
|
+
|
160
|
+
def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
161
|
+
super().cuda(device)
|
162
|
+
self._to_dvc("cuda", device)
|
163
|
+
return self
|
164
|
+
|
165
|
+
def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
166
|
+
super().mtia(device)
|
167
|
+
self._to_dvc("mtia", device)
|
168
|
+
return self
|
169
|
+
|
170
|
+
def cpu(self) -> T:
|
171
|
+
super().cpu()
|
172
|
+
self._to_dvc("cpu", None)
|
173
|
+
return self
|
174
|
+
|
175
|
+
|
176
|
+
class Model(_Devices_Base, ABC):
|
177
|
+
"""
|
178
|
+
This makes it easier to assign a device and retrieves it later
|
179
|
+
"""
|
180
|
+
|
181
|
+
_is_unfrozen: bool = False
|
123
182
|
|
124
183
|
def _apply_device_to(self):
|
125
184
|
"""Add here components that are needed to have device applied to them,
|
@@ -239,103 +298,6 @@ class Model(nn.Module, ABC):
|
|
239
298
|
not_unfrozen.append((name, "excluded"))
|
240
299
|
return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
|
241
300
|
|
242
|
-
def to(self, *args, **kwargs):
|
243
|
-
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
244
|
-
*args, **kwargs
|
245
|
-
)
|
246
|
-
|
247
|
-
if dtype is not None:
|
248
|
-
if not (dtype.is_floating_point or dtype.is_complex):
|
249
|
-
raise TypeError(
|
250
|
-
"nn.Module.to only accepts floating point or complex "
|
251
|
-
f"dtypes, but got desired dtype={dtype}"
|
252
|
-
)
|
253
|
-
if dtype.is_complex:
|
254
|
-
warnings.warn(
|
255
|
-
"Complex modules are a new feature under active development whose design may change, "
|
256
|
-
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
|
257
|
-
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
|
258
|
-
"if a complex module does not work as expected."
|
259
|
-
)
|
260
|
-
|
261
|
-
def convert(t: Tensor):
|
262
|
-
try:
|
263
|
-
if convert_to_format is not None and t.dim() in (4, 5):
|
264
|
-
return t.to(
|
265
|
-
device,
|
266
|
-
dtype if t.is_floating_point() or t.is_complex() else None,
|
267
|
-
non_blocking,
|
268
|
-
memory_format=convert_to_format,
|
269
|
-
)
|
270
|
-
return t.to(
|
271
|
-
device,
|
272
|
-
dtype if t.is_floating_point() or t.is_complex() else None,
|
273
|
-
non_blocking,
|
274
|
-
)
|
275
|
-
except NotImplementedError as e:
|
276
|
-
if str(e) == "Cannot copy out of meta tensor; no data!":
|
277
|
-
raise NotImplementedError(
|
278
|
-
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
|
279
|
-
f"when moving module from meta to a different device."
|
280
|
-
) from None
|
281
|
-
else:
|
282
|
-
raise
|
283
|
-
|
284
|
-
self._apply(convert)
|
285
|
-
self.device = device
|
286
|
-
self._apply_device_to()
|
287
|
-
return self
|
288
|
-
|
289
|
-
def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
290
|
-
super().ipu(device)
|
291
|
-
dvc = "ipu"
|
292
|
-
if device is not None:
|
293
|
-
dvc += (
|
294
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
295
|
-
)
|
296
|
-
self.device = dvc
|
297
|
-
self._apply_device_to()
|
298
|
-
return self
|
299
|
-
|
300
|
-
def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
301
|
-
super().xpu(device)
|
302
|
-
dvc = "xpu"
|
303
|
-
if device is not None:
|
304
|
-
dvc += (
|
305
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
306
|
-
)
|
307
|
-
self.device = dvc
|
308
|
-
self._apply_device_to()
|
309
|
-
return self
|
310
|
-
|
311
|
-
def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
312
|
-
super().cuda(device)
|
313
|
-
dvc = "cuda"
|
314
|
-
if device is not None:
|
315
|
-
dvc += (
|
316
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
317
|
-
)
|
318
|
-
self.device = dvc
|
319
|
-
self._apply_device_to()
|
320
|
-
return self
|
321
|
-
|
322
|
-
def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
323
|
-
super().mtia(device)
|
324
|
-
dvc = "mtia"
|
325
|
-
if device is not None:
|
326
|
-
dvc += (
|
327
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
328
|
-
)
|
329
|
-
self.device = dvc
|
330
|
-
self._apply_device_to()
|
331
|
-
return self
|
332
|
-
|
333
|
-
def cpu(self) -> T:
|
334
|
-
super().cpu()
|
335
|
-
self.device = "cpu"
|
336
|
-
self._apply_device_to()
|
337
|
-
return self
|
338
|
-
|
339
301
|
def count_trainable_parameters(self, module_name: Optional[str] = None):
|
340
302
|
"""Gets the number of trainable parameters from either the entire model or from a specific module."""
|
341
303
|
if module_name is not None:
|
@@ -411,6 +373,19 @@ class Model(nn.Module, ABC):
|
|
411
373
|
else:
|
412
374
|
print(f"Non-Trainable Parameters: {params}")
|
413
375
|
|
376
|
+
@classmethod
|
377
|
+
def from_pretrained(
|
378
|
+
cls,
|
379
|
+
model_path: Union[str, PathLike],
|
380
|
+
model_config: Optional[Union[str, PathLike]] = None,
|
381
|
+
model_config_dict: Optional[Dict[str, Any]] = None,
|
382
|
+
):
|
383
|
+
"""TODO: Finish this implementation or leave it as an optional function."""
|
384
|
+
assert isinstance(model_path, (str, PathLike, bytes))
|
385
|
+
model_path = get_weights(model_path)
|
386
|
+
model_config = get_config(model_config, model_config_dict)
|
387
|
+
return model_path, model_config
|
388
|
+
|
414
389
|
def save_weights(
|
415
390
|
self,
|
416
391
|
path: Union[Path, str],
|
@@ -494,7 +469,9 @@ class Model(nn.Module, ABC):
|
|
494
469
|
pass
|
495
470
|
|
496
471
|
def add_loss(
|
497
|
-
self,
|
472
|
+
self,
|
473
|
+
loss: Union[float, list[float]],
|
474
|
+
mode: Union[Literal["train", "eval"]] = "train",
|
498
475
|
):
|
499
476
|
if isinstance(loss, Number) and loss:
|
500
477
|
self._loss_history.append(loss, mode)
|
@@ -513,8 +490,12 @@ class Model(nn.Module, ABC):
|
|
513
490
|
def load_loss_history(self, path: Optional[PathLike] = None):
|
514
491
|
self._loss_history.load(path)
|
515
492
|
|
516
|
-
def get_loss_avg(
|
517
|
-
|
493
|
+
def get_loss_avg(
|
494
|
+
self,
|
495
|
+
mode: Union[Literal["train", "eval"], str],
|
496
|
+
quantity: int = 0,
|
497
|
+
):
|
498
|
+
t_list = self._loss_history.get(mode)
|
518
499
|
if not t_list:
|
519
500
|
return float("nan")
|
520
501
|
if quantity > 0:
|
@@ -524,9 +505,10 @@ class Model(nn.Module, ABC):
|
|
524
505
|
def freeze_unfreeze_loss(
|
525
506
|
self,
|
526
507
|
losses: Optional[Union[float, List[float]]] = None,
|
527
|
-
trigger_loss: float = 0.1,
|
508
|
+
trigger_loss: Union[float, bool] = 0.1,
|
528
509
|
excluded_modules: Optional[List[str]] = None,
|
529
|
-
|
510
|
+
max_items: int = 1000,
|
511
|
+
loss_name: str = "train",
|
530
512
|
):
|
531
513
|
"""If a certain threshold is reached the weights will freeze or unfreeze the modules.
|
532
514
|
the biggest use-case for this function is when training GANs where the balance
|
@@ -534,18 +516,29 @@ class Model(nn.Module, ABC):
|
|
534
516
|
|
535
517
|
Args:
|
536
518
|
losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
|
537
|
-
trigger_loss (float, optional): The value where the weights will be either freeze or unfreeze. Defaults to 0.1.
|
519
|
+
trigger_loss (float, bool, optional): The value where the weights will be either freeze or unfreeze. If set to a boolean it will freeze or unfreeze immediately according to the value (True = Freeze, False = Unfreeze). Defaults to 0.1.
|
538
520
|
excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
|
539
|
-
|
540
|
-
|
521
|
+
max_items (float, optional): The number of previous losses to be locked behind to calculate the current average. Default to 1000.
|
522
|
+
loss_name (str, optional): Responsible to define with key to recover the loss.
|
541
523
|
returns:
|
542
524
|
bool: True when its frozen and false when its trainable.
|
543
525
|
"""
|
544
526
|
if losses is not None:
|
545
|
-
calculated = None
|
546
527
|
self.add_loss(losses)
|
547
528
|
|
548
|
-
|
529
|
+
if isinstance(trigger_loss, bool):
|
530
|
+
if trigger_loss:
|
531
|
+
if self._is_unfrozen:
|
532
|
+
self.freeze_all(excluded_modules)
|
533
|
+
self._is_unfrozen = False
|
534
|
+
return True
|
535
|
+
# else
|
536
|
+
if not self._is_unfrozen:
|
537
|
+
self.unfreeze_all(excluded_modules)
|
538
|
+
self._is_unfrozen = True
|
539
|
+
return False
|
540
|
+
|
541
|
+
value = self.get_loss_avg(loss_name, max_items)
|
549
542
|
|
550
543
|
if value <= trigger_loss:
|
551
544
|
if self._is_unfrozen:
|
@@ -557,42 +550,3 @@ class Model(nn.Module, ABC):
|
|
557
550
|
self.unfreeze_all(excluded_modules)
|
558
551
|
self._is_unfrozen = True
|
559
552
|
return False
|
560
|
-
|
561
|
-
|
562
|
-
class _ModelExtended(Model):
|
563
|
-
"""Planed, but not ready, maybe in the near future?"""
|
564
|
-
criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None
|
565
|
-
optimizer: Optional[optim.Optimizer] = None
|
566
|
-
|
567
|
-
def train_step(
|
568
|
-
self,
|
569
|
-
*inputs,
|
570
|
-
loss_label: Optional[Tensor] = None,
|
571
|
-
**kwargs,
|
572
|
-
):
|
573
|
-
if not self.training:
|
574
|
-
self.train()
|
575
|
-
if self.optimizer is not None:
|
576
|
-
self.optimizer.zero_grad()
|
577
|
-
if self.autocast:
|
578
|
-
if self.criterion is None:
|
579
|
-
raise RuntimeError(
|
580
|
-
"To use autocast during training, you must assign a criterion first!"
|
581
|
-
)
|
582
|
-
with torch.autocast(device_type=self.device.type):
|
583
|
-
out = self.forward(*loss_label, **kwargs)
|
584
|
-
loss = self.criterion(out, loss_label)
|
585
|
-
|
586
|
-
if self.optimizer is not None:
|
587
|
-
loss.backward()
|
588
|
-
self.optimizer.step()
|
589
|
-
return loss
|
590
|
-
elif self.criterion is not None:
|
591
|
-
out = self.forward(*loss_label, **kwargs)
|
592
|
-
loss = self.criterion(out, loss_label)
|
593
|
-
if self.optimizer is not None:
|
594
|
-
loss.backward()
|
595
|
-
self.optimizer.step()
|
596
|
-
return loss
|
597
|
-
else:
|
598
|
-
return self(*inputs, **kwargs)
|
lt_tensor/model_zoo/__init__.py
CHANGED
lt_tensor/model_zoo/bsc.py
CHANGED
@@ -10,9 +10,9 @@ __all__ = [
|
|
10
10
|
"MultiScaleEncoder1D",
|
11
11
|
]
|
12
12
|
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
13
|
+
from lt_tensor.torch_commons import *
|
14
|
+
from lt_tensor.model_base import Model
|
15
|
+
from lt_tensor.transform import get_sinusoidal_embedding
|
16
16
|
|
17
17
|
|
18
18
|
class FeedForward(Model):
|
@@ -208,10 +208,10 @@ class MultiScaleEncoder1D(Model):
|
|
208
208
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
209
209
|
# x: [B, C, T]
|
210
210
|
return self.net(x) # [B, hidden, T]
|
211
|
-
|
212
|
-
|
211
|
+
|
212
|
+
|
213
213
|
class AudioClassifier(Model):
|
214
|
-
def __init__(self, n_mels:int=80, num_classes=5):
|
214
|
+
def __init__(self, n_mels: int = 80, num_classes=5):
|
215
215
|
super().__init__()
|
216
216
|
self.model = nn.Sequential(
|
217
217
|
nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
|
lt_tensor/model_zoo/disc.py
CHANGED
lt_tensor/model_zoo/fsn.py
CHANGED
lt_tensor/model_zoo/gns.py
CHANGED
@@ -7,10 +7,10 @@ __all__ = [
|
|
7
7
|
"NoisePredictor1D",
|
8
8
|
]
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from .rsd import ResBlock1D
|
13
|
-
from
|
10
|
+
from lt_tensor.torch_commons import *
|
11
|
+
from lt_tensor.model_base import Model
|
12
|
+
from lt_tensor.model_zoo.rsd import ResBlock1D
|
13
|
+
from lt_tensor.misc_utils import log_tensor
|
14
14
|
|
15
15
|
import torch.nn.functional as F
|
16
16
|
|
@@ -0,0 +1,150 @@
|
|
1
|
+
__all__ = ["iSTFTGenerator", "ResBlocks"]
|
2
|
+
import gc
|
3
|
+
import math
|
4
|
+
import itertools
|
5
|
+
from lt_utils.common import *
|
6
|
+
from lt_tensor.torch_commons import *
|
7
|
+
from lt_tensor.model_base import Model
|
8
|
+
from lt_tensor.misc_utils import log_tensor
|
9
|
+
from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
|
10
|
+
from lt_utils.misc_utils import log_traceback
|
11
|
+
from lt_tensor.processors import AudioProcessor
|
12
|
+
from lt_utils.type_utils import is_dir, is_pathlike
|
13
|
+
from lt_tensor.misc_utils import set_seed, clear_cache
|
14
|
+
from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
|
15
|
+
import torch.nn.functional as F
|
16
|
+
from lt_tensor.config_templates import updateDict, ModelConfig
|
17
|
+
|
18
|
+
|
19
|
+
class ResBlocks(ConvNets):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
channels: int,
|
23
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
24
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
25
|
+
[1, 3, 5],
|
26
|
+
[1, 3, 5],
|
27
|
+
[1, 3, 5],
|
28
|
+
],
|
29
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
30
|
+
):
|
31
|
+
super().__init__()
|
32
|
+
self.num_kernels = len(resblock_kernel_sizes)
|
33
|
+
self.rb = nn.ModuleList()
|
34
|
+
self.activation = activation
|
35
|
+
|
36
|
+
for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
37
|
+
self.rb.append(ResBlock1D(channels, k, j, activation))
|
38
|
+
|
39
|
+
self.rb.apply(self.init_weights)
|
40
|
+
|
41
|
+
def forward(self, x: torch.Tensor):
|
42
|
+
xs = None
|
43
|
+
for i, block in enumerate(self.rb):
|
44
|
+
if i == 0:
|
45
|
+
xs = block(x)
|
46
|
+
else:
|
47
|
+
xs += block(x)
|
48
|
+
x = xs / self.num_kernels
|
49
|
+
return self.activation(x)
|
50
|
+
|
51
|
+
|
52
|
+
class iSTFTGenerator(ConvNets):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
in_channels: int = 80,
|
56
|
+
upsample_rates: List[Union[int, List[int]]] = [8, 8],
|
57
|
+
upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
|
58
|
+
upsample_initial_channel: int = 512,
|
59
|
+
resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
|
60
|
+
resblock_dilation_sizes: List[Union[int, List[int]]] = [
|
61
|
+
[1, 3, 5],
|
62
|
+
[1, 3, 5],
|
63
|
+
[1, 3, 5],
|
64
|
+
],
|
65
|
+
n_fft: int = 16,
|
66
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
67
|
+
hop_length: int = 256,
|
68
|
+
):
|
69
|
+
super().__init__()
|
70
|
+
self.num_kernels = len(resblock_kernel_sizes)
|
71
|
+
self.num_upsamples = len(upsample_rates)
|
72
|
+
self.hop_length = hop_length
|
73
|
+
self.conv_pre = weight_norm(
|
74
|
+
nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
75
|
+
)
|
76
|
+
self.blocks = nn.ModuleList()
|
77
|
+
self.activation = activation
|
78
|
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
79
|
+
self.blocks.append(
|
80
|
+
self._make_blocks(
|
81
|
+
(i, k, u),
|
82
|
+
upsample_initial_channel,
|
83
|
+
resblock_kernel_sizes,
|
84
|
+
resblock_dilation_sizes,
|
85
|
+
)
|
86
|
+
)
|
87
|
+
|
88
|
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
89
|
+
self.post_n_fft = n_fft // 2 + 1
|
90
|
+
self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
|
91
|
+
self.conv_post.apply(self.init_weights)
|
92
|
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
93
|
+
|
94
|
+
self.phase = nn.Sequential(
|
95
|
+
nn.LeakyReLU(0.2),
|
96
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
97
|
+
nn.LeakyReLU(0.2),
|
98
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
99
|
+
)
|
100
|
+
self.spec = nn.Sequential(
|
101
|
+
nn.LeakyReLU(0.2),
|
102
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
103
|
+
nn.LeakyReLU(0.2),
|
104
|
+
nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
|
105
|
+
)
|
106
|
+
|
107
|
+
def _make_blocks(
|
108
|
+
self,
|
109
|
+
state: Tuple[int, int, int],
|
110
|
+
upsample_initial_channel: int,
|
111
|
+
resblock_kernel_sizes: List[Union[int, List[int]]],
|
112
|
+
resblock_dilation_sizes: List[int | List[int]],
|
113
|
+
):
|
114
|
+
i, k, u = state
|
115
|
+
channels = upsample_initial_channel // (2 ** (i + 1))
|
116
|
+
return nn.ModuleDict(
|
117
|
+
dict(
|
118
|
+
up=nn.Sequential(
|
119
|
+
self.activation,
|
120
|
+
weight_norm(
|
121
|
+
nn.ConvTranspose1d(
|
122
|
+
upsample_initial_channel // (2**i),
|
123
|
+
channels,
|
124
|
+
k,
|
125
|
+
u,
|
126
|
+
padding=(k - u) // 2,
|
127
|
+
)
|
128
|
+
).apply(self.init_weights),
|
129
|
+
),
|
130
|
+
residual=ResBlocks(
|
131
|
+
channels,
|
132
|
+
resblock_kernel_sizes,
|
133
|
+
resblock_dilation_sizes,
|
134
|
+
self.activation,
|
135
|
+
),
|
136
|
+
)
|
137
|
+
)
|
138
|
+
|
139
|
+
def forward(self, x):
|
140
|
+
x = self.conv_pre(x)
|
141
|
+
for block in self.blocks:
|
142
|
+
x = block["up"](x)
|
143
|
+
x = block["residual"](x)
|
144
|
+
|
145
|
+
x = self.reflection_pad(x)
|
146
|
+
x = self.conv_post(x)
|
147
|
+
spec = torch.exp(self.spec(x[:, : self.post_n_fft, :]))
|
148
|
+
phase = torch.sin(self.phase(x[:, self.post_n_fft :, :]))
|
149
|
+
|
150
|
+
return spec, phase
|