lt-tensor 0.0.1a0__py3-none-any.whl → 0.0.1a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py CHANGED
@@ -1 +1,21 @@
1
- __version__ = "0.0.1dev3"
1
+ __version__ = "0.0.1a"
2
+
3
+ from . import (
4
+ model_zoo,
5
+ model_base,
6
+ math_ops,
7
+ misc_utils,
8
+ monotonic_align,
9
+ transform,
10
+ lr_schedulers,
11
+ )
12
+
13
+ __all__ = [
14
+ "model_zoo",
15
+ "model_base",
16
+ "math_ops",
17
+ "misc_utils",
18
+ "monotonic_align",
19
+ "transform",
20
+ "lr_schedulers",
21
+ ]
File without changes
@@ -0,0 +1,111 @@
1
+ __all__ = ["AudioProcessor"]
2
+ from ..torch_commons import *
3
+ import torchaudio
4
+ from typing import TypeAlias, Union, Optional
5
+ from lt_utils.common import PathLike
6
+ import librosa
7
+ from lt_utils.type_utils import is_file
8
+ from torchaudio.functional import resample
9
+ from ..transform import inverse_transform
10
+
11
+
12
+ class AudioProcessor:
13
+
14
+ def __init__(
15
+ self,
16
+ sample_rate: int = 24000,
17
+ n_mels: int = 80,
18
+ n_fft: int = 2048,
19
+ win_length: int = 2048,
20
+ hop_length: int = 256,
21
+ f_min: float = 0,
22
+ f_max: float | None = None,
23
+ mean: int = -4,
24
+ std: int = 4,
25
+ n_iter: int = 32,
26
+ center: bool = True,
27
+ mel_scale: str = "htk",
28
+ inv_n_fft: int = 16,
29
+ inv_hop: int = 4,
30
+ ):
31
+ self.mean = mean
32
+ self.std = std
33
+ self.n_mels = n_mels
34
+ self.n_fft = n_fft
35
+ self.n_stft = n_fft // 2 + 1
36
+ self.f_min = f_min
37
+ self.f_max = f_max
38
+ self.n_iter = n_iter
39
+ self.hop_length = hop_length
40
+ self.sample_rate = sample_rate
41
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate=sample_rate,
43
+ n_mels=n_mels,
44
+ n_fft=n_fft,
45
+ win_length=win_length,
46
+ hop_length=hop_length,
47
+ center=center,
48
+ f_min=f_min,
49
+ f_max=f_max,
50
+ mel_scale=mel_scale,
51
+ )
52
+ self.mel_rscale = torchaudio.transforms.InverseMelScale(
53
+ n_stft=self.n_stft,
54
+ m_mels=n_mels,
55
+ sample_rate=sample_rate,
56
+ f_min=f_min,
57
+ f_max=f_max,
58
+ mel_scale=mel_scale,
59
+ )
60
+ self.giffin_lim = torchaudio.transforms.GriffinLim(
61
+ n_fft=n_fft,
62
+ n_iter=n_iter,
63
+ win_length=win_length,
64
+ hop_length=hop_length,
65
+ )
66
+ self._inverse_transform = lambda x, y: inverse_transform(
67
+ x, y, inv_n_fft, inv_hop, inv_n_fft
68
+ )
69
+
70
+ def inverse_transform(self, spec: Tensor, phase: Tensor):
71
+ return self._inverse_transform(spec, phase)
72
+
73
+ def compute_mel(
74
+ self,
75
+ wave: Tensor,
76
+ ) -> Tensor:
77
+ """Returns: [B, M, ML]"""
78
+ mel_tensor = self.mel_spec(wave) # [M, ML]
79
+ mel_tensor = (mel_tensor - self.mean) / self.std
80
+ return mel_tensor # [B, M, ML]
81
+
82
+ def reverse_mel(self, mel: Tensor, n_iter: Optional[int] = None):
83
+ if isinstance(n_iter, int) and n_iter != self.n_iter:
84
+ self.giffin_lim = torchaudio.transforms.GriffinLim(
85
+ n_fft=self.n_fft,
86
+ n_iter=n_iter,
87
+ win_length=self.win_length,
88
+ hop_length=self.hop_length,
89
+ )
90
+ self.n_iter = n_iter
91
+ return self.giffin_lim.forward(
92
+ self.mel_rscale(mel),
93
+ )
94
+
95
+ def load_audio(
96
+ self,
97
+ path: PathLike,
98
+ top_db: float = 30,
99
+ ) -> Tensor:
100
+ is_file(path, True)
101
+ wave, sr = librosa.load(str(path), sr=self.sample_rate)
102
+ wave, _ = librosa.effects.trim(wave, top_db=top_db)
103
+ return (
104
+ torch.from_numpy(
105
+ librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
106
+ if sr != self.sample_rate
107
+ else wave
108
+ )
109
+ .float()
110
+ .unsqueeze(0)
111
+ )
lt_tensor/math_ops.py CHANGED
@@ -11,7 +11,7 @@ __all__ = [
11
11
  "phase",
12
12
  ]
13
13
 
14
- from ._torch_commons import *
14
+ from .torch_commons import *
15
15
 
16
16
 
17
17
  def sin_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
lt_tensor/misc_utils.py CHANGED
@@ -27,11 +27,12 @@ import gc
27
27
  import random
28
28
  import numpy as np
29
29
  from lt_utils.type_utils import is_str
30
- from ._torch_commons import *
30
+ from .torch_commons import *
31
31
  from lt_utils.misc_utils import log_traceback, cache_wrapper
32
32
  from lt_utils.file_ops import load_json, load_yaml, save_json, save_yaml
33
33
  import math
34
-
34
+ from lt_utils.common import *
35
+ import torch.nn.functional as F
35
36
 
36
37
  def log_tensor(
37
38
  item: Union[Tensor, np.ndarray],
@@ -83,12 +84,12 @@ def set_seed(seed: int):
83
84
  torch.xpu.manual_seed_all(seed)
84
85
 
85
86
 
86
- def count_parameters(model: Module) -> int:
87
+ def count_parameters(model: nn.Module) -> int:
87
88
  """Returns total number of trainable parameters."""
88
89
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
89
90
 
90
91
 
91
- def freeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
92
+ def freeze_all_except(model: nn.Module, except_layers: Optional[list[str]] = None):
92
93
  """Freezes all model parameters except specified layers."""
93
94
  no_exceptions = not except_layers
94
95
  for name, param in model.named_parameters():
@@ -98,14 +99,14 @@ def freeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
98
99
  param.requires_grad_(False)
99
100
 
100
101
 
101
- def freeze_selected_weights(model: Module, target_layers: list[str]):
102
+ def freeze_selected_weights(model: nn.Module, target_layers: list[str]):
102
103
  """Freezes only parameters on specified layers."""
103
104
  for name, param in model.named_parameters():
104
105
  if any(layer in name for layer in target_layers):
105
106
  param.requires_grad_(False)
106
107
 
107
108
 
108
- def unfreeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
109
+ def unfreeze_all_except(model: nn.Module, except_layers: Optional[list[str]] = None):
109
110
  """Unfreezes all model parameters except specified layers."""
110
111
  no_exceptions = not except_layers
111
112
  for name, param in model.named_parameters():
@@ -115,14 +116,14 @@ def unfreeze_all_except(model: Module, except_layers: Optional[list[str]] = None
115
116
  param.requires_grad_(True)
116
117
 
117
118
 
118
- def unfreeze_selected_weights(model: Module, target_layers: list[str]):
119
+ def unfreeze_selected_weights(model: nn.Module, target_layers: list[str]):
119
120
  """Unfreezes only parameters on specified layers."""
120
121
  for name, param in model.named_parameters():
121
122
  if not any(layer in name for layer in target_layers):
122
123
  param.requires_grad_(True)
123
124
 
124
125
 
125
- def clip_gradients(model: Module, max_norm: float = 1.0):
126
+ def clip_gradients(model: nn.Module, max_norm: float = 1.0):
126
127
  """Applies gradient clipping."""
127
128
  return nn.utils.clip_grad_norm_(model.parameters(), max_norm)
128
129
 
@@ -576,7 +577,7 @@ def masked_cross_entropy(
576
577
  return loss
577
578
 
578
579
 
579
- class NoiseScheduler(Module):
580
+ class NoiseScheduler(nn.Module):
580
581
  def __init__(self, timesteps: int = 512):
581
582
  super().__init__()
582
583
 
@@ -2,17 +2,35 @@ __all__ = ["Model"]
2
2
 
3
3
 
4
4
  import warnings
5
- from ._torch_commons import *
5
+ from .torch_commons import *
6
+ from lt_utils.common import *
6
7
 
7
- ROOT_DEVICE = torch.device(torch.zeros(1).device)
8
+ T = TypeVar("T")
8
9
 
10
+ ROOT_DEVICE = torch.zeros(1).device
9
11
 
10
- class _ModelDevice(nn.Module):
12
+ POSSIBLE_OUTPUT_TYPES: TypeAlias = Union[
13
+ Tensor,
14
+ Sequence[Tensor],
15
+ Dict[Union[str, Tensor, Any], Union[Sequence[Tensor], Tensor, Any]],
16
+ ]
17
+
18
+
19
+ class Model(nn.Module, ABC):
11
20
  """
12
21
  This makes it easier to assign a device and retrieves it later
13
22
  """
14
23
 
15
24
  _device: torch.device = ROOT_DEVICE
25
+ _autocast: bool = False
26
+
27
+ @property
28
+ def autocast(self):
29
+ return self._autocast
30
+
31
+ @autocast.setter
32
+ def autocast(self, value: bool):
33
+ self._autocast = value
16
34
 
17
35
  @property
18
36
  def device(self):
@@ -127,18 +145,6 @@ class _ModelDevice(nn.Module):
127
145
  self.device = "cpu"
128
146
  return self
129
147
 
130
-
131
- class Model(_ModelDevice, ABC):
132
- _autocast: bool = False
133
-
134
- @property
135
- def autocast(self):
136
- return self._autocast
137
-
138
- @autocast.setter
139
- def autocast(self, value: bool):
140
- self._autocast = value
141
-
142
148
  def count_trainable_parameters(self, module_name: Optional[str] = None):
143
149
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
144
150
  if module_name is not None:
@@ -263,6 +269,13 @@ class Model(_ModelDevice, ABC):
263
269
  self.train()
264
270
  return self(*args, **kwargs)
265
271
 
272
+ @torch.autocast(device_type=_device.type)
273
+ def ac_forward(self, *args, **kwargs):
274
+ return
275
+
276
+ def __call__(self, *args, **kwds) -> POSSIBLE_OUTPUT_TYPES:
277
+ return super().__call__(*args, **kwds)
278
+
266
279
  @abstractmethod
267
280
  def forward(
268
281
  self, *args, **kwargs
@@ -1,9 +1,11 @@
1
1
  __all__ = [
2
2
  "bsc", # basic
3
3
  "rsd", # residual
4
- "tfr", # transformer
4
+ "tfrms", # transformer
5
5
  "pos", # positional encoders
6
6
  "fsn", # fusion
7
- "dfs", # diffusion
7
+ "gns", # generators
8
+ "disc", # discriminators
9
+ "istft" # self-explanatory
8
10
  ]
9
- from . import bsc, dfs, fsn, pos, rsd, tfr
11
+ from . import bsc, fsn, gns, istft, pos, rsd, tfrms, disc
@@ -10,8 +10,8 @@ __all__ = [
10
10
  "MultiScaleEncoder1D",
11
11
  ]
12
12
 
13
- from .._torch_commons import *
14
- from .._basics import Model
13
+ from ..torch_commons import *
14
+ from ..model_base import Model
15
15
  from ..transform import get_sinusoidal_embedding
16
16
 
17
17
 
@@ -0,0 +1,155 @@
1
+ from ..torch_commons import *
2
+ import torch.nn.functional as F
3
+ from lt_tensor.model_base import Model
4
+ from lt_utils.common import *
5
+
6
+
7
+ class PeriodDiscriminator(Model):
8
+ def __init__(
9
+ self,
10
+ period: int,
11
+ use_spectral_norm=False,
12
+ kernel_size: int = 5,
13
+ stride: int = 3,
14
+ initial_s: int = 32,
15
+ ):
16
+ super().__init__()
17
+ self.period = period
18
+ self.norm_f = weight_norm if use_spectral_norm == False else spectral_norm
19
+ self.first_pass = nn.Sequential(
20
+ self.norm_f(
21
+ nn.Conv2d(
22
+ 1, initial_s * 4, (kernel_size, 1), (stride, 1), padding=(2, 0)
23
+ )
24
+ ),
25
+ nn.LeakyReLU(0.1),
26
+ )
27
+ self._last_sz = initial_s * 4
28
+
29
+ self.convs = nn.ModuleList([self._get_next(i == 3) for i in range(4)])
30
+
31
+ self.post_conv = nn.Conv2d(1024, 1, (stride, 1), 1, padding=(1, 0))
32
+ self.kernel_size = kernel_size
33
+ self.stride = stride
34
+
35
+ def _get_next(self, is_last: bool = False):
36
+ in_dim = self._last_sz
37
+ self._last_sz *= 4
38
+ print(self._last_sz, "-----------------------")
39
+ stride = (self.stride, 1) if not is_last else 1
40
+ return nn.Sequential(
41
+ self.norm_f(
42
+ nn.Conv2d(
43
+ in_dim,
44
+ self._last_sz,
45
+ (self.kernel_size, 1),
46
+ stride,
47
+ padding=(2, 0),
48
+ )
49
+ ),
50
+ nn.LeakyReLU(0.1),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ """
55
+ x: (B, T)
56
+ """
57
+ b, t = x.shape
58
+ if t % self.period != 0:
59
+ pad_len = self.period - (t % self.period)
60
+ x = F.pad(x, (0, pad_len), mode="reflect")
61
+ t = t + pad_len
62
+
63
+ x = x.view(b, 1, t // self.period, self.period) # (B, 1, T//P, P)
64
+
65
+ f_map = []
66
+ x = self.first_pass(x)
67
+ f_map.append(x)
68
+ for conv in self.convs:
69
+ x = conv(x)
70
+ f_map.append(x)
71
+ x = self.post_conv(x)
72
+ f_map.append(x)
73
+ return x.flatten(1, -1), f_map
74
+
75
+
76
+ class MultiPeriodDiscriminator(Model):
77
+ def __init__(self, periods=[2, 3, 5, 7, 11]):
78
+ super().__init__()
79
+
80
+ self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
81
+
82
+ def forward(self, x: torch.Tensor):
83
+ """
84
+ x: (B, T)
85
+ Returns: list of tuples of outputs from each period discriminator and the f_map.
86
+ """
87
+ return [d(x) for d in self.discriminators]
88
+
89
+
90
+ class ScaleDiscriminator(nn.Module):
91
+ def __init__(self, use_spectral_norm=False):
92
+ super().__init__()
93
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
94
+ self.convs = nn.ModuleList(
95
+ [
96
+ norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
97
+ norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
98
+ norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
99
+ norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
100
+ norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
101
+ norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
102
+ norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
103
+ ]
104
+ )
105
+ self.post_conv = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
106
+ self.activation = nn.LeakyReLU(0.1)
107
+
108
+ def forward(self, x: torch.Tensor):
109
+ """
110
+ x: (B, T)
111
+ """
112
+ f_map = []
113
+ x = x.unsqueeze(1) # (B, 1, T)
114
+ for conv in self.convs:
115
+ x = self.activation(conv(x))
116
+ f_map.append(x)
117
+ x = self.post_conv(x)
118
+ f_map.append(x)
119
+ return x.flatten(1, -1), f_map
120
+
121
+
122
+ class MultiScaleDiscriminator(Model):
123
+ def __init__(self):
124
+ super().__init__()
125
+ self.pooling = nn.AvgPool1d(4, 2, padding=2)
126
+ self.discriminators = nn.ModuleList(
127
+ [ScaleDiscriminator(i == 0) for i in range(3)]
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor):
131
+ """
132
+ x: (B, T)
133
+ Returns: list of outputs from each scale discriminator
134
+ """
135
+ outputs = []
136
+ for i, d in enumerate(self.discriminators):
137
+ if i != 0:
138
+ x = self.pooling(x)
139
+ outputs.append(d(x))
140
+ return outputs
141
+
142
+
143
+ class GeneralLossDescriminator(Model):
144
+ """TODO: build an unified loss for both mpd and msd here."""
145
+
146
+ def __init__(self):
147
+ super().__init__()
148
+ self.mpd = MultiPeriodDiscriminator()
149
+ self.msd = MultiScaleDiscriminator()
150
+
151
+ def _get_group_(self):
152
+ pass
153
+
154
+ def forward(self, x: Tensor, y_hat: Tensor):
155
+ return
@@ -6,8 +6,8 @@ __all__ = [
6
6
  "GatedFusion",
7
7
  ]
8
8
 
9
- from .._torch_commons import *
10
- from .._basics import Model
9
+ from ..torch_commons import *
10
+ from ..model_base import Model
11
11
 
12
12
 
13
13
  class ConcatFusion(Model):
@@ -39,7 +39,7 @@ class BilinearFusion(Model):
39
39
  return self.bilinear(a, b)
40
40
 
41
41
 
42
- class CrossAttentionFusion(nn.Module):
42
+ class CrossAttentionFusion(Model):
43
43
  def __init__(self, q_dim: int, kv_dim: int, n_heads: int = 4, d_model: int = 256):
44
44
  super().__init__()
45
45
  self.q_proj = nn.Linear(q_dim, d_model)
@@ -57,7 +57,7 @@ class CrossAttentionFusion(nn.Module):
57
57
  return output
58
58
 
59
59
 
60
- class GatedFusion(nn.Module):
60
+ class GatedFusion(Model):
61
61
  def __init__(self, in_dim: int):
62
62
  super().__init__()
63
63
  self.gate = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.Sigmoid())
@@ -7,11 +7,13 @@ __all__ = [
7
7
  "NoisePredictor1D",
8
8
  ]
9
9
 
10
- from .._torch_commons import *
11
- from .._basics import Model
12
- from .rsd import ResBlock1D
10
+ from ..torch_commons import *
11
+ from ..model_base import Model
12
+ from .rsd import ResBlock1D, ResBlocks
13
13
  from ..misc_utils import log_tensor
14
14
 
15
+ import torch.nn.functional as F
16
+
15
17
 
16
18
  class Downsample1D(Model):
17
19
  def __init__(
@@ -179,3 +181,5 @@ class NoisePredictor1D(Model):
179
181
  cond_proj = self.proj(cond).unsqueeze(-1) # [B, hidden, 1]
180
182
  x = x + cond_proj # simple conditioning
181
183
  return self.net(x) # [B, C, T]
184
+
185
+
@@ -0,0 +1,108 @@
1
+ from ..torch_commons import *
2
+ from ..model_base import Model
3
+ from .rsd import ResBlocks
4
+ from ..misc_utils import log_tensor
5
+
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Generator(Model):
10
+ """Based on the adaptation made by from Rishikesh
11
+ A Generator for audio processing, can be usd for tother things."""
12
+
13
+ def __init__(
14
+ self,
15
+ in_channels: int = 80,
16
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
17
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
18
+ upsample_initial_channel: int = 512,
19
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
20
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
21
+ [1, 3, 5],
22
+ [1, 3, 5],
23
+ [1, 3, 5],
24
+ ],
25
+ n_fft: int = 16,
26
+ activation: nn.Module = nn.LeakyReLU(0.1),
27
+ ):
28
+ super().__init__()
29
+ self.num_kernels = len(resblock_kernel_sizes)
30
+ self.num_upsamples = len(upsample_rates)
31
+ self.conv_pre = weight_norm(
32
+ nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
33
+ )
34
+ self.blocks = nn.ModuleList()
35
+ self.activation = activation
36
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
37
+ self.blocks.append(
38
+ self._make_blocks(
39
+ (i, k, u),
40
+ upsample_initial_channel,
41
+ resblock_kernel_sizes,
42
+ resblock_dilation_sizes,
43
+ )
44
+ )
45
+
46
+ ch = upsample_initial_channel // (2 ** (i + 1))
47
+ self.post_n_fft = n_fft // 2 + 1
48
+ self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
49
+ self.conv_post.apply(self.init_weights)
50
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
51
+
52
+ def _make_blocks(
53
+ self,
54
+ state: Tuple[int, int, int],
55
+ upsample_initial_channel: int,
56
+ resblock_kernel_sizes: List[Union[int, List[int]]],
57
+ resblock_dilation_sizes: List[int | List[int]],
58
+ ):
59
+ i, k, u = state
60
+ channels = upsample_initial_channel // (2 ** (i + 1))
61
+ return nn.ModuleDict(
62
+ dict(
63
+ up=nn.Sequential(
64
+ self.activation,
65
+ weight_norm(
66
+ nn.ConvTranspose1d(
67
+ upsample_initial_channel // (2**i),
68
+ channels,
69
+ k,
70
+ u,
71
+ padding=(k - u) // 2,
72
+ )
73
+ ),
74
+ ),
75
+ residual=ResBlocks(
76
+ channels,
77
+ resblock_kernel_sizes,
78
+ resblock_dilation_sizes,
79
+ self.activation,
80
+ ),
81
+ )
82
+ )
83
+
84
+ def forward(self, x):
85
+ x = self.conv_pre(x)
86
+ for block in self.blocks:
87
+ x = block["up"](x)
88
+ x = block["residual"](x)
89
+
90
+ x = self.reflection_pad(x)
91
+ x = self.conv_post(x)
92
+ spec = torch.exp(x[:, : self.post_n_fft, :])
93
+ phase = torch.sin(x[:, self.post_n_fft :, :])
94
+
95
+ return spec, phase
96
+
97
+ def remove_weight_norm(self):
98
+ for module in self.modules():
99
+ try:
100
+ remove_weight_norm(module)
101
+ except ValueError:
102
+ pass # Not normed, skip
103
+
104
+ @staticmethod
105
+ def init_weights(m, mean=0.0, std=0.01):
106
+ classname = m.__class__.__name__
107
+ if "Conv" in classname:
108
+ m.weight.data.normal_(mean, std)
@@ -5,11 +5,11 @@ __all__ = [
5
5
  ]
6
6
 
7
7
  import math
8
- from .._torch_commons import *
9
- from .._basics import Model
8
+ from ..torch_commons import *
9
+ from ..model_base import Model
10
10
 
11
11
 
12
- class RotaryEmbedding(Module):
12
+ class RotaryEmbedding(nn.Module):
13
13
  def __init__(self, dim: int, base: int = 10000):
14
14
  """
15
15
  Rotary Positional Embedding Module.
@@ -76,7 +76,7 @@ class RotaryEmbedding(Module):
76
76
  return x_rotated.view(b, s, d) # Back to [b, s, d]
77
77
 
78
78
 
79
- class PositionalEncoding(Module):
79
+ class PositionalEncoding(nn.Module):
80
80
  def __init__(self, d_model: int, max_len: int = 8192):
81
81
  super().__init__()
82
82
  # create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
@@ -100,7 +100,7 @@ class PositionalEncoding(Module):
100
100
  return x
101
101
 
102
102
 
103
- class LearnedPositionalEncoding(Module):
103
+ class LearnedPositionalEncoding(nn.Module):
104
104
  def __init__(self, max_len: int, dim_model: int, dropout: float = 0.1):
105
105
  super().__init__()
106
106
  self.embedding = nn.Embedding(max_len, dim_model)
@@ -1,23 +1,24 @@
1
1
  __all__ = [
2
2
  "spectral_norm_select",
3
+ "ResBlock1D_BT",
3
4
  "ResBlock1D",
4
5
  "ResBlock2D",
5
- "ResBlock1D_S",
6
+ "ResBlocks",
6
7
  ]
7
-
8
- from .._torch_commons import *
9
- from .._basics import Model
8
+ from lt_utils.common import *
9
+ from ..torch_commons import *
10
+ from ..model_base import Model
10
11
  import math
11
12
  from ..misc_utils import log_tensor
12
13
 
13
14
 
14
- def spectral_norm_select(module: Module, enabled: bool):
15
+ def spectral_norm_select(module: nn.Module, enabled: bool):
15
16
  if enabled:
16
17
  return spectral_norm(module)
17
18
  return module
18
19
 
19
20
 
20
- class ResBlock1D(Model):
21
+ class ResBlock1D_BT(Model):
21
22
  def __init__(
22
23
  self,
23
24
  in_channels: int,
@@ -106,6 +107,103 @@ class ResBlock1D(Model):
106
107
  m.weight.data.normal_(mean, std)
107
108
 
108
109
 
110
+ class ResBlock1D(Model):
111
+ def __init__(
112
+ self,
113
+ channels,
114
+ kernel_size=3,
115
+ dilation=(1, 3, 5),
116
+ activation: nn.Module = nn.LeakyReLU(0.1),
117
+ ):
118
+ super(ResBlock1D, self).__init__()
119
+ self.convs = nn.ModuleList(
120
+ [
121
+ self._get_conv_layer(i, channels, kernel_size, 1, dilation, activation)
122
+ for i in range(3)
123
+ ]
124
+ )
125
+ self.convs.apply(self.init_weights)
126
+
127
+ def _get_conv_layer(self, id, ch, k, stride, d, actv):
128
+ get_padding = lambda ks, d: int((ks * d - d) / 2)
129
+ return nn.Sequential(
130
+ actv, # 1
131
+ weight_norm(
132
+ nn.Conv1d(
133
+ ch, ch, k, stride, dilation=d[id], padding=get_padding(k, d[id])
134
+ )
135
+ ), # 2
136
+ actv, # 3
137
+ weight_norm(
138
+ nn.Conv1d(ch, ch, k, stride, dilation=1, padding=get_padding(k, 1))
139
+ ), # 4
140
+ )
141
+
142
+ def forward(self, x: torch.Tensor):
143
+ for cnn in self.convs:
144
+ x = cnn(x) + x
145
+ return x
146
+
147
+ def remove_weight_norm(self):
148
+ for module in self.modules():
149
+ try:
150
+ remove_weight_norm(module)
151
+ except ValueError:
152
+ pass # Not normed, skip
153
+
154
+ @staticmethod
155
+ def init_weights(m, mean=0.0, std=0.01):
156
+ classname = m.__class__.__name__
157
+ if "Conv" in classname:
158
+ m.weight.data.normal_(mean, std)
159
+
160
+
161
+ class ResBlocks(Model):
162
+ def __init__(
163
+ self,
164
+ channels: int,
165
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
166
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
167
+ [1, 3, 5],
168
+ [1, 3, 5],
169
+ [1, 3, 5],
170
+ ],
171
+ activation: nn.Module = nn.LeakyReLU(0.1),
172
+ ):
173
+ super().__init__()
174
+ self.num_kernels = len(resblock_kernel_sizes)
175
+ self.rb = nn.ModuleList()
176
+ self.activation = activation
177
+
178
+ for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
179
+ self.rb.append(ResBlock1D(channels, k, j, activation))
180
+
181
+ self.rb.apply(self.init_weights)
182
+
183
+ def forward(self, x: torch.Tensor):
184
+ xs = None
185
+ for i, block in enumerate(self.rb):
186
+ if i == 0:
187
+ xs = block(x)
188
+ else:
189
+ xs += block(x)
190
+ x = xs / self.num_kernels
191
+ return self.activation(x)
192
+
193
+ def remove_weight_norm(self):
194
+ for module in self.modules():
195
+ try:
196
+ remove_weight_norm(module)
197
+ except ValueError:
198
+ pass # Not normed, skip
199
+
200
+ @staticmethod
201
+ def init_weights(m, mean=0.0, std=0.01):
202
+ classname = m.__class__.__name__
203
+ if "Conv" in classname:
204
+ m.weight.data.normal_(mean, std)
205
+
206
+
109
207
  class ResBlock2D(Model):
110
208
  def __init__(
111
209
  self,
@@ -137,22 +235,3 @@ class ResBlock2D(Model):
137
235
 
138
236
  def forward(self, x):
139
237
  return (self.block(x) + self.skip(x)) / self.sqrt_2
140
-
141
-
142
- class ResBlock1D_S(Model):
143
- """Simplified version"""
144
-
145
- def __init__(self, channels: int, kernel_size: int = 3, dilation: int = 1):
146
- super().__init__()
147
- padding = (kernel_size - 1) // 2 * dilation
148
- self.net = nn.Sequential(
149
- nn.Conv1d(
150
- channels, channels, kernel_size, padding=padding, dilation=dilation
151
- ),
152
- nn.LeakyReLU(0.1),
153
- nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=1),
154
- )
155
- self.activation = nn.LeakyReLU(0.1)
156
-
157
- def forward(self, x: torch.Tensor) -> torch.Tensor:
158
- return self.activation(x + self.net(x))
@@ -7,10 +7,10 @@ __all__ = [
7
7
  ]
8
8
 
9
9
  import math
10
- from .._torch_commons import *
11
- from .._basics import Model
10
+ from ..torch_commons import *
11
+ from ..model_base import Model
12
12
  from lt_utils.misc_utils import default
13
-
13
+ from typing import Optional
14
14
  from .pos import *
15
15
  from .bsc import FeedForward
16
16
 
@@ -0,0 +1,30 @@
1
+ __all__ = [
2
+ "nn",
3
+ "torch",
4
+ "optim",
5
+ "Tensor",
6
+ "FloatTensor",
7
+ "LongTensor",
8
+ "HalfTensor",
9
+ "remove_weight_norm",
10
+ "remove_spectral_norm",
11
+ "weight_norm",
12
+ "spectral_norm",
13
+ "DeviceType",
14
+ # frequent typing
15
+ "Optional",
16
+ "List",
17
+ "Dict",
18
+ "Tuple",
19
+ "Union",
20
+ "TypeAlias",
21
+ "Sequence",
22
+ "Any",
23
+ ]
24
+ import torch
25
+ from torch.nn.utils import remove_weight_norm, remove_spectral_norm
26
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
27
+ from torch import nn, optim, Tensor, FloatTensor, LongTensor, HalfTensor
28
+ from typing import TypeAlias, Union, Optional, List, Dict, Tuple, Sequence, Any
29
+
30
+ DeviceType: TypeAlias = Union[torch.device, str]
lt_tensor/transform.py CHANGED
@@ -20,10 +20,12 @@ __all__ = [
20
20
  "stft_istft_rebuild",
21
21
  ]
22
22
 
23
- from ._torch_commons import *
23
+ from .torch_commons import *
24
24
  import torchaudio
25
25
  import math
26
26
  from .misc_utils import log_tensor
27
+ from lt_utils.common import *
28
+ import torch.nn.functional as F
27
29
 
28
30
 
29
31
  def to_mel_spectrogram(
@@ -196,7 +198,7 @@ def get_sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
196
198
  return emb
197
199
 
198
200
 
199
- def _generate_window(
201
+ def generate_window(
200
202
  M: int, alpha: float = 0.5, device: Optional[DeviceType] = None
201
203
  ) -> Tensor:
202
204
  if M < 1:
@@ -281,7 +283,7 @@ def window_sumsquare(
281
283
  x = torch.zeros(total_length, dtype=dtype, device=device)
282
284
 
283
285
  # Get the window (from scipy for now)
284
- win = _generate_window(window_spec, win_length, fftbins=True)
286
+ win = generate_window(window_spec, win_length, fftbins=True)
285
287
  win = torch.tensor(win, dtype=dtype, device=device)
286
288
 
287
289
  # Normalize and square
@@ -301,14 +303,14 @@ def window_sumsquare(
301
303
  def inverse_transform(
302
304
  spec: Tensor,
303
305
  phase: Tensor,
304
- window: Optional[Tensor] = None,
305
306
  n_fft: int = 2048,
306
307
  hop_length: int = 300,
307
308
  win_length: int = 1200,
308
309
  length: Optional[Any] = None,
310
+ window: Optional[Tensor] = None,
309
311
  ):
310
312
  if window is None:
311
- window = _generate_window(win_length)
313
+ window = torch.hann_window(win_length or n_fft).to(spec.device)
312
314
  return torch.istft(
313
315
  spec * torch.exp(phase * 1j),
314
316
  n_fft,
@@ -317,33 +319,3 @@ def inverse_transform(
317
319
  window=window,
318
320
  length=length,
319
321
  )
320
-
321
-
322
- def stft_istft_rebuild(
323
- input_data: Tensor,
324
- window: Optional[Tensor] = None,
325
- n_fft: int = 2048,
326
- hop_length: int = 300,
327
- win_length: int = 1200,
328
- ):
329
- """
330
- Perform STFT followed by ISTFT reconstruction using magnitude and phase.
331
- """
332
- if window is None:
333
- window = _generate_window(win_length)
334
- st = torch.stft(
335
- input_data,
336
- n_fft,
337
- hop_length,
338
- win_length,
339
- window=window,
340
- return_complex=True,
341
- )
342
- return torch.istft(
343
- torch.abs(st) * torch.exp(1j * torch.angle(st)),
344
- n_fft,
345
- hop_length,
346
- win_length,
347
- window=window,
348
- length=input_data.shape[-1],
349
- ).squeeze(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a0
3
+ Version: 0.0.1a3
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,7 +17,8 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.1a0
20
+ Requires-Dist: lt-utils==0.0.1.a3
21
+ Requires-Dist: librosa>=0.11.0
21
22
  Dynamic: author
22
23
  Dynamic: classifier
23
24
  Dynamic: description
@@ -0,0 +1,24 @@
1
+ lt_tensor/__init__.py,sha256=bvCjaIsYjbGFbR5MNezgLyRgN4_CsyrjmVEvuClsgOU,303
2
+ lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
3
+ lt_tensor/math_ops.py,sha256=ZtnJ9WB-pbFQLsXuNfQl2dAaeob5BWfxmhkwpxITUZ4,2066
4
+ lt_tensor/misc_utils.py,sha256=e44FCQbjNHP-4WOHIbtqqH0x590DzUE6CrD_4Vl_d38,19880
5
+ lt_tensor/model_base.py,sha256=tmRu5pTcELKMFcybOiZ1thJPuJWRSPkbUUtp9Y1NJWw,9555
6
+ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
7
+ lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
8
+ lt_tensor/transform.py,sha256=IVAaQlq12OvMVhX3lX4lgsTCJYJce5n5MtMy7IK_AU4,8892
9
+ lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ lt_tensor/datasets/audio.py,sha256=5Bn9Apb3K5QnRah2EfhztcatBRsnpQsdItm_jTaDrUs,3350
11
+ lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
12
+ lt_tensor/model_zoo/bsc.py,sha256=muxIR7dU-Pvf-HFE-iy3zmRb1sTJlcs1vqdlnbU1Hss,6307
13
+ lt_tensor/model_zoo/disc.py,sha256=ND6JR_x6b2Y1VqxZejalv8Cz5_TO3H_Z-0x6UnACbBM,4740
14
+ lt_tensor/model_zoo/fsn.py,sha256=5ySsg2OHjvTV_coPAdZQ0f7bz4ugJB8mDYsItmd61qA,2102
15
+ lt_tensor/model_zoo/gns.py,sha256=Tirr_grONp_FFQ_L7K-zV2lvkaC39h8mMl4QDpx9vLQ,6028
16
+ lt_tensor/model_zoo/istft.py,sha256=RV7KVY7q4CYzzsWXH4NGJQwSqrYWwHh-16Q62lKoA2k,3594
17
+ lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
18
+ lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
19
+ lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
20
+ lt_tensor-0.0.1a3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
21
+ lt_tensor-0.0.1a3.dist-info/METADATA,sha256=T5Gya3J6YebHzwR0gyvJ8lr5Rj9EJWtLSoo7--CSado,968
22
+ lt_tensor-0.0.1a3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ lt_tensor-0.0.1a3.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
24
+ lt_tensor-0.0.1a3.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- import torch
2
- from torch import nn, optim
3
- import torch.nn.functional as F
4
- from torch.optim import Optimizer
5
- from torch.nn import Module, L1Loss, MSELoss
6
- from torch.nn.utils import remove_weight_norm
7
- from torch import Tensor, FloatTensor, device, LongTensor
8
- from torch.nn.utils.parametrizations import weight_norm, spectral_norm
9
-
10
- from lt_utils.common import *
11
-
12
- DeviceType: TypeAlias = Union[device, str]
@@ -1,20 +0,0 @@
1
- lt_tensor/__init__.py,sha256=pUB05ZkgkpP10ivzwoWdbq_HCxw-iOsbf6m8eFtx-YM,26
2
- lt_tensor/_basics.py,sha256=Zty5XZ5qeVFoZJRhtpGvOH7rg9hbAS7mIULOdrOKBDQ,9189
3
- lt_tensor/_torch_commons.py,sha256=_2Eck-MsQ46PxW5ku7NJvNSL5vg54_4GkLCqdzFevwA,402
4
- lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
- lt_tensor/math_ops.py,sha256=j4Arst-kOdm0bcZbXD4rzcVdiyYOJ59ZQQIyH7r0Wug,2067
6
- lt_tensor/misc_utils.py,sha256=3r6ikrBCj2IjSWZMRU1Lif0OgYTF3HExANG_IqhPtic,19799
7
- lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
8
- lt_tensor/transform.py,sha256=IYPT2YHT9NDvHrdtJvTLmxL9Cm26Ck2Uc9zE0k6l2aI,9504
9
- lt_tensor/model_zoo/__init__.py,sha256=ybyd3St8wiswnBGKFcy6FqRo5NlfGPJPC7jbRJlTlv8,205
10
- lt_tensor/model_zoo/bsc.py,sha256=6jBICcy8FT81EUiN9g1eZuHhPF4xA7gzS5kaVT3RngU,6305
11
- lt_tensor/model_zoo/dfs.py,sha256=0dTA1aveZT5OZu8eI6Cb8q8IGSjZyFYDcfc2FpDH5S8,5980
12
- lt_tensor/model_zoo/fsn.py,sha256=YDu1sbLwJwSKCPlmPlqQujivlgfNvwpwGa5q4SY9MYk,2108
13
- lt_tensor/model_zoo/pos.py,sha256=L2j6zYkdBWjrgROJt4cFOwdnne6j94m2lGi9m_QC7oc,4460
14
- lt_tensor/model_zoo/rsd.py,sha256=QGfkhoP7BVCGlCyBkIxHE7eWUp71JFkK6bM4dgBw1Hw,4720
15
- lt_tensor/model_zoo/tfr.py,sha256=mIwu6WqDxcLGlBfofIIspzGpUe2jsR0hrzT9mEW-MHE,4208
16
- lt_tensor-0.0.1a0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
17
- lt_tensor-0.0.1a0.dist-info/METADATA,sha256=hQVkxd4J5C7KX1DRVVYkIVKK0MIlGf-0kSLQ--HkTdY,936
18
- lt_tensor-0.0.1a0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- lt_tensor-0.0.1a0.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
20
- lt_tensor-0.0.1a0.dist-info/RECORD,,