lt-tensor 0.0.1a10__py3-none-any.whl → 0.0.1a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +2 -0
- lt_tensor/config_templates.py +97 -0
- lt_tensor/datasets/audio.py +21 -7
- lt_tensor/losses.py +98 -84
- lt_tensor/math_ops.py +1 -1
- lt_tensor/misc_utils.py +94 -7
- lt_tensor/model_base.py +298 -128
- lt_tensor/model_zoo/__init__.py +2 -2
- lt_tensor/model_zoo/bsc.py +25 -3
- lt_tensor/model_zoo/disc.py +55 -51
- lt_tensor/model_zoo/fsn.py +2 -2
- lt_tensor/model_zoo/gns.py +4 -4
- lt_tensor/model_zoo/istft/__init__.py +5 -0
- lt_tensor/model_zoo/istft/generator.py +150 -0
- lt_tensor/model_zoo/istft/trainer.py +450 -0
- lt_tensor/model_zoo/istft.py +508 -66
- lt_tensor/model_zoo/pos.py +2 -2
- lt_tensor/model_zoo/rsd.py +16 -146
- lt_tensor/model_zoo/tfrms.py +4 -4
- lt_tensor/noise_tools.py +3 -4
- lt_tensor/processors/audio.py +87 -16
- lt_tensor/transform.py +30 -61
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/METADATA +3 -2
- lt_tensor-0.0.1a12.dist-info/RECORD +32 -0
- lt_tensor-0.0.1a10.dist-info/RECORD +0 -28
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a10.dist-info → lt_tensor-0.0.1a12.dist-info}/top_level.txt +0 -0
lt_tensor/model_base.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1
|
-
__all__ = ["Model"]
|
2
|
-
|
1
|
+
__all__ = ["Model", "LossTracker"]
|
3
2
|
|
3
|
+
import gc
|
4
|
+
import json
|
5
|
+
import math
|
4
6
|
import warnings
|
5
|
-
from .torch_commons import *
|
7
|
+
from lt_tensor.torch_commons import *
|
6
8
|
from lt_utils.common import *
|
7
|
-
from
|
9
|
+
from lt_tensor.misc_utils import plot_view, get_weights, get_config, updateDict
|
10
|
+
from lt_utils.misc_utils import log_traceback, get_current_time
|
11
|
+
from lt_utils.file_ops import load_json, save_json
|
8
12
|
|
9
13
|
T = TypeVar("T")
|
10
14
|
|
@@ -17,13 +21,57 @@ POSSIBLE_OUTPUT_TYPES: TypeAlias = Union[
|
|
17
21
|
]
|
18
22
|
|
19
23
|
|
20
|
-
class
|
21
|
-
""
|
22
|
-
|
23
|
-
|
24
|
+
class LossTracker:
|
25
|
+
last_file = f"logs/history_{get_current_time()}.json"
|
26
|
+
|
27
|
+
def __init__(self, max_len=50_000):
|
28
|
+
self.max_len = max_len
|
29
|
+
self.history = {
|
30
|
+
"train": [],
|
31
|
+
"eval": [],
|
32
|
+
}
|
33
|
+
|
34
|
+
def append(
|
35
|
+
self, loss: float, mode: Union[Literal["train", "eval"], "str"] = "train"
|
36
|
+
):
|
37
|
+
self.history[mode].append(float(loss))
|
38
|
+
if len(self.history[mode]) > self.max_len:
|
39
|
+
self.history[mode] = self.history[mode][-self.max_len :]
|
40
|
+
|
41
|
+
def get(self, mode: Union[Literal["train", "eval"], "str"] = "train"):
|
42
|
+
return self.history.get(mode, [])
|
43
|
+
|
44
|
+
def save(self, path: Optional[PathLike] = None):
|
45
|
+
if path is None:
|
46
|
+
path = f"logs/history_{get_current_time()}.json"
|
47
|
+
save_json(path, self.history, indent=2)
|
48
|
+
self.last_file = path
|
49
|
+
|
50
|
+
def load(self, path: Optional[PathLike] = None):
|
51
|
+
if path is None:
|
52
|
+
path = self.last_file
|
53
|
+
self.history = load_json(path, [])
|
54
|
+
self.last_file = path
|
55
|
+
|
56
|
+
def plot(
|
57
|
+
self,
|
58
|
+
history_keys: Union[str, List[str]],
|
59
|
+
max_amount: int = 0,
|
60
|
+
title: str = "Loss",
|
61
|
+
):
|
62
|
+
if isinstance(history_keys, str):
|
63
|
+
history_keys = [history_keys]
|
64
|
+
return plot_view(
|
65
|
+
{k: v for k, v in self.history.items() if k in history_keys},
|
66
|
+
title,
|
67
|
+
max_amount,
|
68
|
+
)
|
69
|
+
|
24
70
|
|
71
|
+
class _Devices_Base(nn.Module):
|
25
72
|
_device: torch.device = ROOT_DEVICE
|
26
73
|
_autocast: bool = False
|
74
|
+
_loss_history: LossTracker = LossTracker(100_000)
|
27
75
|
|
28
76
|
@property
|
29
77
|
def autocast(self):
|
@@ -41,7 +89,96 @@ class Model(nn.Module, ABC):
|
|
41
89
|
def device(self, device: Union[torch.device, str]):
|
42
90
|
assert isinstance(device, (str, torch.device))
|
43
91
|
self._device = torch.device(device) if isinstance(device, str) else device
|
44
|
-
|
92
|
+
|
93
|
+
def to(self, *args, **kwargs):
|
94
|
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
95
|
+
*args, **kwargs
|
96
|
+
)
|
97
|
+
|
98
|
+
if dtype is not None:
|
99
|
+
if not (dtype.is_floating_point or dtype.is_complex):
|
100
|
+
raise TypeError(
|
101
|
+
"nn.Module.to only accepts floating point or complex "
|
102
|
+
f"dtypes, but got desired dtype={dtype}"
|
103
|
+
)
|
104
|
+
if dtype.is_complex:
|
105
|
+
warnings.warn(
|
106
|
+
"Complex modules are a new feature under active development whose design may change, "
|
107
|
+
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
|
108
|
+
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
|
109
|
+
"if a complex module does not work as expected."
|
110
|
+
)
|
111
|
+
|
112
|
+
def convert(t: Tensor):
|
113
|
+
try:
|
114
|
+
if convert_to_format is not None and t.dim() in (4, 5):
|
115
|
+
return t.to(
|
116
|
+
device,
|
117
|
+
dtype if t.is_floating_point() or t.is_complex() else None,
|
118
|
+
non_blocking,
|
119
|
+
memory_format=convert_to_format,
|
120
|
+
)
|
121
|
+
return t.to(
|
122
|
+
device,
|
123
|
+
dtype if t.is_floating_point() or t.is_complex() else None,
|
124
|
+
non_blocking,
|
125
|
+
)
|
126
|
+
except NotImplementedError as e:
|
127
|
+
if str(e) == "Cannot copy out of meta tensor; no data!":
|
128
|
+
raise NotImplementedError(
|
129
|
+
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
|
130
|
+
f"when moving module from meta to a different device."
|
131
|
+
) from None
|
132
|
+
else:
|
133
|
+
raise
|
134
|
+
|
135
|
+
self._apply(convert)
|
136
|
+
self.device = device
|
137
|
+
return self
|
138
|
+
|
139
|
+
def _to_dvc(
|
140
|
+
self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
|
141
|
+
):
|
142
|
+
device = device_name
|
143
|
+
if device_id is not None:
|
144
|
+
if isinstance(device_id, Number):
|
145
|
+
device += ":" + str(int(device_id))
|
146
|
+
elif hasattr(device_id, "index"):
|
147
|
+
device += ":" + str(device_id.index)
|
148
|
+
self.device = device
|
149
|
+
|
150
|
+
def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
151
|
+
super().ipu(device)
|
152
|
+
self._to_dvc("ipu", device)
|
153
|
+
return self
|
154
|
+
|
155
|
+
def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
156
|
+
super().xpu(device)
|
157
|
+
self._to_dvc("xpu", device)
|
158
|
+
return self
|
159
|
+
|
160
|
+
def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
161
|
+
super().cuda(device)
|
162
|
+
self._to_dvc("cuda", device)
|
163
|
+
return self
|
164
|
+
|
165
|
+
def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
166
|
+
super().mtia(device)
|
167
|
+
self._to_dvc("mtia", device)
|
168
|
+
return self
|
169
|
+
|
170
|
+
def cpu(self) -> T:
|
171
|
+
super().cpu()
|
172
|
+
self._to_dvc("cpu", None)
|
173
|
+
return self
|
174
|
+
|
175
|
+
|
176
|
+
class Model(_Devices_Base, ABC):
|
177
|
+
"""
|
178
|
+
This makes it easier to assign a device and retrieves it later
|
179
|
+
"""
|
180
|
+
|
181
|
+
_is_unfrozen: bool = False
|
45
182
|
|
46
183
|
def _apply_device_to(self):
|
47
184
|
"""Add here components that are needed to have device applied to them,
|
@@ -61,6 +198,7 @@ class Model(nn.Module, ABC):
|
|
61
198
|
if hasattr(self, weight):
|
62
199
|
w = getattr(self, weight)
|
63
200
|
if isinstance(w, nn.Module):
|
201
|
+
|
64
202
|
w.requires_grad_(not freeze)
|
65
203
|
else:
|
66
204
|
weight.requires_grad_(not freeze)
|
@@ -112,21 +250,27 @@ class Model(nn.Module, ABC):
|
|
112
250
|
for name, param in self.named_parameters():
|
113
251
|
if no_exclusions:
|
114
252
|
try:
|
115
|
-
param.
|
116
|
-
|
253
|
+
if param.requires_grad:
|
254
|
+
param.requires_grad_(False)
|
255
|
+
frozen.append(name)
|
256
|
+
else:
|
257
|
+
not_frozen.append((name, "was_frozen"))
|
117
258
|
except Exception as e:
|
118
259
|
not_frozen.append((name, str(e)))
|
119
260
|
elif any(layer in name for layer in exclude):
|
120
261
|
try:
|
121
|
-
param.
|
122
|
-
|
262
|
+
if param.requires_grad:
|
263
|
+
param.requires_grad_(False)
|
264
|
+
frozen.append(name)
|
265
|
+
else:
|
266
|
+
not_frozen.append((name, "was_frozen"))
|
123
267
|
except Exception as e:
|
124
268
|
not_frozen.append((name, str(e)))
|
125
269
|
else:
|
126
|
-
not_frozen.append((name, "
|
270
|
+
not_frozen.append((name, "excluded"))
|
127
271
|
return dict(frozen=frozen, not_frozen=not_frozen)
|
128
272
|
|
129
|
-
def
|
273
|
+
def unfreeze_all(self, exclude: Optional[list[str]] = None):
|
130
274
|
"""Unfreezes all model parameters except specified layers."""
|
131
275
|
no_exclusions = not exclude
|
132
276
|
unfrozen = []
|
@@ -134,111 +278,26 @@ class Model(nn.Module, ABC):
|
|
134
278
|
for name, param in self.named_parameters():
|
135
279
|
if no_exclusions:
|
136
280
|
try:
|
137
|
-
param.
|
138
|
-
|
281
|
+
if not param.requires_grad:
|
282
|
+
param.requires_grad_(True)
|
283
|
+
unfrozen.append(name)
|
284
|
+
else:
|
285
|
+
not_unfrozen.append((name, "was_unfrozen"))
|
139
286
|
except Exception as e:
|
140
287
|
not_unfrozen.append((name, str(e)))
|
141
288
|
elif any(layer in name for layer in exclude):
|
142
289
|
try:
|
143
|
-
param.
|
144
|
-
|
290
|
+
if not param.requires_grad:
|
291
|
+
param.requires_grad_(True)
|
292
|
+
unfrozen.append(name)
|
293
|
+
else:
|
294
|
+
not_unfrozen.append((name, "was_unfrozen"))
|
145
295
|
except Exception as e:
|
146
296
|
not_unfrozen.append((name, str(e)))
|
147
297
|
else:
|
148
|
-
not_unfrozen.append((name, "
|
298
|
+
not_unfrozen.append((name, "excluded"))
|
149
299
|
return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
|
150
300
|
|
151
|
-
def to(self, *args, **kwargs):
|
152
|
-
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
153
|
-
*args, **kwargs
|
154
|
-
)
|
155
|
-
|
156
|
-
if dtype is not None:
|
157
|
-
if not (dtype.is_floating_point or dtype.is_complex):
|
158
|
-
raise TypeError(
|
159
|
-
"nn.Module.to only accepts floating point or complex "
|
160
|
-
f"dtypes, but got desired dtype={dtype}"
|
161
|
-
)
|
162
|
-
if dtype.is_complex:
|
163
|
-
warnings.warn(
|
164
|
-
"Complex modules are a new feature under active development whose design may change, "
|
165
|
-
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
|
166
|
-
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
|
167
|
-
"if a complex module does not work as expected."
|
168
|
-
)
|
169
|
-
|
170
|
-
def convert(t: Tensor):
|
171
|
-
try:
|
172
|
-
if convert_to_format is not None and t.dim() in (4, 5):
|
173
|
-
return t.to(
|
174
|
-
device,
|
175
|
-
dtype if t.is_floating_point() or t.is_complex() else None,
|
176
|
-
non_blocking,
|
177
|
-
memory_format=convert_to_format,
|
178
|
-
)
|
179
|
-
return t.to(
|
180
|
-
device,
|
181
|
-
dtype if t.is_floating_point() or t.is_complex() else None,
|
182
|
-
non_blocking,
|
183
|
-
)
|
184
|
-
except NotImplementedError as e:
|
185
|
-
if str(e) == "Cannot copy out of meta tensor; no data!":
|
186
|
-
raise NotImplementedError(
|
187
|
-
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
|
188
|
-
f"when moving module from meta to a different device."
|
189
|
-
) from None
|
190
|
-
else:
|
191
|
-
raise
|
192
|
-
|
193
|
-
self._apply(convert)
|
194
|
-
self.device = device
|
195
|
-
return self
|
196
|
-
|
197
|
-
def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
198
|
-
super().ipu(device)
|
199
|
-
dvc = "ipu"
|
200
|
-
if device is not None:
|
201
|
-
dvc += (
|
202
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
203
|
-
)
|
204
|
-
self.device = dvc
|
205
|
-
return self
|
206
|
-
|
207
|
-
def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
208
|
-
super().xpu(device)
|
209
|
-
dvc = "xpu"
|
210
|
-
if device is not None:
|
211
|
-
dvc += (
|
212
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
213
|
-
)
|
214
|
-
self.device = dvc
|
215
|
-
return self
|
216
|
-
|
217
|
-
def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
218
|
-
super().cuda(device)
|
219
|
-
dvc = "cuda"
|
220
|
-
if device is not None:
|
221
|
-
dvc += (
|
222
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
223
|
-
)
|
224
|
-
self.device = dvc
|
225
|
-
return self
|
226
|
-
|
227
|
-
def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
228
|
-
super().mtia(device)
|
229
|
-
dvc = "mtia"
|
230
|
-
if device is not None:
|
231
|
-
dvc += (
|
232
|
-
":" + str(device) if isinstance(device, (int, float)) else device.index
|
233
|
-
)
|
234
|
-
self.device = dvc
|
235
|
-
return self
|
236
|
-
|
237
|
-
def cpu(self) -> T:
|
238
|
-
super().cpu()
|
239
|
-
self.device = "cpu"
|
240
|
-
return self
|
241
|
-
|
242
301
|
def count_trainable_parameters(self, module_name: Optional[str] = None):
|
243
302
|
"""Gets the number of trainable parameters from either the entire model or from a specific module."""
|
244
303
|
if module_name is not None:
|
@@ -314,15 +373,39 @@ class Model(nn.Module, ABC):
|
|
314
373
|
else:
|
315
374
|
print(f"Non-Trainable Parameters: {params}")
|
316
375
|
|
317
|
-
|
376
|
+
@classmethod
|
377
|
+
def from_pretrained(
|
378
|
+
cls,
|
379
|
+
model_path: Union[str, PathLike],
|
380
|
+
model_config: Optional[Union[str, PathLike]] = None,
|
381
|
+
model_config_dict: Optional[Dict[str, Any]] = None,
|
382
|
+
):
|
383
|
+
"""TODO: Finish this implementation or leave it as an optional function."""
|
384
|
+
assert isinstance(model_path, (str, PathLike, bytes))
|
385
|
+
model_path = get_weights(model_path)
|
386
|
+
model_config = get_config(model_config, model_config_dict)
|
387
|
+
return model_path, model_config
|
388
|
+
|
389
|
+
def save_weights(
|
390
|
+
self,
|
391
|
+
path: Union[Path, str],
|
392
|
+
replace: bool = False,
|
393
|
+
):
|
318
394
|
path = Path(path)
|
395
|
+
model_dir = path
|
319
396
|
if path.exists():
|
320
|
-
|
321
|
-
path
|
322
|
-
|
323
|
-
|
397
|
+
if path.is_dir():
|
398
|
+
model_dir = Path(path, f"model_{get_current_time()}.pt")
|
399
|
+
elif path.is_file():
|
400
|
+
if replace:
|
401
|
+
path.unlink()
|
402
|
+
else:
|
403
|
+
model_dir = Path(path.parent, f"model_{get_current_time()}.pt")
|
404
|
+
else:
|
405
|
+
if not "." in str(path):
|
406
|
+
model_dir = Path(path, f"model_{get_current_time()}.pt")
|
324
407
|
path.parent.mkdir(exist_ok=True, parents=True)
|
325
|
-
torch.save(obj=self.state_dict(), f=str(
|
408
|
+
torch.save(obj=self.state_dict(), f=str(model_dir))
|
326
409
|
|
327
410
|
def load_weights(
|
328
411
|
self,
|
@@ -338,7 +421,14 @@ class Model(nn.Module, ABC):
|
|
338
421
|
if not path.exists():
|
339
422
|
assert not raise_if_not_exists, "Path does not exists!"
|
340
423
|
return None
|
341
|
-
|
424
|
+
if path.is_dir():
|
425
|
+
possible_files = list(Path(path).rglob("*.pt"))
|
426
|
+
assert (
|
427
|
+
possible_files or not raise_if_not_exists
|
428
|
+
), "No model could be found in the given path!"
|
429
|
+
if not possible_files:
|
430
|
+
return None
|
431
|
+
path = sorted(possible_files)[-1]
|
342
432
|
state_dict = torch.load(
|
343
433
|
str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
|
344
434
|
)
|
@@ -353,30 +443,110 @@ class Model(nn.Module, ABC):
|
|
353
443
|
def inference(self, *args, **kwargs):
|
354
444
|
if self.training:
|
355
445
|
self.eval()
|
356
|
-
if self.autocast:
|
357
|
-
with torch.autocast(device_type=self.device.type):
|
358
|
-
return self(*args, **kwargs)
|
359
446
|
return self(*args, **kwargs)
|
360
447
|
|
361
448
|
def train_step(
|
362
449
|
self,
|
363
|
-
*
|
450
|
+
*inputs,
|
364
451
|
**kwargs,
|
365
452
|
):
|
366
453
|
"""Train Step"""
|
367
454
|
if not self.training:
|
368
455
|
self.train()
|
369
|
-
return self(*
|
370
|
-
|
371
|
-
@torch.autocast(device_type=_device.type)
|
372
|
-
def ac_forward(self, *args, **kwargs):
|
373
|
-
return
|
456
|
+
return self(*inputs, **kwargs)
|
374
457
|
|
375
458
|
def __call__(self, *args, **kwds) -> POSSIBLE_OUTPUT_TYPES:
|
376
|
-
|
459
|
+
if self.autocast and not self.training:
|
460
|
+
with torch.autocast(device_type=self.device.type):
|
461
|
+
return super().__call__(*args, **kwds)
|
462
|
+
else:
|
463
|
+
return super().__call__(*args, **kwds)
|
377
464
|
|
378
465
|
@abstractmethod
|
379
466
|
def forward(
|
380
467
|
self, *args, **kwargs
|
381
468
|
) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
|
382
469
|
pass
|
470
|
+
|
471
|
+
def add_loss(
|
472
|
+
self,
|
473
|
+
loss: Union[float, list[float]],
|
474
|
+
mode: Union[Literal["train", "eval"]] = "train",
|
475
|
+
):
|
476
|
+
if isinstance(loss, Number) and loss:
|
477
|
+
self._loss_history.append(loss, mode)
|
478
|
+
elif isinstance(loss, (list, tuple)):
|
479
|
+
if loss:
|
480
|
+
self._loss_history.append(sum(loss) / len(loss), mode=mode)
|
481
|
+
elif isinstance(loss, Tensor):
|
482
|
+
try:
|
483
|
+
self._loss_history.append(loss.detach().flatten().mean().item())
|
484
|
+
except Exception as e:
|
485
|
+
log_traceback(e, "add_loss - Tensor")
|
486
|
+
|
487
|
+
def save_loss_history(self, path: Optional[PathLike] = None):
|
488
|
+
self._loss_history.save(path)
|
489
|
+
|
490
|
+
def load_loss_history(self, path: Optional[PathLike] = None):
|
491
|
+
self._loss_history.load(path)
|
492
|
+
|
493
|
+
def get_loss_avg(
|
494
|
+
self,
|
495
|
+
mode: Union[Literal["train", "eval"], str],
|
496
|
+
quantity: int = 0,
|
497
|
+
):
|
498
|
+
t_list = self._loss_history.get(mode)
|
499
|
+
if not t_list:
|
500
|
+
return float("nan")
|
501
|
+
if quantity > 0:
|
502
|
+
t_list = t_list[-quantity:]
|
503
|
+
return sum(t_list) / len(t_list)
|
504
|
+
|
505
|
+
def freeze_unfreeze_loss(
|
506
|
+
self,
|
507
|
+
losses: Optional[Union[float, List[float]]] = None,
|
508
|
+
trigger_loss: Union[float, bool] = 0.1,
|
509
|
+
excluded_modules: Optional[List[str]] = None,
|
510
|
+
max_items: int = 1000,
|
511
|
+
loss_name: str = "train",
|
512
|
+
):
|
513
|
+
"""If a certain threshold is reached the weights will freeze or unfreeze the modules.
|
514
|
+
the biggest use-case for this function is when training GANs where the balance
|
515
|
+
from the discriminator and generator must be kept.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
|
519
|
+
trigger_loss (float, bool, optional): The value where the weights will be either freeze or unfreeze. If set to a boolean it will freeze or unfreeze immediately according to the value (True = Freeze, False = Unfreeze). Defaults to 0.1.
|
520
|
+
excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
|
521
|
+
max_items (float, optional): The number of previous losses to be locked behind to calculate the current average. Default to 1000.
|
522
|
+
loss_name (str, optional): Responsible to define with key to recover the loss.
|
523
|
+
returns:
|
524
|
+
bool: True when its frozen and false when its trainable.
|
525
|
+
"""
|
526
|
+
if losses is not None:
|
527
|
+
self.add_loss(losses)
|
528
|
+
|
529
|
+
if isinstance(trigger_loss, bool):
|
530
|
+
if trigger_loss:
|
531
|
+
if self._is_unfrozen:
|
532
|
+
self.freeze_all(excluded_modules)
|
533
|
+
self._is_unfrozen = False
|
534
|
+
return True
|
535
|
+
# else
|
536
|
+
if not self._is_unfrozen:
|
537
|
+
self.unfreeze_all(excluded_modules)
|
538
|
+
self._is_unfrozen = True
|
539
|
+
return False
|
540
|
+
|
541
|
+
value = self.get_loss_avg(loss_name, max_items)
|
542
|
+
|
543
|
+
if value <= trigger_loss:
|
544
|
+
if self._is_unfrozen:
|
545
|
+
self.freeze_all(excluded_modules)
|
546
|
+
self._is_unfrozen = False
|
547
|
+
return True
|
548
|
+
else:
|
549
|
+
if not self._is_unfrozen:
|
550
|
+
self.unfreeze_all(excluded_modules)
|
551
|
+
self._is_unfrozen = True
|
552
|
+
return False
|
lt_tensor/model_zoo/__init__.py
CHANGED
lt_tensor/model_zoo/bsc.py
CHANGED
@@ -10,9 +10,9 @@ __all__ = [
|
|
10
10
|
"MultiScaleEncoder1D",
|
11
11
|
]
|
12
12
|
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
13
|
+
from lt_tensor.torch_commons import *
|
14
|
+
from lt_tensor.model_base import Model
|
15
|
+
from lt_tensor.transform import get_sinusoidal_embedding
|
16
16
|
|
17
17
|
|
18
18
|
class FeedForward(Model):
|
@@ -208,3 +208,25 @@ class MultiScaleEncoder1D(Model):
|
|
208
208
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
209
209
|
# x: [B, C, T]
|
210
210
|
return self.net(x) # [B, hidden, T]
|
211
|
+
|
212
|
+
|
213
|
+
class AudioClassifier(Model):
|
214
|
+
def __init__(self, n_mels: int = 80, num_classes=5):
|
215
|
+
super().__init__()
|
216
|
+
self.model = nn.Sequential(
|
217
|
+
nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
|
218
|
+
nn.LeakyReLU(0.2),
|
219
|
+
nn.Conv1d(256, 256, kernel_size=3, padding=1, groups=4),
|
220
|
+
nn.BatchNorm1d(256),
|
221
|
+
nn.LeakyReLU(0.2),
|
222
|
+
nn.Conv1d(256, 256, kernel_size=3, padding=1),
|
223
|
+
nn.BatchNorm1d(256),
|
224
|
+
nn.LeakyReLU(0.2),
|
225
|
+
nn.AdaptiveAvgPool1d(1), # Output shape: [B, 64, 1]
|
226
|
+
nn.Flatten(), # -> [B, 64]
|
227
|
+
nn.Linear(256, num_classes),
|
228
|
+
)
|
229
|
+
self.eval()
|
230
|
+
|
231
|
+
def forward(self, x):
|
232
|
+
return self.model(x)
|