lt-tensor 0.0.1a4__tar.gz → 0.0.1a7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/PKG-INFO +2 -2
  2. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/__init__.py +9 -1
  3. lt_tensor-0.0.1a7/lt_tensor/datasets/audio.py +109 -0
  4. lt_tensor-0.0.1a7/lt_tensor/losses.py +145 -0
  5. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/math_ops.py +7 -0
  6. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/misc_utils.py +10 -96
  7. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_base.py +105 -6
  8. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/disc.py +14 -14
  9. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/istft.py +41 -0
  10. lt_tensor-0.0.1a7/lt_tensor/noise_tools.py +368 -0
  11. lt_tensor-0.0.1a7/lt_tensor/processors/__init__.py +3 -0
  12. lt_tensor-0.0.1a7/lt_tensor/processors/audio.py +193 -0
  13. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/transform.py +190 -30
  14. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor.egg-info/PKG-INFO +2 -2
  15. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor.egg-info/SOURCES.txt +5 -1
  16. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor.egg-info/requires.txt +1 -1
  17. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/setup.py +2 -2
  18. lt_tensor-0.0.1a4/lt_tensor/datasets/audio.py +0 -110
  19. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/LICENSE +0 -0
  20. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/README.md +0 -0
  21. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/datasets/__init__.py +0 -0
  22. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/lr_schedulers.py +0 -0
  23. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/__init__.py +0 -0
  24. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/bsc.py +0 -0
  25. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/fsn.py +0 -0
  26. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/gns.py +0 -0
  27. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/pos.py +0 -0
  28. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/rsd.py +0 -0
  29. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/model_zoo/tfrms.py +0 -0
  30. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/monotonic_align.py +0 -0
  31. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor/torch_commons.py +0 -0
  32. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor.egg-info/dependency_links.txt +0 -0
  33. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/lt_tensor.egg-info/top_level.txt +0 -0
  34. {lt_tensor-0.0.1a4 → lt_tensor-0.0.1a7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a4
3
+ Version: 0.0.1a7
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,7 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.1.a3
20
+ Requires-Dist: lt-utils==0.0.1
21
21
  Requires-Dist: librosa>=0.11.0
22
22
  Dynamic: author
23
23
  Dynamic: classifier
@@ -1,13 +1,17 @@
1
1
  __version__ = "0.0.1a"
2
2
 
3
3
  from . import (
4
+ lr_schedulers,
4
5
  model_zoo,
5
6
  model_base,
6
7
  math_ops,
7
8
  misc_utils,
8
9
  monotonic_align,
9
10
  transform,
10
- lr_schedulers,
11
+ noise_tools,
12
+ losses,
13
+ processors,
14
+ datasets,
11
15
  )
12
16
 
13
17
  __all__ = [
@@ -18,4 +22,8 @@ __all__ = [
18
22
  "monotonic_align",
19
23
  "transform",
20
24
  "lr_schedulers",
25
+ "noise_tools",
26
+ "losses",
27
+ "processors",
28
+ "datasets",
21
29
  ]
@@ -0,0 +1,109 @@
1
+ __all__ = ["WaveMelDatasets"]
2
+ from ..torch_commons import *
3
+ from lt_utils.common import *
4
+ import random
5
+ from torch.utils.data import Dataset, DataLoader, Sampler
6
+ from ..processors import AudioProcessor
7
+ import torch.nn.functional as FT
8
+ from ..misc_utils import log_tensor
9
+
10
+
11
+ class WaveMelDataset(Dataset):
12
+ """Untested!"""
13
+
14
+ data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
15
+
16
+ def __init__(
17
+ self,
18
+ audio_processor: AudioProcessor,
19
+ path: PathLike,
20
+ limit_files: Optional[int] = None,
21
+ max_frame_length: Optional[int] = None,
22
+ ):
23
+ super().__init__()
24
+ assert max_frame_length is None or max_frame_length >= (
25
+ (audio_processor.n_fft // 2) + 1
26
+ )
27
+ self.post_n_fft = (audio_processor.n_fft // 2) + 1
28
+ self.ap = audio_processor
29
+ self.files = self.ap.find_audios(path)
30
+ if limit_files:
31
+ random.shuffle(self.files)
32
+ self.files = self.files[:limit_files]
33
+ self.data = []
34
+
35
+ for file in self.files:
36
+ results = self.load_data(file, max_frame_length)
37
+ self.data.extend(results)
38
+
39
+ def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
40
+ return {"mel": audio_mel, "raw": audio_raw, "file": file}
41
+
42
+ def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
43
+ initial_audio = self.ap.load_audio(file)
44
+ if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
45
+ audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
46
+ return [self._add_dict(initial_audio, audio_mel, file)]
47
+ results = []
48
+ for fragment in torch.split(
49
+ initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
50
+ ):
51
+ if fragment.shape[-1] < self.post_n_fft:
52
+ # sometimes the tensor will be too small to be able to pass on mel
53
+ continue
54
+ audio_mel = self.ap.compute_mel(fragment, add_base=True)
55
+ results.append(self._add_dict(fragment, audio_mel, file))
56
+ return results
57
+
58
+ def get_data_loader(
59
+ self,
60
+ batch_size: int = 1,
61
+ shuffle: Optional[bool] = None,
62
+ sampler: Optional[Union[Sampler, Iterable]] = None,
63
+ batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
64
+ num_workers: int = 0,
65
+ pin_memory: bool = False,
66
+ drop_last: bool = False,
67
+ timeout: float = 0,
68
+ ):
69
+ return DataLoader(
70
+ self,
71
+ batch_size=batch_size,
72
+ shuffle=shuffle,
73
+ sampler=sampler,
74
+ batch_sampler=batch_sampler,
75
+ num_workers=num_workers,
76
+ pin_memory=pin_memory,
77
+ drop_last=drop_last,
78
+ timeout=timeout,
79
+ collate_fn=self.collate_fn,
80
+ )
81
+
82
+ @staticmethod
83
+ def collate_fn(batch: Sequence[Dict[str, Tensor]]):
84
+ mels = []
85
+ audios = []
86
+ files = []
87
+ for x in batch:
88
+ mels.append(x["mel"])
89
+ audios.append(x["raw"])
90
+ files.append(x["file"])
91
+ # Find max time in mel (dim -1), and max audio length
92
+ max_mel_len = max([m.shape[-1] for m in mels])
93
+ max_audio_len = max([a.shape[-1] for a in audios])
94
+
95
+ padded_mels = torch.stack(
96
+ [FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mels]
97
+ ) # shape: [B, 80, T_max]
98
+
99
+ padded_audios = torch.stack(
100
+ [FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in audios]
101
+ ) # shape: [B, L_max]
102
+
103
+ return padded_mels, padded_audios, files
104
+
105
+ def __len__(self):
106
+ return len(self.data)
107
+
108
+ def __getitem__(self, index):
109
+ return self.data[index]
@@ -0,0 +1,145 @@
1
+ __all__ = ["masked_cross_entropy"]
2
+ import math
3
+ import random
4
+ from .torch_commons import *
5
+ from lt_utils.common import *
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def masked_cross_entropy(
10
+ logits: torch.Tensor, # [B, T, V]
11
+ targets: torch.Tensor, # [B, T]
12
+ lengths: torch.Tensor, # [B]
13
+ reduction: str = "mean",
14
+ ) -> torch.Tensor:
15
+ """
16
+ CrossEntropyLoss with masking for variable-length sequences.
17
+ - logits: unnormalized scores [B, T, V]
18
+ - targets: ground truth indices [B, T]
19
+ - lengths: actual sequence lengths [B]
20
+ """
21
+ B, T, V = logits.size()
22
+ logits = logits.view(-1, V)
23
+ targets = targets.view(-1)
24
+
25
+ # Create mask
26
+ mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
27
+ mask = mask.reshape(-1)
28
+
29
+ # Apply CE only where mask == True
30
+ loss = F.cross_entropy(
31
+ logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
32
+ )
33
+ if reduction == "none":
34
+ return loss
35
+ return loss
36
+
37
+
38
+ def diff_loss(pred_noise, true_noise, mask=None):
39
+ """Standard diffusion noise-prediction loss (e.g., DDPM)"""
40
+ if mask is not None:
41
+ return F.mse_loss(pred_noise * mask, true_noise * mask)
42
+ return F.mse_loss(pred_noise, true_noise)
43
+
44
+
45
+ def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
46
+ """Combines L1 and L2"""
47
+ l1 = F.l1_loss(pred_noise, true_noise)
48
+ l2 = F.mse_loss(pred_noise, true_noise)
49
+ return alpha * l1 + (1 - alpha) * l2
50
+
51
+
52
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
53
+ loss = 0
54
+ for real, fake in zip(real_preds, fake_preds):
55
+ if use_lsgan:
56
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
57
+ fake, torch.zeros_like(fake)
58
+ )
59
+ else:
60
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
61
+ torch.log(1 - fake + 1e-7)
62
+ )
63
+ return loss
64
+
65
+
66
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
67
+ loss = 0
68
+ for real, fake in zip(real_preds, fake_preds):
69
+ if use_lsgan:
70
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
71
+ fake, torch.zeros_like(fake)
72
+ )
73
+ else:
74
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
75
+ torch.log(1 - fake + 1e-7)
76
+ )
77
+ return loss
78
+
79
+
80
+ def gan_g_loss(fake_preds, use_lsgan=True):
81
+ loss = 0
82
+ for fake in fake_preds:
83
+ if use_lsgan:
84
+ loss += F.mse_loss(fake, torch.ones_like(fake))
85
+ else:
86
+ loss += -torch.mean(torch.log(fake + 1e-7))
87
+ return loss
88
+
89
+
90
+ def feature_matching_loss(real_feats, fake_feats):
91
+ """real_feats and fake_feats are lists of intermediate features"""
92
+ loss = 0
93
+ for real_layers, fake_layers in zip(real_feats, fake_feats):
94
+ for r, f in zip(real_layers, fake_layers):
95
+ loss += F.l1_loss(f, r.detach())
96
+ return loss
97
+
98
+
99
+ def feature_loss(real_fmaps, fake_fmaps, weight=2.0):
100
+ loss = 0.0
101
+ for dr, dg in zip(real_fmaps, fake_fmaps): # Each (layer list from a discriminator)
102
+ for r_feat, g_feat in zip(dr, dg):
103
+ loss += F.l1_loss(r_feat, g_feat)
104
+ return loss * weight
105
+
106
+
107
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
108
+ loss = 0.0
109
+ r_losses = []
110
+ g_losses = []
111
+
112
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
113
+ r_loss = F.mse_loss(dr, torch.ones_like(dr))
114
+ g_loss = F.mse_loss(dg, torch.zeros_like(dg))
115
+ loss += r_loss + g_loss
116
+ r_losses.append(r_loss)
117
+ g_losses.append(g_loss)
118
+
119
+ return loss, r_losses, g_losses
120
+
121
+
122
+ def generator_loss(fake_outputs):
123
+ total = 0.0
124
+ g_losses = []
125
+ for out in fake_outputs:
126
+ loss = F.mse_loss(out, torch.ones_like(out))
127
+ g_losses.append(loss)
128
+ total += loss
129
+ return total, g_losses
130
+
131
+
132
+ def multi_resolution_stft_loss(y, y_hat, fft_sizes=[512, 1024, 2048]):
133
+ loss = 0
134
+ for fft_size in fft_sizes:
135
+ hop = fft_size // 4
136
+ win = fft_size
137
+ y_stft = torch.stft(
138
+ y, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
139
+ )
140
+ y_hat_stft = torch.stft(
141
+ y_hat, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
142
+ )
143
+
144
+ loss += F.l1_loss(torch.abs(y_stft), torch.abs(y_hat_stft))
145
+ return loss
@@ -8,6 +8,7 @@ __all__ = [
8
8
  "dot_product",
9
9
  "normalize_tensor",
10
10
  "log_magnitude",
11
+ "shift_time",
11
12
  "phase",
12
13
  ]
13
14
 
@@ -50,6 +51,11 @@ def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
50
51
  return torch.roll(x, shifts=1, dims=dim)
51
52
 
52
53
 
54
+ def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
55
+ """Shifts tensor along time axis (last dim)."""
56
+ return torch.roll(x, shifts=shift, dims=-1)
57
+
58
+
53
59
  def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
54
60
  """Computes dot product along the specified dimension."""
55
61
  return torch.sum(x * y, dim=dim)
@@ -69,3 +75,4 @@ def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
69
75
  def phase(stft_complex: Tensor) -> Tensor:
70
76
  """Returns phase from complex STFT."""
71
77
  return torch.angle(stft_complex)
78
+
@@ -8,7 +8,6 @@ __all__ = [
8
8
  "unfreeze_selected_weights",
9
9
  "clip_gradients",
10
10
  "detach_hidden",
11
- "tensor_summary",
12
11
  "one_hot",
13
12
  "safe_divide",
14
13
  "batch_pad",
@@ -18,22 +17,20 @@ __all__ = [
18
17
  "default_device",
19
18
  "Packing",
20
19
  "Padding",
21
- "MaskUtils",
22
- "masked_cross_entropy",
23
- "NoiseScheduler",
20
+ "Masking",
24
21
  ]
25
22
 
26
23
  import gc
24
+ import sys
27
25
  import random
28
26
  import numpy as np
29
27
  from lt_utils.type_utils import is_str
30
28
  from .torch_commons import *
31
- from lt_utils.misc_utils import log_traceback, cache_wrapper
32
- from lt_utils.file_ops import load_json, load_yaml, save_json, save_yaml
33
- import math
29
+ from lt_utils.misc_utils import cache_wrapper
34
30
  from lt_utils.common import *
35
31
  import torch.nn.functional as F
36
32
 
33
+
37
34
  def log_tensor(
38
35
  item: Union[Tensor, np.ndarray],
39
36
  title: Optional[str] = None,
@@ -64,10 +61,13 @@ def log_tensor(
64
61
  print(f"mean: {item.mean(dim=dim):.4f}")
65
62
  except:
66
63
  pass
67
- if print_tensor:
68
- print(item)
64
+ if print_tensor:
65
+ print(item)
69
66
  if has_title:
70
67
  print("".join(["-"] * _b), "\n")
68
+ else:
69
+ print("\n")
70
+ sys.stdout.flush()
71
71
 
72
72
 
73
73
  def set_seed(seed: int):
@@ -136,11 +136,6 @@ def detach_hidden(hidden):
136
136
  return tuple(detach_hidden(h) for h in hidden)
137
137
 
138
138
 
139
- def tensor_summary(tensor: torch.Tensor) -> str:
140
- """Prints min/max/mean/std of a tensor for debugging."""
141
- return f"Shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}, min: {tensor.min():.4f}, max: {tensor.max():.4f}, mean: {tensor.mean():.4f}, std: {tensor.std():.4f}"
142
-
143
-
144
139
  def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
145
140
  """One-hot encodes a tensor of labels."""
146
141
  return F.one_hot(labels, num_classes).float()
@@ -463,7 +458,7 @@ class Padding:
463
458
  return torch.stack(padded), lengths
464
459
 
465
460
 
466
- class MaskUtils:
461
+ class Masking:
467
462
 
468
463
  @staticmethod
469
464
  def apply_mask(x: Tensor, mask: Tensor, fill_value: Number = 0) -> Tensor:
@@ -546,84 +541,3 @@ class MaskUtils:
546
541
  return (
547
542
  causal_mask & pad_mask.unsqueeze(1).expand(-1, pad_mask.size(1), -1).bool()
548
543
  )
549
-
550
-
551
- def masked_cross_entropy(
552
- logits: torch.Tensor, # [B, T, V]
553
- targets: torch.Tensor, # [B, T]
554
- lengths: torch.Tensor, # [B]
555
- reduction: str = "mean",
556
- ) -> torch.Tensor:
557
- """
558
- CrossEntropyLoss with masking for variable-length sequences.
559
- - logits: unnormalized scores [B, T, V]
560
- - targets: ground truth indices [B, T]
561
- - lengths: actual sequence lengths [B]
562
- """
563
- B, T, V = logits.size()
564
- logits = logits.view(-1, V)
565
- targets = targets.view(-1)
566
-
567
- # Create mask
568
- mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
569
- mask = mask.reshape(-1)
570
-
571
- # Apply CE only where mask == True
572
- loss = F.cross_entropy(
573
- logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
574
- )
575
- if reduction == "none":
576
- return loss
577
- return loss
578
-
579
-
580
- class NoiseScheduler(nn.Module):
581
- def __init__(self, timesteps: int = 512):
582
- super().__init__()
583
-
584
- betas = torch.linspace(1e-4, 0.02, timesteps)
585
- alphas = 1.0 - betas
586
- alpha_cumprod = torch.cumprod(alphas, dim=0)
587
-
588
- self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
589
- self.register_buffer(
590
- "sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
591
- )
592
-
593
- self.timesteps = timesteps
594
- self.default_noise = math.sqrt(1.25)
595
-
596
- def get_random_noise(
597
- self, min_max: Tuple[float, float] = (-3, 3), seed: int = 0
598
- ) -> float:
599
- if seed > 0:
600
- random.seed(seed)
601
- return random.uniform(*min_max)
602
-
603
- def set_noise(
604
- self,
605
- seed: int = 0,
606
- min_max: Tuple[float, float] = (-3, 3),
607
- default: bool = False,
608
- ):
609
- self.default_noise = (
610
- math.sqrt(1.25) if default else self.get_random_noise(min_max, seed)
611
- )
612
-
613
- def forward(
614
- self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
615
- ) -> Tensor:
616
- if t < 0 or t >= self.timesteps:
617
- raise ValueError(
618
- f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
619
- )
620
-
621
- if noise is None:
622
- noise = self.default_noise
623
-
624
- if isinstance(noise, (float, int)):
625
- noise = torch.randn_like(x_0) * noise
626
-
627
- alpha_term = self.sqrt_alpha_cumprod[t] * x_0
628
- noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
629
- return alpha_term + noise_term
@@ -4,6 +4,7 @@ __all__ = ["Model"]
4
4
  import warnings
5
5
  from .torch_commons import *
6
6
  from lt_utils.common import *
7
+ from lt_utils.misc_utils import log_traceback
7
8
 
8
9
  T = TypeVar("T")
9
10
 
@@ -40,20 +41,113 @@ class Model(nn.Module, ABC):
40
41
  def device(self, device: Union[torch.device, str]):
41
42
  assert isinstance(device, (str, torch.device))
42
43
  self._device = torch.device(device) if isinstance(device, str) else device
43
- self.tp_apply_device_to()
44
+ self._apply_device_to()
44
45
 
45
- def tp_apply_device_to(self):
46
+ def _apply_device_to(self):
46
47
  """Add here components that are needed to have device applied to them,
47
- that usualy the '.to()' function fails to apply
48
+ that usually the '.to()' function fails to apply
48
49
 
49
50
  example:
50
51
  ```
51
- def tp_apply_device_to(self):
52
+ def _apply_device_to(self):
52
53
  self.my_tensor = self.my_tensor.to(device=self.device)
53
54
  ```
54
55
  """
55
56
  pass
56
57
 
58
+ def freeze_weight(self, weight: Union[str, nn.Module], freeze: bool):
59
+ assert isinstance(weight, (str, nn.Module))
60
+ if isinstance(weight, str):
61
+ if hasattr(self, weight):
62
+ w = getattr(self, weight)
63
+ if isinstance(w, nn.Module):
64
+ w.requires_grad_(not freeze)
65
+ else:
66
+ weight.requires_grad_(not freeze)
67
+
68
+ def _freeze_unfreeze(
69
+ self,
70
+ weight: Union[str, nn.Module],
71
+ task: Literal["freeze", "unfreeze"] = "freeze",
72
+ _skip_except: bool = False,
73
+ ):
74
+ try:
75
+ assert isinstance(weight, (str, nn.Module))
76
+ if isinstance(weight, str):
77
+ w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a valid attribute of {self._get_name()}"
78
+ if hasattr(self, weight):
79
+ w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a Module type."
80
+ w = getattr(self, weight)
81
+ if isinstance(w, nn.Module):
82
+ w_txt = f"Successfully {task} the module '{weight}'."
83
+ w.requires_grad_(task == "unfreeze")
84
+
85
+ else:
86
+ w.requires_grad_(task == "unfreeze")
87
+ w_txt = f"Successfully '{task}' the module '{weight}'."
88
+ return w_txt
89
+ except Exception as e:
90
+ if not _skip_except:
91
+ raise e
92
+ return str(e)
93
+
94
+ def freeze_weight(
95
+ self,
96
+ weight: Union[str, nn.Module],
97
+ _skip_except: bool = False,
98
+ ):
99
+ return self._freeze_unfreeze(weight, "freeze", _skip_except)
100
+
101
+ def unfreeze_weight(
102
+ self,
103
+ weight: Union[str, nn.Module],
104
+ _skip_except: bool = False,
105
+ ):
106
+ return self._freeze_unfreeze(weight, "freeze", _skip_except)
107
+
108
+ def freeze_all(self, exclude: Optional[List[str]] = None):
109
+ no_exclusions = not exclude
110
+ frozen = []
111
+ not_frozen = []
112
+ for name, param in self.named_parameters():
113
+ if no_exclusions:
114
+ try:
115
+ param.requires_grad_(False)
116
+ frozen.append(name)
117
+ except Exception as e:
118
+ not_frozen.append((name, str(e)))
119
+ elif any(layer in name for layer in exclude):
120
+ try:
121
+ param.requires_grad_(False)
122
+ frozen.append(name)
123
+ except Exception as e:
124
+ not_frozen.append((name, str(e)))
125
+ else:
126
+ not_frozen.append((name, "Excluded"))
127
+ return dict(frozen=frozen, not_frozen=not_frozen)
128
+
129
+ def unfreeze_all_except(self, exclude: Optional[list[str]] = None):
130
+ """Unfreezes all model parameters except specified layers."""
131
+ no_exclusions = not exclude
132
+ unfrozen = []
133
+ not_unfrozen = []
134
+ for name, param in self.named_parameters():
135
+ if no_exclusions:
136
+ try:
137
+ param.requires_grad_(True)
138
+ unfrozen.append(name)
139
+ except Exception as e:
140
+ not_unfrozen.append((name, str(e)))
141
+ elif any(layer in name for layer in exclude):
142
+ try:
143
+ param.requires_grad_(True)
144
+ unfrozen.append(name)
145
+ except Exception as e:
146
+ not_unfrozen.append((name, str(e)))
147
+ else:
148
+ not_unfrozen.append((name, "Excluded"))
149
+ return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
150
+
57
151
  def to(self, *args, **kwargs):
58
152
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
59
153
  *args, **kwargs
@@ -186,11 +280,16 @@ class Model(nn.Module, ABC):
186
280
  )
187
281
 
188
282
  def get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
189
- """Returns the weights of the model entrie model or from a specified module"""
283
+ """Returns the weights of the model entry model or from a specified module"""
190
284
  if module_name is not None:
191
285
  assert hasattr(self, module_name), f"Module {module_name} does not exits"
192
286
  module = getattr(self, module_name)
193
- return [x.data.detach() for x in module.parameters()]
287
+ params = []
288
+ if isinstance(module, nn.Module):
289
+ return [x.data.detach() for x in module.parameters()]
290
+ elif isinstance(module, (Tensor, nn.Parameter)):
291
+ return [module.data.detach()]
292
+ raise (f"{module_name} is has no weights")
194
293
  return [x.data.detach() for x in self.parameters()]
195
294
 
196
295
  def print_trainable_parameters(
@@ -11,37 +11,36 @@ class PeriodDiscriminator(Model):
11
11
  use_spectral_norm=False,
12
12
  kernel_size: int = 5,
13
13
  stride: int = 3,
14
- initial_s: int = 32,
15
14
  ):
16
15
  super().__init__()
17
16
  self.period = period
17
+ self.stride = stride
18
+ self.kernel_size = kernel_size
18
19
  self.norm_f = weight_norm if use_spectral_norm == False else spectral_norm
20
+
21
+ self.channels = [32, 128, 512, 1024, 1024]
19
22
  self.first_pass = nn.Sequential(
20
23
  self.norm_f(
21
24
  nn.Conv2d(
22
- 1, initial_s * 4, (kernel_size, 1), (stride, 1), padding=(2, 0)
25
+ 1, self.channels[0], (kernel_size, 1), (stride, 1), padding=(2, 0)
23
26
  )
24
27
  ),
25
28
  nn.LeakyReLU(0.1),
26
29
  )
27
- self._last_sz = initial_s * 4
28
30
 
29
- self.convs = nn.ModuleList([self._get_next(i == 3) for i in range(4)])
31
+
32
+ self.convs = nn.ModuleList([self._get_next(self.channels[i+1], self.channels[i], i == 3) for i in range(4)])
30
33
 
31
34
  self.post_conv = nn.Conv2d(1024, 1, (stride, 1), 1, padding=(1, 0))
32
- self.kernel_size = kernel_size
33
- self.stride = stride
34
35
 
35
- def _get_next(self, is_last: bool = False):
36
- in_dim = self._last_sz
37
- self._last_sz *= 4
38
- print(self._last_sz, "-----------------------")
36
+ def _get_next(self, out_dim:int, last_in:int, is_last: bool = False):
39
37
  stride = (self.stride, 1) if not is_last else 1
38
+
40
39
  return nn.Sequential(
41
40
  self.norm_f(
42
41
  nn.Conv2d(
43
- in_dim,
44
- self._last_sz,
42
+ last_in,
43
+ out_dim,
45
44
  (self.kernel_size, 1),
46
45
  stride,
47
46
  padding=(2, 0),
@@ -91,6 +90,7 @@ class ScaleDiscriminator(nn.Module):
91
90
  def __init__(self, use_spectral_norm=False):
92
91
  super().__init__()
93
92
  norm_f = weight_norm if use_spectral_norm == False else spectral_norm
93
+ self.activation = nn.LeakyReLU(0.1)
94
94
  self.convs = nn.ModuleList(
95
95
  [
96
96
  norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
@@ -103,7 +103,6 @@ class ScaleDiscriminator(nn.Module):
103
103
  ]
104
104
  )
105
105
  self.post_conv = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
106
- self.activation = nn.LeakyReLU(0.1)
107
106
 
108
107
  def forward(self, x: torch.Tensor):
109
108
  """
@@ -147,9 +146,10 @@ class GeneralLossDescriminator(Model):
147
146
  super().__init__()
148
147
  self.mpd = MultiPeriodDiscriminator()
149
148
  self.msd = MultiScaleDiscriminator()
149
+ self.print_trainable_parameters()
150
150
 
151
151
  def _get_group_(self):
152
152
  pass
153
153
 
154
154
  def forward(self, x: Tensor, y_hat: Tensor):
155
- return
155
+ return