lt-tensor 0.0.1a10__py3-none-any.whl → 0.0.1a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -1,10 +1,14 @@
1
- __all__ = ["Model"]
2
-
1
+ __all__ = ["Model", "LossTracker"]
3
2
 
3
+ import gc
4
+ import json
5
+ import math
4
6
  import warnings
5
- from .torch_commons import *
7
+ from lt_tensor.torch_commons import *
6
8
  from lt_utils.common import *
7
- from lt_utils.misc_utils import log_traceback
9
+ from lt_tensor.misc_utils import plot_view, get_weights, get_config, updateDict
10
+ from lt_utils.misc_utils import log_traceback, get_current_time
11
+ from lt_utils.file_ops import load_json, save_json
8
12
 
9
13
  T = TypeVar("T")
10
14
 
@@ -17,13 +21,57 @@ POSSIBLE_OUTPUT_TYPES: TypeAlias = Union[
17
21
  ]
18
22
 
19
23
 
20
- class Model(nn.Module, ABC):
21
- """
22
- This makes it easier to assign a device and retrieves it later
23
- """
24
+ class LossTracker:
25
+ last_file = f"logs/history_{get_current_time()}.json"
26
+
27
+ def __init__(self, max_len=50_000):
28
+ self.max_len = max_len
29
+ self.history = {
30
+ "train": [],
31
+ "eval": [],
32
+ }
33
+
34
+ def append(
35
+ self, loss: float, mode: Union[Literal["train", "eval"], "str"] = "train"
36
+ ):
37
+ self.history[mode].append(float(loss))
38
+ if len(self.history[mode]) > self.max_len:
39
+ self.history[mode] = self.history[mode][-self.max_len :]
40
+
41
+ def get(self, mode: Union[Literal["train", "eval"], "str"] = "train"):
42
+ return self.history.get(mode, [])
43
+
44
+ def save(self, path: Optional[PathLike] = None):
45
+ if path is None:
46
+ path = f"logs/history_{get_current_time()}.json"
47
+ save_json(path, self.history, indent=2)
48
+ self.last_file = path
49
+
50
+ def load(self, path: Optional[PathLike] = None):
51
+ if path is None:
52
+ path = self.last_file
53
+ self.history = load_json(path, [])
54
+ self.last_file = path
55
+
56
+ def plot(
57
+ self,
58
+ history_keys: Union[str, List[str]],
59
+ max_amount: int = 0,
60
+ title: str = "Loss",
61
+ ):
62
+ if isinstance(history_keys, str):
63
+ history_keys = [history_keys]
64
+ return plot_view(
65
+ {k: v for k, v in self.history.items() if k in history_keys},
66
+ title,
67
+ max_amount,
68
+ )
69
+
24
70
 
71
+ class _Devices_Base(nn.Module):
25
72
  _device: torch.device = ROOT_DEVICE
26
73
  _autocast: bool = False
74
+ _loss_history: LossTracker = LossTracker(100_000)
27
75
 
28
76
  @property
29
77
  def autocast(self):
@@ -41,7 +89,96 @@ class Model(nn.Module, ABC):
41
89
  def device(self, device: Union[torch.device, str]):
42
90
  assert isinstance(device, (str, torch.device))
43
91
  self._device = torch.device(device) if isinstance(device, str) else device
44
- self._apply_device_to()
92
+
93
+ def to(self, *args, **kwargs):
94
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
95
+ *args, **kwargs
96
+ )
97
+
98
+ if dtype is not None:
99
+ if not (dtype.is_floating_point or dtype.is_complex):
100
+ raise TypeError(
101
+ "nn.Module.to only accepts floating point or complex "
102
+ f"dtypes, but got desired dtype={dtype}"
103
+ )
104
+ if dtype.is_complex:
105
+ warnings.warn(
106
+ "Complex modules are a new feature under active development whose design may change, "
107
+ "and some modules might not work as expected when using complex tensors as parameters or buffers. "
108
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
109
+ "if a complex module does not work as expected."
110
+ )
111
+
112
+ def convert(t: Tensor):
113
+ try:
114
+ if convert_to_format is not None and t.dim() in (4, 5):
115
+ return t.to(
116
+ device,
117
+ dtype if t.is_floating_point() or t.is_complex() else None,
118
+ non_blocking,
119
+ memory_format=convert_to_format,
120
+ )
121
+ return t.to(
122
+ device,
123
+ dtype if t.is_floating_point() or t.is_complex() else None,
124
+ non_blocking,
125
+ )
126
+ except NotImplementedError as e:
127
+ if str(e) == "Cannot copy out of meta tensor; no data!":
128
+ raise NotImplementedError(
129
+ f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
130
+ f"when moving module from meta to a different device."
131
+ ) from None
132
+ else:
133
+ raise
134
+
135
+ self._apply(convert)
136
+ self.device = device
137
+ return self
138
+
139
+ def _to_dvc(
140
+ self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
141
+ ):
142
+ device = device_name
143
+ if device_id is not None:
144
+ if isinstance(device_id, Number):
145
+ device += ":" + str(int(device_id))
146
+ elif hasattr(device_id, "index"):
147
+ device += ":" + str(device_id.index)
148
+ self.device = device
149
+
150
+ def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
151
+ super().ipu(device)
152
+ self._to_dvc("ipu", device)
153
+ return self
154
+
155
+ def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
156
+ super().xpu(device)
157
+ self._to_dvc("xpu", device)
158
+ return self
159
+
160
+ def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
161
+ super().cuda(device)
162
+ self._to_dvc("cuda", device)
163
+ return self
164
+
165
+ def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
166
+ super().mtia(device)
167
+ self._to_dvc("mtia", device)
168
+ return self
169
+
170
+ def cpu(self) -> T:
171
+ super().cpu()
172
+ self._to_dvc("cpu", None)
173
+ return self
174
+
175
+
176
+ class Model(_Devices_Base, ABC):
177
+ """
178
+ This makes it easier to assign a device and retrieves it later
179
+ """
180
+
181
+ _is_unfrozen: bool = False
45
182
 
46
183
  def _apply_device_to(self):
47
184
  """Add here components that are needed to have device applied to them,
@@ -61,6 +198,7 @@ class Model(nn.Module, ABC):
61
198
  if hasattr(self, weight):
62
199
  w = getattr(self, weight)
63
200
  if isinstance(w, nn.Module):
201
+
64
202
  w.requires_grad_(not freeze)
65
203
  else:
66
204
  weight.requires_grad_(not freeze)
@@ -112,21 +250,27 @@ class Model(nn.Module, ABC):
112
250
  for name, param in self.named_parameters():
113
251
  if no_exclusions:
114
252
  try:
115
- param.requires_grad_(False)
116
- frozen.append(name)
253
+ if param.requires_grad:
254
+ param.requires_grad_(False)
255
+ frozen.append(name)
256
+ else:
257
+ not_frozen.append((name, "was_frozen"))
117
258
  except Exception as e:
118
259
  not_frozen.append((name, str(e)))
119
260
  elif any(layer in name for layer in exclude):
120
261
  try:
121
- param.requires_grad_(False)
122
- frozen.append(name)
262
+ if param.requires_grad:
263
+ param.requires_grad_(False)
264
+ frozen.append(name)
265
+ else:
266
+ not_frozen.append((name, "was_frozen"))
123
267
  except Exception as e:
124
268
  not_frozen.append((name, str(e)))
125
269
  else:
126
- not_frozen.append((name, "Excluded"))
270
+ not_frozen.append((name, "excluded"))
127
271
  return dict(frozen=frozen, not_frozen=not_frozen)
128
272
 
129
- def unfreeze_all_except(self, exclude: Optional[list[str]] = None):
273
+ def unfreeze_all(self, exclude: Optional[list[str]] = None):
130
274
  """Unfreezes all model parameters except specified layers."""
131
275
  no_exclusions = not exclude
132
276
  unfrozen = []
@@ -134,111 +278,26 @@ class Model(nn.Module, ABC):
134
278
  for name, param in self.named_parameters():
135
279
  if no_exclusions:
136
280
  try:
137
- param.requires_grad_(True)
138
- unfrozen.append(name)
281
+ if not param.requires_grad:
282
+ param.requires_grad_(True)
283
+ unfrozen.append(name)
284
+ else:
285
+ not_unfrozen.append((name, "was_unfrozen"))
139
286
  except Exception as e:
140
287
  not_unfrozen.append((name, str(e)))
141
288
  elif any(layer in name for layer in exclude):
142
289
  try:
143
- param.requires_grad_(True)
144
- unfrozen.append(name)
290
+ if not param.requires_grad:
291
+ param.requires_grad_(True)
292
+ unfrozen.append(name)
293
+ else:
294
+ not_unfrozen.append((name, "was_unfrozen"))
145
295
  except Exception as e:
146
296
  not_unfrozen.append((name, str(e)))
147
297
  else:
148
- not_unfrozen.append((name, "Excluded"))
298
+ not_unfrozen.append((name, "excluded"))
149
299
  return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
150
300
 
151
- def to(self, *args, **kwargs):
152
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
153
- *args, **kwargs
154
- )
155
-
156
- if dtype is not None:
157
- if not (dtype.is_floating_point or dtype.is_complex):
158
- raise TypeError(
159
- "nn.Module.to only accepts floating point or complex "
160
- f"dtypes, but got desired dtype={dtype}"
161
- )
162
- if dtype.is_complex:
163
- warnings.warn(
164
- "Complex modules are a new feature under active development whose design may change, "
165
- "and some modules might not work as expected when using complex tensors as parameters or buffers. "
166
- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
167
- "if a complex module does not work as expected."
168
- )
169
-
170
- def convert(t: Tensor):
171
- try:
172
- if convert_to_format is not None and t.dim() in (4, 5):
173
- return t.to(
174
- device,
175
- dtype if t.is_floating_point() or t.is_complex() else None,
176
- non_blocking,
177
- memory_format=convert_to_format,
178
- )
179
- return t.to(
180
- device,
181
- dtype if t.is_floating_point() or t.is_complex() else None,
182
- non_blocking,
183
- )
184
- except NotImplementedError as e:
185
- if str(e) == "Cannot copy out of meta tensor; no data!":
186
- raise NotImplementedError(
187
- f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
188
- f"when moving module from meta to a different device."
189
- ) from None
190
- else:
191
- raise
192
-
193
- self._apply(convert)
194
- self.device = device
195
- return self
196
-
197
- def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
198
- super().ipu(device)
199
- dvc = "ipu"
200
- if device is not None:
201
- dvc += (
202
- ":" + str(device) if isinstance(device, (int, float)) else device.index
203
- )
204
- self.device = dvc
205
- return self
206
-
207
- def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
208
- super().xpu(device)
209
- dvc = "xpu"
210
- if device is not None:
211
- dvc += (
212
- ":" + str(device) if isinstance(device, (int, float)) else device.index
213
- )
214
- self.device = dvc
215
- return self
216
-
217
- def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
218
- super().cuda(device)
219
- dvc = "cuda"
220
- if device is not None:
221
- dvc += (
222
- ":" + str(device) if isinstance(device, (int, float)) else device.index
223
- )
224
- self.device = dvc
225
- return self
226
-
227
- def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
228
- super().mtia(device)
229
- dvc = "mtia"
230
- if device is not None:
231
- dvc += (
232
- ":" + str(device) if isinstance(device, (int, float)) else device.index
233
- )
234
- self.device = dvc
235
- return self
236
-
237
- def cpu(self) -> T:
238
- super().cpu()
239
- self.device = "cpu"
240
- return self
241
-
242
301
  def count_trainable_parameters(self, module_name: Optional[str] = None):
243
302
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
244
303
  if module_name is not None:
@@ -314,15 +373,39 @@ class Model(nn.Module, ABC):
314
373
  else:
315
374
  print(f"Non-Trainable Parameters: {params}")
316
375
 
317
- def save_weights(self, path: Union[Path, str]):
376
+ @classmethod
377
+ def from_pretrained(
378
+ cls,
379
+ model_path: Union[str, PathLike],
380
+ model_config: Optional[Union[str, PathLike]] = None,
381
+ model_config_dict: Optional[Dict[str, Any]] = None,
382
+ ):
383
+ """TODO: Finish this implementation or leave it as an optional function."""
384
+ assert isinstance(model_path, (str, PathLike, bytes))
385
+ model_path = get_weights(model_path)
386
+ model_config = get_config(model_config, model_config_dict)
387
+ return model_path, model_config
388
+
389
+ def save_weights(
390
+ self,
391
+ path: Union[Path, str],
392
+ replace: bool = False,
393
+ ):
318
394
  path = Path(path)
395
+ model_dir = path
319
396
  if path.exists():
320
- assert (
321
- path.is_file()
322
- ), "The provided path exists but its a directory not a file!"
323
- path.rmdir()
397
+ if path.is_dir():
398
+ model_dir = Path(path, f"model_{get_current_time()}.pt")
399
+ elif path.is_file():
400
+ if replace:
401
+ path.unlink()
402
+ else:
403
+ model_dir = Path(path.parent, f"model_{get_current_time()}.pt")
404
+ else:
405
+ if not "." in str(path):
406
+ model_dir = Path(path, f"model_{get_current_time()}.pt")
324
407
  path.parent.mkdir(exist_ok=True, parents=True)
325
- torch.save(obj=self.state_dict(), f=str(path))
408
+ torch.save(obj=self.state_dict(), f=str(model_dir))
326
409
 
327
410
  def load_weights(
328
411
  self,
@@ -338,7 +421,14 @@ class Model(nn.Module, ABC):
338
421
  if not path.exists():
339
422
  assert not raise_if_not_exists, "Path does not exists!"
340
423
  return None
341
- assert path.is_file(), "The provided path is not a valid file!"
424
+ if path.is_dir():
425
+ possible_files = list(Path(path).rglob("*.pt"))
426
+ assert (
427
+ possible_files or not raise_if_not_exists
428
+ ), "No model could be found in the given path!"
429
+ if not possible_files:
430
+ return None
431
+ path = sorted(possible_files)[-1]
342
432
  state_dict = torch.load(
343
433
  str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
344
434
  )
@@ -353,30 +443,110 @@ class Model(nn.Module, ABC):
353
443
  def inference(self, *args, **kwargs):
354
444
  if self.training:
355
445
  self.eval()
356
- if self.autocast:
357
- with torch.autocast(device_type=self.device.type):
358
- return self(*args, **kwargs)
359
446
  return self(*args, **kwargs)
360
447
 
361
448
  def train_step(
362
449
  self,
363
- *args,
450
+ *inputs,
364
451
  **kwargs,
365
452
  ):
366
453
  """Train Step"""
367
454
  if not self.training:
368
455
  self.train()
369
- return self(*args, **kwargs)
370
-
371
- @torch.autocast(device_type=_device.type)
372
- def ac_forward(self, *args, **kwargs):
373
- return
456
+ return self(*inputs, **kwargs)
374
457
 
375
458
  def __call__(self, *args, **kwds) -> POSSIBLE_OUTPUT_TYPES:
376
- return super().__call__(*args, **kwds)
459
+ if self.autocast and not self.training:
460
+ with torch.autocast(device_type=self.device.type):
461
+ return super().__call__(*args, **kwds)
462
+ else:
463
+ return super().__call__(*args, **kwds)
377
464
 
378
465
  @abstractmethod
379
466
  def forward(
380
467
  self, *args, **kwargs
381
468
  ) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
382
469
  pass
470
+
471
+ def add_loss(
472
+ self,
473
+ loss: Union[float, list[float]],
474
+ mode: Union[Literal["train", "eval"]] = "train",
475
+ ):
476
+ if isinstance(loss, Number) and loss:
477
+ self._loss_history.append(loss, mode)
478
+ elif isinstance(loss, (list, tuple)):
479
+ if loss:
480
+ self._loss_history.append(sum(loss) / len(loss), mode=mode)
481
+ elif isinstance(loss, Tensor):
482
+ try:
483
+ self._loss_history.append(loss.detach().flatten().mean().item())
484
+ except Exception as e:
485
+ log_traceback(e, "add_loss - Tensor")
486
+
487
+ def save_loss_history(self, path: Optional[PathLike] = None):
488
+ self._loss_history.save(path)
489
+
490
+ def load_loss_history(self, path: Optional[PathLike] = None):
491
+ self._loss_history.load(path)
492
+
493
+ def get_loss_avg(
494
+ self,
495
+ mode: Union[Literal["train", "eval"], str],
496
+ quantity: int = 0,
497
+ ):
498
+ t_list = self._loss_history.get(mode)
499
+ if not t_list:
500
+ return float("nan")
501
+ if quantity > 0:
502
+ t_list = t_list[-quantity:]
503
+ return sum(t_list) / len(t_list)
504
+
505
+ def freeze_unfreeze_loss(
506
+ self,
507
+ losses: Optional[Union[float, List[float]]] = None,
508
+ trigger_loss: Union[float, bool] = 0.1,
509
+ excluded_modules: Optional[List[str]] = None,
510
+ max_items: int = 1000,
511
+ loss_name: str = "train",
512
+ ):
513
+ """If a certain threshold is reached the weights will freeze or unfreeze the modules.
514
+ the biggest use-case for this function is when training GANs where the balance
515
+ from the discriminator and generator must be kept.
516
+
517
+ Args:
518
+ losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
519
+ trigger_loss (float, bool, optional): The value where the weights will be either freeze or unfreeze. If set to a boolean it will freeze or unfreeze immediately according to the value (True = Freeze, False = Unfreeze). Defaults to 0.1.
520
+ excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
521
+ max_items (float, optional): The number of previous losses to be locked behind to calculate the current average. Default to 1000.
522
+ loss_name (str, optional): Responsible to define with key to recover the loss.
523
+ returns:
524
+ bool: True when its frozen and false when its trainable.
525
+ """
526
+ if losses is not None:
527
+ self.add_loss(losses)
528
+
529
+ if isinstance(trigger_loss, bool):
530
+ if trigger_loss:
531
+ if self._is_unfrozen:
532
+ self.freeze_all(excluded_modules)
533
+ self._is_unfrozen = False
534
+ return True
535
+ # else
536
+ if not self._is_unfrozen:
537
+ self.unfreeze_all(excluded_modules)
538
+ self._is_unfrozen = True
539
+ return False
540
+
541
+ value = self.get_loss_avg(loss_name, max_items)
542
+
543
+ if value <= trigger_loss:
544
+ if self._is_unfrozen:
545
+ self.freeze_all(excluded_modules)
546
+ self._is_unfrozen = False
547
+ return True
548
+ else:
549
+ if not self._is_unfrozen:
550
+ self.unfreeze_all(excluded_modules)
551
+ self._is_unfrozen = True
552
+ return False
@@ -5,7 +5,7 @@ __all__ = [
5
5
  "pos", # positional encoders
6
6
  "fsn", # fusion
7
7
  "gns", # generators
8
- "disc", # discriminators
9
- "istft" # self-explanatory
8
+ "disc", # discriminators
9
+ "istft", # self-explanatory
10
10
  ]
11
11
  from . import bsc, fsn, gns, istft, pos, rsd, tfrms, disc
@@ -10,9 +10,9 @@ __all__ = [
10
10
  "MultiScaleEncoder1D",
11
11
  ]
12
12
 
13
- from ..torch_commons import *
14
- from ..model_base import Model
15
- from ..transform import get_sinusoidal_embedding
13
+ from lt_tensor.torch_commons import *
14
+ from lt_tensor.model_base import Model
15
+ from lt_tensor.transform import get_sinusoidal_embedding
16
16
 
17
17
 
18
18
  class FeedForward(Model):
@@ -208,3 +208,25 @@ class MultiScaleEncoder1D(Model):
208
208
  def forward(self, x: torch.Tensor) -> torch.Tensor:
209
209
  # x: [B, C, T]
210
210
  return self.net(x) # [B, hidden, T]
211
+
212
+
213
+ class AudioClassifier(Model):
214
+ def __init__(self, n_mels: int = 80, num_classes=5):
215
+ super().__init__()
216
+ self.model = nn.Sequential(
217
+ nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
218
+ nn.LeakyReLU(0.2),
219
+ nn.Conv1d(256, 256, kernel_size=3, padding=1, groups=4),
220
+ nn.BatchNorm1d(256),
221
+ nn.LeakyReLU(0.2),
222
+ nn.Conv1d(256, 256, kernel_size=3, padding=1),
223
+ nn.BatchNorm1d(256),
224
+ nn.LeakyReLU(0.2),
225
+ nn.AdaptiveAvgPool1d(1), # Output shape: [B, 64, 1]
226
+ nn.Flatten(), # -> [B, 64]
227
+ nn.Linear(256, num_classes),
228
+ )
229
+ self.eval()
230
+
231
+ def forward(self, x):
232
+ return self.model(x)