lt-tensor 0.0.1a10__py3-none-any.whl → 0.0.1a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/losses.py CHANGED
@@ -1,10 +1,106 @@
1
- __all__ = ["masked_cross_entropy"]
1
+ __all__ = [
2
+ "masked_cross_entropy",
3
+ "adaptive_l1_loss",
4
+ "contrastive_loss",
5
+ "smooth_l1_loss",
6
+ "hybrid_loss",
7
+ "diff_loss",
8
+ "cosine_loss",
9
+ "gan_loss",
10
+ "ft_n_loss",
11
+ ]
2
12
  import math
3
13
  import random
4
14
  from .torch_commons import *
5
15
  from lt_utils.common import *
6
16
  import torch.nn.functional as F
7
17
 
18
+ def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
19
+ if weight is not None:
20
+ return torch.mean((torch.abs(output - target) + weight) **0.5)
21
+ return torch.mean(torch.abs(output - target)**0.5)
22
+
23
+ def adaptive_l1_loss(
24
+ inp: Tensor,
25
+ tgt: Tensor,
26
+ weight: Optional[Tensor] = None,
27
+ scale: float = 1.0,
28
+ inverted: bool = False,
29
+ ):
30
+
31
+ if weight is not None:
32
+ loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
33
+ else:
34
+ loss = torch.mean(torch.abs(inp - tgt))
35
+ loss *= scale
36
+ if inverted:
37
+ return -loss
38
+ return loss
39
+
40
+
41
+ def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
42
+ diff = torch.abs(inp - tgt)
43
+ loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
44
+ if weight is not None:
45
+ loss *= weight
46
+ return loss.mean()
47
+
48
+
49
+ def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
50
+ # label == 1: similar, label == 0: dissimilar
51
+ dist = torch.nn.functional.pairwise_distance(x1, x2)
52
+ loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
53
+ return loss.mean()
54
+
55
+
56
+ def cosine_loss(inp, tgt):
57
+ cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
58
+ return 1 - cos.mean() # Lower is better
59
+
60
+
61
+ class GanLosses:
62
+ @staticmethod
63
+ def get_loss(
64
+ pred: Tensor,
65
+ target_is_real: bool,
66
+ loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
67
+ ) -> Tensor:
68
+ if loss_type == "bce": # Standard GAN
69
+ target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
70
+ return F.binary_cross_entropy_with_logits(pred, target)
71
+
72
+ elif loss_type == "mse": # LSGAN
73
+ target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
74
+ return F.mse_loss(torch.sigmoid(pred), target)
75
+
76
+ elif loss_type == "hinge":
77
+ if target_is_real:
78
+ return torch.mean(F.relu(1.0 - pred))
79
+ else:
80
+ return torch.mean(F.relu(1.0 + pred))
81
+
82
+ elif loss_type == "wasserstein":
83
+ return -pred.mean() if target_is_real else pred.mean()
84
+
85
+ else:
86
+ raise ValueError(f"Unknown loss_type: {loss_type}")
87
+
88
+ @staticmethod
89
+ def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
90
+ return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
91
+
92
+ @staticmethod
93
+ def discriminator_loss(
94
+ real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
95
+ ) -> Tensor:
96
+ real_loss = GanLosses.get_loss(
97
+ real_pred, target_is_real=True, loss_type=loss_type
98
+ )
99
+ fake_loss = GanLosses.get_loss(
100
+ fake_pred.detach(), target_is_real=False, loss_type=loss_type
101
+ )
102
+ return (real_loss + fake_loss) * 0.5
103
+
8
104
 
9
105
  def masked_cross_entropy(
10
106
  logits: torch.Tensor, # [B, T, V]
@@ -61,85 +157,3 @@ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
61
157
  torch.log(1 - fake + 1e-7)
62
158
  )
63
159
  return loss
64
-
65
-
66
- def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
67
- loss = 0
68
- for real, fake in zip(real_preds, fake_preds):
69
- if use_lsgan:
70
- loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
71
- fake, torch.zeros_like(fake)
72
- )
73
- else:
74
- loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
75
- torch.log(1 - fake + 1e-7)
76
- )
77
- return loss
78
-
79
-
80
- def gan_g_loss(fake_preds, use_lsgan=True):
81
- loss = 0
82
- for fake in fake_preds:
83
- if use_lsgan:
84
- loss += F.mse_loss(fake, torch.ones_like(fake))
85
- else:
86
- loss += -torch.mean(torch.log(fake + 1e-7))
87
- return loss
88
-
89
-
90
- def feature_matching_loss(real_feats, fake_feats):
91
- """real_feats and fake_feats are lists of intermediate features"""
92
- loss = 0
93
- for real_layers, fake_layers in zip(real_feats, fake_feats):
94
- for r, f in zip(real_layers, fake_layers):
95
- loss += F.l1_loss(f, r.detach())
96
- return loss
97
-
98
-
99
- def feature_loss(real_fmaps, fake_fmaps, weight=2.0):
100
- loss = 0.0
101
- for dr, dg in zip(real_fmaps, fake_fmaps): # Each (layer list from a discriminator)
102
- for r_feat, g_feat in zip(dr, dg):
103
- loss += F.l1_loss(r_feat, g_feat)
104
- return loss * weight
105
-
106
-
107
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
108
- loss = 0.0
109
- r_losses = []
110
- g_losses = []
111
-
112
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
113
- r_loss = F.mse_loss(dr, torch.ones_like(dr))
114
- g_loss = F.mse_loss(dg, torch.zeros_like(dg))
115
- loss += r_loss + g_loss
116
- r_losses.append(r_loss)
117
- g_losses.append(g_loss)
118
-
119
- return loss, r_losses, g_losses
120
-
121
-
122
- def generator_loss(fake_outputs):
123
- total = 0.0
124
- g_losses = []
125
- for out in fake_outputs:
126
- loss = F.mse_loss(out, torch.ones_like(out))
127
- g_losses.append(loss)
128
- total += loss
129
- return total, g_losses
130
-
131
-
132
- def multi_resolution_stft_loss(y, y_hat, fft_sizes=[512, 1024, 2048]):
133
- loss = 0
134
- for fft_size in fft_sizes:
135
- hop = fft_size // 4
136
- win = fft_size
137
- y_stft = torch.stft(
138
- y, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
139
- )
140
- y_hat_stft = torch.stft(
141
- y_hat, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
142
- )
143
-
144
- loss += F.l1_loss(torch.abs(y_stft), torch.abs(y_hat_stft))
145
- return loss
lt_tensor/misc_utils.py CHANGED
@@ -15,6 +15,7 @@ __all__ = [
15
15
  "TorchCacheUtils",
16
16
  "clear_cache",
17
17
  "default_device",
18
+ "soft_restore",
18
19
  "Packing",
19
20
  "Padding",
20
21
  "Masking",
@@ -35,14 +36,29 @@ from lt_utils.misc_utils import ff_list
35
36
  import torch.nn.functional as F
36
37
 
37
38
 
39
+ def soft_restore(tensor, epsilon=1e-6):
40
+ return torch.where(tensor == 0, torch.full_like(tensor, epsilon), tensor)
41
+
42
+
38
43
  def try_torch(fn: str, *args, **kwargs):
44
+ tryed_torch = False
45
+ not_present_message = (
46
+ f"Both `torch` and `torch.nn.functional` does not contain the module `{fn}`"
47
+ )
39
48
  try:
40
- return getattr(F, fn)(*args, **kwargs)
41
- except Exception as e:
42
- try:
49
+ if hasattr(F, fn):
50
+ return getattr(F, fn)(*args, **kwargs)
51
+ elif hasattr(torch, fn):
52
+ tryed_torch = True
43
53
  return getattr(torch, fn)(*args, **kwargs)
54
+ return not_present_message
55
+ except Exception as a:
56
+ try:
57
+ if not tryed_torch and hasattr(torch, fn):
58
+ return getattr(torch, fn)(*args, **kwargs)
59
+ return str(a)
44
60
  except Exception as e:
45
- return str(e)
61
+ return str(e) + " | " + str(a)
46
62
 
47
63
 
48
64
  def log_tensor(
@@ -192,6 +208,7 @@ class LogTensor:
192
208
  ("max", dict(dim=-1)),
193
209
  ],
194
210
  validate_item_type: bool = False,
211
+ exclude_invalid_losses: bool = True,
195
212
  **kwargs,
196
213
  ):
197
214
  invalid_type = not isinstance(inputs, (Tensor, np.ndarray, list, tuple))
@@ -243,12 +260,13 @@ class LogTensor:
243
260
  value = try_torch(log_fn, inputs)
244
261
  else:
245
262
  value = try_torch(log_fn, inputs, log_args)
263
+
246
264
  results = self._process(log_fn, value)
247
265
  current_register[log_fn] = results
248
266
  self.do_print = old_print
249
267
 
250
268
  if target is not None:
251
- losses = get_losses(inputs, target, False)
269
+ losses = get_losses(inputs, target, exclude_invalid_losses)
252
270
  started_ls = False
253
271
  if self.do_print:
254
272
  for loss, res in losses.items():
lt_tensor/model_base.py CHANGED
@@ -1,10 +1,12 @@
1
- __all__ = ["Model"]
2
-
1
+ __all__ = ["Model", "_ModelExtended", "LossTracker"]
3
2
 
3
+ import gc
4
+ import json
5
+ import math
4
6
  import warnings
5
7
  from .torch_commons import *
6
8
  from lt_utils.common import *
7
- from lt_utils.misc_utils import log_traceback
9
+ from lt_utils.misc_utils import log_traceback, get_current_time
8
10
 
9
11
  T = TypeVar("T")
10
12
 
@@ -17,6 +19,80 @@ POSSIBLE_OUTPUT_TYPES: TypeAlias = Union[
17
19
  ]
18
20
 
19
21
 
22
+ class LossTracker:
23
+ last_file = f"logs/history_{get_current_time()}.json"
24
+
25
+ def __init__(self, max_len=50_000):
26
+ self.max_len = max_len
27
+ self.history = {
28
+ "train": [],
29
+ "eval": [],
30
+ }
31
+
32
+ def append(self, loss: float, mode: Literal["train", "eval"] = "train"):
33
+ assert mode in self.history, f"Invalid mode '{mode}'. Use 'train' or 'eval'."
34
+ self.history[mode].append(float(loss))
35
+ if len(self.history[mode]) > self.max_len:
36
+ self.history[mode] = self.history[mode][-self.max_len :]
37
+
38
+ def get(self, mode: Literal["train", "eval"] = "train"):
39
+ return self.history.get(mode, [])
40
+
41
+ def save(self, path: Optional[PathLike] = None):
42
+ if path is None:
43
+ path = f"logs/history_{get_current_time()}.json"
44
+
45
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
46
+ with open(path, "w") as f:
47
+ json.dump(self.history, f, indent=2)
48
+
49
+ self.last_file = path
50
+
51
+ def load(self, path: Optional[PathLike] = None):
52
+ if path is None:
53
+ _path = self.last_file
54
+ else:
55
+ _path = path
56
+ with open(_path) as f:
57
+ self.history = json.load(f)
58
+ if path is not None:
59
+ self.last_file = path
60
+
61
+ def plot(self, backend: Literal["matplotlib", "plotly"] = "plotly"):
62
+ if backend == "plotly":
63
+ try:
64
+ import plotly.graph_objs as go
65
+ except ModuleNotFoundError:
66
+ warnings.warn(
67
+ "No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
68
+ )
69
+ return
70
+ fig = go.Figure()
71
+ for mode, losses in self.history.items():
72
+ if losses:
73
+ fig.add_trace(go.Scatter(y=losses, name=mode.capitalize()))
74
+ fig.update_layout(
75
+ title="Training vs Evaluation Loss",
76
+ xaxis_title="Step",
77
+ yaxis_title="Loss",
78
+ template="plotly_dark",
79
+ )
80
+ fig.show()
81
+
82
+ elif backend == "matplotlib":
83
+ import matplotlib.pyplot as plt
84
+
85
+ for mode, losses in self.history.items():
86
+ if losses:
87
+ plt.plot(losses, label=f"{mode.capitalize()} Loss")
88
+ plt.title("Loss over Time")
89
+ plt.xlabel("Step")
90
+ plt.ylabel("Loss")
91
+ plt.legend()
92
+ plt.grid(True)
93
+ plt.show()
94
+
95
+
20
96
  class Model(nn.Module, ABC):
21
97
  """
22
98
  This makes it easier to assign a device and retrieves it later
@@ -24,6 +100,8 @@ class Model(nn.Module, ABC):
24
100
 
25
101
  _device: torch.device = ROOT_DEVICE
26
102
  _autocast: bool = False
103
+ _loss_history: LossTracker = LossTracker(100_000)
104
+ _is_unfrozen: bool = False
27
105
 
28
106
  @property
29
107
  def autocast(self):
@@ -61,6 +139,7 @@ class Model(nn.Module, ABC):
61
139
  if hasattr(self, weight):
62
140
  w = getattr(self, weight)
63
141
  if isinstance(w, nn.Module):
142
+
64
143
  w.requires_grad_(not freeze)
65
144
  else:
66
145
  weight.requires_grad_(not freeze)
@@ -112,21 +191,27 @@ class Model(nn.Module, ABC):
112
191
  for name, param in self.named_parameters():
113
192
  if no_exclusions:
114
193
  try:
115
- param.requires_grad_(False)
116
- frozen.append(name)
194
+ if param.requires_grad:
195
+ param.requires_grad_(False)
196
+ frozen.append(name)
197
+ else:
198
+ not_frozen.append((name, "was_frozen"))
117
199
  except Exception as e:
118
200
  not_frozen.append((name, str(e)))
119
201
  elif any(layer in name for layer in exclude):
120
202
  try:
121
- param.requires_grad_(False)
122
- frozen.append(name)
203
+ if param.requires_grad:
204
+ param.requires_grad_(False)
205
+ frozen.append(name)
206
+ else:
207
+ not_frozen.append((name, "was_frozen"))
123
208
  except Exception as e:
124
209
  not_frozen.append((name, str(e)))
125
210
  else:
126
- not_frozen.append((name, "Excluded"))
211
+ not_frozen.append((name, "excluded"))
127
212
  return dict(frozen=frozen, not_frozen=not_frozen)
128
213
 
129
- def unfreeze_all_except(self, exclude: Optional[list[str]] = None):
214
+ def unfreeze_all(self, exclude: Optional[list[str]] = None):
130
215
  """Unfreezes all model parameters except specified layers."""
131
216
  no_exclusions = not exclude
132
217
  unfrozen = []
@@ -134,18 +219,24 @@ class Model(nn.Module, ABC):
134
219
  for name, param in self.named_parameters():
135
220
  if no_exclusions:
136
221
  try:
137
- param.requires_grad_(True)
138
- unfrozen.append(name)
222
+ if not param.requires_grad:
223
+ param.requires_grad_(True)
224
+ unfrozen.append(name)
225
+ else:
226
+ not_unfrozen.append((name, "was_unfrozen"))
139
227
  except Exception as e:
140
228
  not_unfrozen.append((name, str(e)))
141
229
  elif any(layer in name for layer in exclude):
142
230
  try:
143
- param.requires_grad_(True)
144
- unfrozen.append(name)
231
+ if not param.requires_grad:
232
+ param.requires_grad_(True)
233
+ unfrozen.append(name)
234
+ else:
235
+ not_unfrozen.append((name, "was_unfrozen"))
145
236
  except Exception as e:
146
237
  not_unfrozen.append((name, str(e)))
147
238
  else:
148
- not_unfrozen.append((name, "Excluded"))
239
+ not_unfrozen.append((name, "excluded"))
149
240
  return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
150
241
 
151
242
  def to(self, *args, **kwargs):
@@ -192,6 +283,7 @@ class Model(nn.Module, ABC):
192
283
 
193
284
  self._apply(convert)
194
285
  self.device = device
286
+ self._apply_device_to()
195
287
  return self
196
288
 
197
289
  def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
@@ -202,6 +294,7 @@ class Model(nn.Module, ABC):
202
294
  ":" + str(device) if isinstance(device, (int, float)) else device.index
203
295
  )
204
296
  self.device = dvc
297
+ self._apply_device_to()
205
298
  return self
206
299
 
207
300
  def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
@@ -212,6 +305,7 @@ class Model(nn.Module, ABC):
212
305
  ":" + str(device) if isinstance(device, (int, float)) else device.index
213
306
  )
214
307
  self.device = dvc
308
+ self._apply_device_to()
215
309
  return self
216
310
 
217
311
  def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
@@ -222,6 +316,7 @@ class Model(nn.Module, ABC):
222
316
  ":" + str(device) if isinstance(device, (int, float)) else device.index
223
317
  )
224
318
  self.device = dvc
319
+ self._apply_device_to()
225
320
  return self
226
321
 
227
322
  def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
@@ -232,11 +327,13 @@ class Model(nn.Module, ABC):
232
327
  ":" + str(device) if isinstance(device, (int, float)) else device.index
233
328
  )
234
329
  self.device = dvc
330
+ self._apply_device_to()
235
331
  return self
236
332
 
237
333
  def cpu(self) -> T:
238
334
  super().cpu()
239
335
  self.device = "cpu"
336
+ self._apply_device_to()
240
337
  return self
241
338
 
242
339
  def count_trainable_parameters(self, module_name: Optional[str] = None):
@@ -314,15 +411,26 @@ class Model(nn.Module, ABC):
314
411
  else:
315
412
  print(f"Non-Trainable Parameters: {params}")
316
413
 
317
- def save_weights(self, path: Union[Path, str]):
414
+ def save_weights(
415
+ self,
416
+ path: Union[Path, str],
417
+ replace: bool = False,
418
+ ):
318
419
  path = Path(path)
420
+ model_dir = path
319
421
  if path.exists():
320
- assert (
321
- path.is_file()
322
- ), "The provided path exists but its a directory not a file!"
323
- path.rmdir()
422
+ if path.is_dir():
423
+ model_dir = Path(path, f"model_{get_current_time()}.pt")
424
+ elif path.is_file():
425
+ if replace:
426
+ path.unlink()
427
+ else:
428
+ model_dir = Path(path.parent, f"model_{get_current_time()}.pt")
429
+ else:
430
+ if not "." in str(path):
431
+ model_dir = Path(path, f"model_{get_current_time()}.pt")
324
432
  path.parent.mkdir(exist_ok=True, parents=True)
325
- torch.save(obj=self.state_dict(), f=str(path))
433
+ torch.save(obj=self.state_dict(), f=str(model_dir))
326
434
 
327
435
  def load_weights(
328
436
  self,
@@ -338,7 +446,14 @@ class Model(nn.Module, ABC):
338
446
  if not path.exists():
339
447
  assert not raise_if_not_exists, "Path does not exists!"
340
448
  return None
341
- assert path.is_file(), "The provided path is not a valid file!"
449
+ if path.is_dir():
450
+ possible_files = list(Path(path).rglob("*.pt"))
451
+ assert (
452
+ possible_files or not raise_if_not_exists
453
+ ), "No model could be found in the given path!"
454
+ if not possible_files:
455
+ return None
456
+ path = sorted(possible_files)[-1]
342
457
  state_dict = torch.load(
343
458
  str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
344
459
  )
@@ -353,30 +468,131 @@ class Model(nn.Module, ABC):
353
468
  def inference(self, *args, **kwargs):
354
469
  if self.training:
355
470
  self.eval()
356
- if self.autocast:
357
- with torch.autocast(device_type=self.device.type):
358
- return self(*args, **kwargs)
359
471
  return self(*args, **kwargs)
360
472
 
361
473
  def train_step(
362
474
  self,
363
- *args,
475
+ *inputs,
364
476
  **kwargs,
365
477
  ):
366
478
  """Train Step"""
367
479
  if not self.training:
368
480
  self.train()
369
- return self(*args, **kwargs)
370
-
371
- @torch.autocast(device_type=_device.type)
372
- def ac_forward(self, *args, **kwargs):
373
- return
481
+ return self(*inputs, **kwargs)
374
482
 
375
483
  def __call__(self, *args, **kwds) -> POSSIBLE_OUTPUT_TYPES:
376
- return super().__call__(*args, **kwds)
484
+ if self.autocast and not self.training:
485
+ with torch.autocast(device_type=self.device.type):
486
+ return super().__call__(*args, **kwds)
487
+ else:
488
+ return super().__call__(*args, **kwds)
377
489
 
378
490
  @abstractmethod
379
491
  def forward(
380
492
  self, *args, **kwargs
381
493
  ) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
382
494
  pass
495
+
496
+ def add_loss(
497
+ self, loss: Union[float, list[float]], mode: Literal["train", "eval"] = "train"
498
+ ):
499
+ if isinstance(loss, Number) and loss:
500
+ self._loss_history.append(loss, mode)
501
+ elif isinstance(loss, (list, tuple)):
502
+ if loss:
503
+ self._loss_history.append(sum(loss) / len(loss), mode=mode)
504
+ elif isinstance(loss, Tensor):
505
+ try:
506
+ self._loss_history.append(loss.detach().flatten().mean().item())
507
+ except Exception as e:
508
+ log_traceback(e, "add_loss - Tensor")
509
+
510
+ def save_loss_history(self, path: Optional[PathLike] = None):
511
+ self._loss_history.save(path)
512
+
513
+ def load_loss_history(self, path: Optional[PathLike] = None):
514
+ self._loss_history.load(path)
515
+
516
+ def get_loss_avg(self, mode: Literal["train", "eval"], quantity: int = 0):
517
+ t_list = self._loss_history.get("train")
518
+ if not t_list:
519
+ return float("nan")
520
+ if quantity > 0:
521
+ t_list = t_list[-quantity:]
522
+ return sum(t_list) / len(t_list)
523
+
524
+ def freeze_unfreeze_loss(
525
+ self,
526
+ losses: Optional[Union[float, List[float]]] = None,
527
+ trigger_loss: float = 0.1,
528
+ excluded_modules: Optional[List[str]] = None,
529
+ eval_last: int = 1000,
530
+ ):
531
+ """If a certain threshold is reached the weights will freeze or unfreeze the modules.
532
+ the biggest use-case for this function is when training GANs where the balance
533
+ from the discriminator and generator must be kept.
534
+
535
+ Args:
536
+ losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
537
+ trigger_loss (float, optional): The value where the weights will be either freeze or unfreeze. Defaults to 0.1.
538
+ excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
539
+ eval_last (float, optional): The number of previous losses to be locked behind to calculate the current averange. Default to 1000.
540
+
541
+ returns:
542
+ bool: True when its frozen and false when its trainable.
543
+ """
544
+ if losses is not None:
545
+ calculated = None
546
+ self.add_loss(losses)
547
+
548
+ value = self.get_loss_avg("train", eval_last)
549
+
550
+ if value <= trigger_loss:
551
+ if self._is_unfrozen:
552
+ self.freeze_all(excluded_modules)
553
+ self._is_unfrozen = False
554
+ return True
555
+ else:
556
+ if not self._is_unfrozen:
557
+ self.unfreeze_all(excluded_modules)
558
+ self._is_unfrozen = True
559
+ return False
560
+
561
+
562
+ class _ModelExtended(Model):
563
+ """Planed, but not ready, maybe in the near future?"""
564
+ criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None
565
+ optimizer: Optional[optim.Optimizer] = None
566
+
567
+ def train_step(
568
+ self,
569
+ *inputs,
570
+ loss_label: Optional[Tensor] = None,
571
+ **kwargs,
572
+ ):
573
+ if not self.training:
574
+ self.train()
575
+ if self.optimizer is not None:
576
+ self.optimizer.zero_grad()
577
+ if self.autocast:
578
+ if self.criterion is None:
579
+ raise RuntimeError(
580
+ "To use autocast during training, you must assign a criterion first!"
581
+ )
582
+ with torch.autocast(device_type=self.device.type):
583
+ out = self.forward(*loss_label, **kwargs)
584
+ loss = self.criterion(out, loss_label)
585
+
586
+ if self.optimizer is not None:
587
+ loss.backward()
588
+ self.optimizer.step()
589
+ return loss
590
+ elif self.criterion is not None:
591
+ out = self.forward(*loss_label, **kwargs)
592
+ loss = self.criterion(out, loss_label)
593
+ if self.optimizer is not None:
594
+ loss.backward()
595
+ self.optimizer.step()
596
+ return loss
597
+ else:
598
+ return self(*inputs, **kwargs)
@@ -208,3 +208,25 @@ class MultiScaleEncoder1D(Model):
208
208
  def forward(self, x: torch.Tensor) -> torch.Tensor:
209
209
  # x: [B, C, T]
210
210
  return self.net(x) # [B, hidden, T]
211
+
212
+
213
+ class AudioClassifier(Model):
214
+ def __init__(self, n_mels:int=80, num_classes=5):
215
+ super().__init__()
216
+ self.model = nn.Sequential(
217
+ nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
218
+ nn.LeakyReLU(0.2),
219
+ nn.Conv1d(256, 256, kernel_size=3, padding=1, groups=4),
220
+ nn.BatchNorm1d(256),
221
+ nn.LeakyReLU(0.2),
222
+ nn.Conv1d(256, 256, kernel_size=3, padding=1),
223
+ nn.BatchNorm1d(256),
224
+ nn.LeakyReLU(0.2),
225
+ nn.AdaptiveAvgPool1d(1), # Output shape: [B, 64, 1]
226
+ nn.Flatten(), # -> [B, 64]
227
+ nn.Linear(256, num_classes),
228
+ )
229
+ self.eval()
230
+
231
+ def forward(self, x):
232
+ return self.model(x)
@@ -76,20 +76,6 @@ class PeriodDiscriminator(Model):
76
76
  return x.flatten(1, -1), f_map
77
77
 
78
78
 
79
- class MultiPeriodDiscriminator(Model):
80
- def __init__(self, periods=[2, 3, 5, 7, 11]):
81
- super().__init__()
82
-
83
- self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
84
-
85
- def forward(self, x: torch.Tensor):
86
- """
87
- x: (B, T)
88
- Returns: list of tuples of outputs from each period discriminator and the f_map.
89
- """
90
- return [d(x) for d in self.discriminators]
91
-
92
-
93
79
  class ScaleDiscriminator(nn.Module):
94
80
  def __init__(self, use_spectral_norm=False):
95
81
  super().__init__()
@@ -123,11 +109,11 @@ class ScaleDiscriminator(nn.Module):
123
109
 
124
110
 
125
111
  class MultiScaleDiscriminator(Model):
126
- def __init__(self):
112
+ def __init__(self, layers: int = 3):
127
113
  super().__init__()
128
114
  self.pooling = nn.AvgPool1d(4, 2, padding=2)
129
115
  self.discriminators = nn.ModuleList(
130
- [ScaleDiscriminator(i == 0) for i in range(3)]
116
+ [ScaleDiscriminator(i == 0) for i in range(layers)]
131
117
  )
132
118
 
133
119
  def forward(self, x: torch.Tensor):
@@ -136,57 +122,75 @@ class MultiScaleDiscriminator(Model):
136
122
  Returns: list of outputs from each scale discriminator
137
123
  """
138
124
  outputs = []
125
+ features = []
139
126
  for i, d in enumerate(self.discriminators):
140
127
  if i != 0:
141
128
  x = self.pooling(x)
142
- outputs.append(d(x))
143
- return outputs
129
+ out, f_map = d(x)
130
+ outputs.append(out)
131
+ features.append(f_map)
132
+ return outputs, features
144
133
 
145
134
 
146
- class GeneralLossDescriminator(Model):
147
- """TODO: build an unified loss for both mpd and msd here."""
148
-
149
- def __init__(self):
135
+ class MultiPeriodDiscriminator(Model):
136
+ def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]):
150
137
  super().__init__()
151
- self.mpd = MultiPeriodDiscriminator()
152
- self.msd = MultiScaleDiscriminator()
153
- self.print_trainable_parameters()
154
-
155
- def _get_group_(self):
156
- pass
138
+ self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
157
139
 
158
- def forward(self, x: Tensor, y_hat: Tensor):
159
- return
140
+ def forward(self, x: torch.Tensor):
141
+ """
142
+ x: (B, T)
143
+ Returns: list of tuples of outputs from each period discriminator and the f_map.
144
+ """
145
+ # torch.log(torch.clip(x, min=clip_val))
146
+ out_map = []
147
+ feat_map = []
148
+ for d in self.discriminators:
149
+ out, feat = d(x)
150
+ out_map.append(out)
151
+ feat_map.append(feat)
152
+ return out_map, feat_map
160
153
 
161
154
 
162
- def discriminator_loss(d_outputs_real, d_outputs_fake):
155
+ def discriminator_loss(real_out_map, fake_out_map):
163
156
  loss = 0.0
164
- for real_out, fake_out in zip(d_outputs_real, d_outputs_fake):
165
- real_score = real_out[0]
166
- fake_score = fake_out[0]
167
- loss += torch.mean(F.relu(1.0 - real_score)) + torch.mean(
168
- F.relu(1.0 + fake_score)
169
- )
170
- return loss
157
+ rl, fl = [], []
158
+ for real_out, fake_out in zip(real_out_map, fake_out_map):
159
+ real_loss = torch.mean((1.0 - real_out) ** 2)
160
+ fake_loss = torch.mean(fake_out**2)
161
+ loss += real_loss + fake_loss
162
+ rl.append(real_loss.item())
163
+ fl.append(fake_loss.item())
164
+ return loss, sum(rl), sum(fl)
171
165
 
172
166
 
173
- def generator_adv_loss(d_outputs_fake):
167
+ def generator_adv_loss(fake_disc_outputs: List[Tensor]):
174
168
  loss = 0.0
175
- for fake_out in d_outputs_fake:
169
+ for fake_out in fake_disc_outputs:
176
170
  fake_score = fake_out[0]
177
171
  loss += -torch.mean(fake_score)
178
172
  return loss
179
173
 
180
174
 
181
- def feature_matching_loss(
182
- d_outputs_real,
183
- d_outputs_fake,
184
- loss_fn: Callable[[Tensor, Tensor], Tensor] = F.mse_loss,
175
+ def feature_loss(
176
+ fmap_r,
177
+ fmap_g,
178
+ weight=2.0,
179
+ loss_fn: Callable[[Tensor, Tensor], Tensor] = F.l1_loss,
185
180
  ):
186
181
  loss = 0.0
187
- for real_out, fake_out in zip(d_outputs_real, d_outputs_fake):
188
- real_feats = real_out[1]
189
- fake_feats = fake_out[1]
190
- for real_f, fake_f in zip(real_feats, fake_feats):
191
- loss += loss_fn(fake_f, real_f)
192
- return loss
182
+ for dr, dg in zip(fmap_r, fmap_g):
183
+ for rl, gl in zip(dr, dg):
184
+ loss += loss_fn(rl - gl)
185
+ return loss * weight
186
+
187
+
188
+ def generator_loss(disc_generated_outputs):
189
+ loss = 0.0
190
+ gen_losses = []
191
+ for dg in disc_generated_outputs:
192
+ l = torch.mean((1.0 - dg) ** 2)
193
+ gen_losses.append(l.item())
194
+ loss += l
195
+
196
+ return loss, gen_losses
@@ -106,44 +106,3 @@ class Generator(Model):
106
106
  classname = m.__class__.__name__
107
107
  if "Conv" in classname:
108
108
  m.weight.data.normal_(mean, std)
109
-
110
-
111
- # Below are items found in the Rishikesh's repo that might work for this generator.
112
- # https://github.com/rishikksh20/iSTFTNet-pytorch/blob/781480e9563d4dff5a8cc9ef1af6c6e0cab025c8/models.py
113
-
114
-
115
- def feature_loss(fmap_r, fmap_g, weight=2.0):
116
- """Feature matching loss between real and generated feature maps."""
117
- loss = 0.0
118
- for dr, dg in zip(fmap_r, fmap_g):
119
- for rl, gl in zip(dr, dg):
120
- loss += torch.mean(torch.abs(rl - gl))
121
- return loss * weight
122
-
123
-
124
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
125
- """LSGAN-style loss for real and fake predictions."""
126
- loss = 0.0
127
- r_losses, g_losses = [], []
128
-
129
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
130
- r_loss = torch.mean((1.0 - dr) ** 2)
131
- g_loss = torch.mean(dg**2)
132
- loss += r_loss + g_loss
133
- r_losses.append(r_loss.item())
134
- g_losses.append(g_loss.item())
135
-
136
- return loss, r_losses, g_losses
137
-
138
-
139
- def generator_loss(disc_generated_outputs):
140
- """LSGAN generator loss encouraging fake to look like real (close to 1)."""
141
- loss = 0.0
142
- gen_losses = []
143
-
144
- for dg in disc_generated_outputs:
145
- l = torch.mean((1.0 - dg) ** 2)
146
- gen_losses.append(l.item())
147
- loss += l
148
-
149
- return loss, gen_losses
lt_tensor/noise_tools.py CHANGED
@@ -271,9 +271,8 @@ class NoiseSchedulerB(nn.Module):
271
271
  def forward(
272
272
  self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
273
273
  ) -> Tensor:
274
- apply_noise()
275
274
  assert (
276
- 0 >= t < self.timesteps
275
+ 0 <= t < self.timesteps
277
276
  ), f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
278
277
 
279
278
  if noise is None:
lt_tensor/transform.py CHANGED
@@ -420,44 +420,11 @@ class InverseTransform(Model):
420
420
  self.onesided = onesided
421
421
  self.normalized = normalized
422
422
  self.window = torch.hann_window(win_length) if window is None else window
423
- self.update_settings()
424
-
423
+
425
424
  def _apply_device_to(self):
426
425
  """Applies to device while used with module `Model`"""
427
426
  self.window = self.window.to(device=self.device)
428
427
 
429
- def update_settings(
430
- self,
431
- *,
432
- n_fft: Optional[int] = None,
433
- hop_length: Optional[int] = None,
434
- win_length: Optional[int] = None,
435
- length: Optional[int] = None,
436
- window: Optional[Tensor] = None,
437
- onesided: Optional[bool] = None,
438
- return_complex: Optional[bool] = None,
439
- center: Optional[bool] = None,
440
- normalized: Optional[bool] = None,
441
- **_,
442
- ):
443
-
444
- self.kwargs = dict(
445
- n_fft=default(n_fft, self.n_fft),
446
- hop_length=default(hop_length, self.hop_length),
447
- win_length=default(win_length, self.win_length),
448
- length=default(length, self.length),
449
- window=default(window, self.window),
450
- onesided=default(onesided, self.onesided),
451
- return_complex=default(return_complex, self.return_complex),
452
- center=default(center, self.center),
453
- normalized=default(normalized, self.normalized),
454
- )
455
- if self.kwargs["onesided"] and self.kwargs["return_complex"]:
456
- warnings.warn(
457
- "You cannot use return_complex with `onesided` enabled. `return_complex` is set to False."
458
- )
459
- self.kwargs["return_complex"] = False
460
-
461
428
  def forward(self, spec: Tensor, phase: Tensor, **kwargs):
462
429
  """
463
430
  Perform the inverse short-time Fourier transform.
@@ -476,7 +443,16 @@ class InverseTransform(Model):
476
443
  Tensor
477
444
  Time-domain waveform reconstructed from `spec` and `phase`.
478
445
  """
479
- if kwargs:
480
- self.update_settings(**kwargs)
481
446
 
482
- return torch.istft(spec * torch.exp(phase * 1j), **self.kwargs)
447
+ return torch.istft(
448
+ spec * torch.exp(phase * 1j),
449
+ n_fft = self.n_fft,
450
+ hop_length=self.hop_length,
451
+ win_length=self.win_length,
452
+ window=self.window,
453
+ center=self.center,
454
+ normalized=self.normalized,
455
+ onesided=self.onesided,
456
+ length=self.length,
457
+ return_complex=self.return_complex,
458
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a10
3
+ Version: 0.0.1a11
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,28 +1,28 @@
1
1
  lt_tensor/__init__.py,sha256=uwJ7uiO18VYj8Z1V4KSOQ3ZrnowSgJWKCIiFBrzLMOI,429
2
- lt_tensor/losses.py,sha256=TinZJP2ypZ7Tdg6d9nnFWFkPyormfgQ0Z9P2ER3sqzE,4341
2
+ lt_tensor/losses.py,sha256=1wrke1e68hUBNAoPdJgKni0pJvXKcieza_R8nwBzMW4,4937
3
3
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
4
4
  lt_tensor/math_ops.py,sha256=ewIYkvxIy_Lab_9ExjFUgLs-oYLOu8IRRDo7f1pn3i8,2248
5
- lt_tensor/misc_utils.py,sha256=N9Rf-i6m51Q3YYdmI5tI5Rb3wPz8OAJrTrLlqfCwWrk,24792
6
- lt_tensor/model_base.py,sha256=8qN7oklALFanOz-eqVzdnB9RD2kN_3ltynSMAPOl-TI,13413
5
+ lt_tensor/misc_utils.py,sha256=8LqtpmLKqCo79NdH160ByQojG8YTDcw8aHKFgOFGVLI,25425
6
+ lt_tensor/model_base.py,sha256=a2ogixC2fUyOLqz15TzCRcGXvBam--TdmpG83jw9Of8,21543
7
7
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
8
- lt_tensor/noise_tools.py,sha256=JkWw0-bCMRNNMShwXKKt5KbO3104tvNiBePt-ThPkEo,11366
8
+ lt_tensor/noise_tools.py,sha256=rfFbPsrsycWVuH9G4zZCQC9Vgi9r8hDaECcB0TZYSYQ,11345
9
9
  lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
10
- lt_tensor/transform.py,sha256=Bxh87vFRKuZay_g1Alf_ZtEo89CzmV3XUQDINwHB7iA,14505
10
+ lt_tensor/transform.py,sha256=LZZ9G7ud1cojERC7N7hMAbH9GC3ImY1hBIY00kVMs-I,13492
11
11
  lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  lt_tensor/datasets/audio.py,sha256=YREyRsCvy-KS5tE0JNMWEdlIJogE1khLqhiq4wOWXVg,3777
13
13
  lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
14
- lt_tensor/model_zoo/bsc.py,sha256=muxIR7dU-Pvf-HFE-iy3zmRb1sTJlcs1vqdlnbU1Hss,6307
15
- lt_tensor/model_zoo/disc.py,sha256=SphFVVPZLP96-mZPEvWD_up2aT63rqSPjnps1-j9D6w,5707
14
+ lt_tensor/model_zoo/bsc.py,sha256=OQqsQDRBf6gWqoeGeEuIaTh96AqcDyTIbO8MAMNTtI4,7045
15
+ lt_tensor/model_zoo/disc.py,sha256=9RxyHYH2nGhxLs_yoEFVgerBfH4-qdaL2Mu9akyG0_M,5841
16
16
  lt_tensor/model_zoo/fsn.py,sha256=5ySsg2OHjvTV_coPAdZQ0f7bz4ugJB8mDYsItmd61qA,2102
17
17
  lt_tensor/model_zoo/gns.py,sha256=Tirr_grONp_FFQ_L7K-zV2lvkaC39h8mMl4QDpx9vLQ,6028
18
- lt_tensor/model_zoo/istft.py,sha256=0Xms2QNPAgz_ib8XTfaWl1SCHgS53oKC6-EkDkl_qe4,4863
18
+ lt_tensor/model_zoo/istft.py,sha256=RV7KVY7q4CYzzsWXH4NGJQwSqrYWwHh-16Q62lKoA2k,3594
19
19
  lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
20
20
  lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
21
21
  lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
22
22
  lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
23
23
  lt_tensor/processors/audio.py,sha256=2Sta_KytTqGZh-ZeHpcCbqP6O8VT6QQVkx-7szA3Itc,8830
24
- lt_tensor-0.0.1a10.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
25
- lt_tensor-0.0.1a10.dist-info/METADATA,sha256=-VDQmGfkd5uW4_8B_TbwH-xvRivsGn3jWEtXTyeCT0s,966
26
- lt_tensor-0.0.1a10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- lt_tensor-0.0.1a10.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
28
- lt_tensor-0.0.1a10.dist-info/RECORD,,
24
+ lt_tensor-0.0.1a11.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
25
+ lt_tensor-0.0.1a11.dist-info/METADATA,sha256=DNs5JZfr_mjve_GHy13Auics3BI_f1pNYBth-dQW04M,966
26
+ lt_tensor-0.0.1a11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ lt_tensor-0.0.1a11.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
28
+ lt_tensor-0.0.1a11.dist-info/RECORD,,