lt-tensor 0.0.1a39__tar.gz → 0.0.1a40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/PKG-INFO +1 -1
  2. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/__init__.py +1 -1
  3. lt_tensor-0.0.1a40/lt_tensor/model_zoo/audio_models/bemaganv2/__init__.py +205 -0
  4. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +14 -39
  5. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +20 -19
  6. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +24 -44
  7. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/istft/__init__.py +15 -39
  8. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/convs.py +35 -4
  9. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/noise_tools.py +22 -13
  10. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/processors/audio.py +115 -61
  11. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/PKG-INFO +1 -1
  12. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/SOURCES.txt +1 -0
  13. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/setup.py +1 -1
  14. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/LICENSE +0 -0
  15. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/README.md +0 -0
  16. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/config_templates.py +0 -0
  17. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/losses.py +0 -0
  18. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/lr_schedulers.py +0 -0
  19. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/math_ops.py +0 -0
  20. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/misc_utils.py +0 -0
  21. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_base.py +0 -0
  22. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/__init__.py +0 -0
  23. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/__init__.py +0 -0
  24. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/__init__.py +0 -0
  25. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/act.py +0 -0
  26. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/filter.py +0 -0
  27. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/alias_free/resample.py +0 -0
  28. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/activations/snake/__init__.py +0 -0
  29. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
  30. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/audio_models/resblocks.py +0 -0
  31. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/basic.py +0 -0
  32. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/features.py +0 -0
  33. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/fusion.py +0 -0
  34. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
  35. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/CQT/transforms.py +0 -0
  36. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/CQT/utils.py +0 -0
  37. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/__init__.py +0 -0
  38. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/_envelope_disc/__init__.py +0 -0
  39. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/losses/discriminators.py +0 -0
  40. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  41. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/residual.py +0 -0
  42. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/model_zoo/transformer.py +0 -0
  43. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/monotonic_align.py +0 -0
  44. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/processors/__init__.py +0 -0
  45. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/torch_commons.py +0 -0
  46. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor/transform.py +0 -0
  47. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/dependency_links.txt +0 -0
  48. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/requires.txt +0 -0
  49. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/lt_tensor.egg-info/top_level.txt +0 -0
  50. {lt_tensor-0.0.1a39 → lt_tensor-0.0.1a40}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a39
3
+ Version: 0.0.1a40
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -1,4 +1,4 @@
1
- __version__ = "0.0.1a39"
1
+ __version__ = "0.0.1a40"
2
2
 
3
3
  from . import (
4
4
  lr_schedulers,
@@ -0,0 +1,205 @@
1
+ from lt_utils.common import *
2
+ from lt_tensor.torch_commons import *
3
+ from lt_tensor.model_zoo.convs import ConvNets
4
+ from lt_tensor.config_templates import ModelConfig
5
+ from lt_utils.file_ops import is_file, load_json
6
+ from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
7
+ from lt_tensor.model_zoo.activations import snake, alias_free
8
+ from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
9
+
10
+
11
+ class BemaGANv2Config(ModelConfig):
12
+ # Training params
13
+ in_channels: int = 80
14
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2]
15
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4]
16
+ upsample_initial_channel: int = 1536
17
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
18
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
19
+ [1, 3, 5],
20
+ [1, 3, 5],
21
+ [1, 3, 5],
22
+ ]
23
+
24
+ activation: Literal["snake", "snakebeta"] = "snakebeta"
25
+ resblock_activation: Literal["snake", "snakebeta"] = "snakebeta"
26
+ resblock: int = 0
27
+ use_bias_at_final: bool = True
28
+ use_tanh_at_final: bool = True
29
+ snake_logscale: bool = True
30
+
31
+ def __init__(
32
+ self,
33
+ in_channels: int = 80,
34
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2],
35
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4],
36
+ upsample_initial_channel: int = 1536,
37
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
38
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
39
+ [1, 3, 5],
40
+ [1, 3, 5],
41
+ [1, 3, 5],
42
+ ],
43
+ activation: Literal["snake", "snakebeta"] = "snakebeta",
44
+ resblock_activation: Literal["snake", "snakebeta"] = "snakebeta",
45
+ resblock: Union[int, str] = "1",
46
+ use_bias_at_final: bool = False,
47
+ use_tanh_at_final: bool = False,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ settings = {
52
+ "in_channels": in_channels,
53
+ "upsample_rates": upsample_rates,
54
+ "upsample_kernel_sizes": upsample_kernel_sizes,
55
+ "upsample_initial_channel": upsample_initial_channel,
56
+ "resblock_kernel_sizes": resblock_kernel_sizes,
57
+ "resblock_dilation_sizes": resblock_dilation_sizes,
58
+ "activation": activation,
59
+ "resblock_activation": resblock_activation,
60
+ "resblock": resblock,
61
+ "use_bias_at_final": use_bias_at_final,
62
+ "use_tanh_at_final": use_tanh_at_final,
63
+ }
64
+ super().__init__(**settings)
65
+
66
+ def post_process(self):
67
+ if isinstance(self.resblock, str):
68
+ self.resblock = 0 if self.resblock == "1" else 1
69
+
70
+
71
+ class BemaGANv2Generator(ConvNets):
72
+
73
+ def __init__(
74
+ self, cfg: Union[BemaGANv2Config, Dict[str, object]] = BemaGANv2Config()
75
+ ):
76
+ super().__init__()
77
+ cfg = cfg if isinstance(cfg, BemaGANv2Config) else BemaGANv2Config(**cfg)
78
+ self.cfg = cfg
79
+
80
+ actv = get_snake(self.cfg.activation)
81
+
82
+ self.num_kernels = len(cfg.resblock_kernel_sizes)
83
+ self.num_upsamples = len(cfg.upsample_rates)
84
+
85
+ self.conv_pre = weight_norm(
86
+ nn.Conv1d(cfg.num_mels, cfg.upsample_initial_channel, 7, 1, padding=3)
87
+ )
88
+
89
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
90
+ resblock = AMPBlock1 if cfg.resblock == 0 else AMPBlock2
91
+
92
+ # transposed conv-based upsamplers. does not apply anti-aliasing
93
+ self.ups = nn.ModuleList()
94
+ for i, (u, k) in enumerate(zip(cfg.upsample_rates, cfg.upsample_kernel_sizes)):
95
+ self.ups.append(
96
+ nn.ModuleList(
97
+ [
98
+ weight_norm(
99
+ nn.ConvTranspose1d(
100
+ cfg.upsample_initial_channel // (2**i),
101
+ cfg.upsample_initial_channel // (2 ** (i + 1)),
102
+ k,
103
+ u,
104
+ padding=(k - u) // 2,
105
+ )
106
+ )
107
+ ]
108
+ )
109
+ )
110
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
111
+ self.resblocks = nn.ModuleList()
112
+ for i in range(len(self.ups)):
113
+ ch = cfg.upsample_initial_channel // (2 ** (i + 1))
114
+ for k, d in zip(cfg.resblock_kernel_sizes, cfg.resblock_dilation_sizes):
115
+ self.resblocks.append(
116
+ resblock(
117
+ ch,
118
+ k,
119
+ d,
120
+ snake_logscale=cfg.snake_logscale,
121
+ activation=cfg.resblock_activation,
122
+ )
123
+ )
124
+
125
+ self.activation_post = actv(ch, alpha_logscale=cfg.snake_logscale)
126
+ # post conv
127
+
128
+ self.conv_post = weight_norm(
129
+ nn.Conv1d(ch, 1, 7, 1, padding=3, bias=self.cfg.use_bias_at_final)
130
+ )
131
+ self._use_tanh = self.cfg.use_tanh_at_final
132
+
133
+ # weight initialization
134
+ for i in range(len(self.ups)):
135
+ self.ups[i].apply(self.init_weights)
136
+ self.conv_post.apply(self.init_weights)
137
+
138
+ def forward(self, x: Tensor):
139
+ # pre conv
140
+ x = self.conv_pre(x)
141
+
142
+ for i in range(self.num_upsamples):
143
+ # upsampling
144
+ for i_up in range(len(self.ups[i])):
145
+ x = self.ups[i][i_up](x)
146
+ # AMP blocks
147
+ xs = None
148
+ for j in range(self.num_kernels):
149
+ if xs is None:
150
+ xs = self.resblocks[i * self.num_kernels + j](x)
151
+ else:
152
+ xs += self.resblocks[i * self.num_kernels + j](x)
153
+ x = xs / self.num_kernels
154
+
155
+ # post conv
156
+ x = self.activation_post(x)
157
+ x = self.conv_post(x)
158
+ if self._use_tanh:
159
+ return x.tanh()
160
+ return x
161
+
162
+ @classmethod
163
+ def from_pretrained(
164
+ cls,
165
+ model_file: PathLike,
166
+ model_config: Union[
167
+ BemaGANv2Config, Dict[str, Any], PathLike
168
+ ] = BemaGANv2Config(),
169
+ *,
170
+ remove_norms: bool = False,
171
+ strict: bool = True,
172
+ map_location: str = "cpu",
173
+ weights_only: bool = False,
174
+ **kwargs,
175
+ ):
176
+
177
+ is_file(model_file, validate=True)
178
+ model_state_dict = torch.load(
179
+ model_file,
180
+ weights_only=weights_only,
181
+ map_location=map_location,
182
+ )
183
+
184
+ if isinstance(model_config, BemaGANv2Config):
185
+ h = model_config
186
+ elif isinstance(model_config, dict):
187
+ h = BemaGANv2Config(**model_config)
188
+ elif isinstance(model_config, (str, Path, bytes)):
189
+ h = BemaGANv2Config(
190
+ **load_json(model_config, BemaGANv2Config().state_dict())
191
+ )
192
+
193
+ model = cls(h)
194
+ if remove_norms:
195
+ model.remove_norms()
196
+ try:
197
+ model.load_state_dict(model_state_dict, strict=strict)
198
+ return model
199
+ except RuntimeError:
200
+ print(
201
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
202
+ )
203
+ model.remove_norms()
204
+ model.load_state_dict(model_state_dict, strict=strict)
205
+ return model
@@ -2,9 +2,9 @@ from lt_utils.common import *
2
2
  from lt_tensor.torch_commons import *
3
3
  from lt_tensor.model_zoo.convs import ConvNets
4
4
  from lt_tensor.config_templates import ModelConfig
5
- from lt_tensor.model_zoo.activations import snake, alias_free
5
+ from lt_tensor.model_zoo.activations import alias_free
6
6
  from lt_tensor.model_zoo.audio_models.resblocks import AMPBlock1, AMPBlock2, get_snake
7
- from lt_utils.file_ops import load_json, is_file, is_dir, is_path_valid
7
+ from lt_utils.file_ops import load_json, is_file
8
8
 
9
9
 
10
10
  class BigVGANConfig(ModelConfig):
@@ -78,8 +78,9 @@ class BigVGAN(ConvNets):
78
78
 
79
79
  """
80
80
 
81
- def __init__(self, cfg: BigVGANConfig):
81
+ def __init__(self, cfg: Union[BigVGANConfig, Dict[str, object]] = BigVGANConfig()):
82
82
  super().__init__()
83
+ cfg = cfg if isinstance(cfg, BigVGANConfig) else BigVGANConfig(**cfg)
83
84
  self.cfg = cfg
84
85
  actv = get_snake(self.cfg.activation)
85
86
 
@@ -173,46 +174,16 @@ class BigVGAN(ConvNets):
173
174
  return x.tanh()
174
175
  return x.clamp(min=-1.0, max=1.0)
175
176
 
176
- def load_weights(
177
- self,
178
- path,
179
- strict=False,
180
- assign=False,
181
- weights_only=False,
182
- mmap=None,
183
- raise_if_not_exists=False,
184
- **pickle_load_args,
185
- ):
186
- try:
187
- return super().load_weights(
188
- path,
189
- raise_if_not_exists,
190
- strict,
191
- assign,
192
- weights_only,
193
- mmap,
194
- **pickle_load_args,
195
- )
196
- except RuntimeError:
197
- self.remove_norms()
198
- return super().load_weights(
199
- path,
200
- raise_if_not_exists,
201
- strict,
202
- assign,
203
- weights_only,
204
- mmap,
205
- **pickle_load_args,
206
- )
207
-
208
177
  @classmethod
209
178
  def from_pretrained(
210
179
  cls,
211
180
  model_file: PathLike,
212
- model_config: Union[BigVGANConfig, Dict[str, Any]],
181
+ model_config: Union[
182
+ BigVGANConfig, Dict[str, Any], Dict[str, Any], PathLike
183
+ ] = BigVGANConfig(),
213
184
  *,
214
185
  remove_norms: bool = False,
215
- strict: bool = False,
186
+ strict: bool = True,
216
187
  map_location: str = "cpu",
217
188
  weights_only: bool = False,
218
189
  **kwargs,
@@ -220,13 +191,17 @@ class BigVGAN(ConvNets):
220
191
 
221
192
  is_file(model_file, validate=True)
222
193
  model_state_dict = torch.load(
223
- model_file, weights_only=weights_only, map_location=map_location
194
+ model_file,
195
+ weights_only=weights_only,
196
+ map_location=map_location,
224
197
  )
225
198
 
226
199
  if isinstance(model_config, BigVGANConfig):
227
200
  h = model_config
228
- else:
201
+ elif isinstance(model_config, dict):
229
202
  h = BigVGANConfig(**model_config)
203
+ elif isinstance(model_config, (str, Path, bytes)):
204
+ h = BigVGANConfig(**load_json(model_config, BigVGANConfig().state_dict()))
230
205
 
231
206
  model = cls(h)
232
207
  if remove_norms:
@@ -177,43 +177,44 @@ class ResidualBlock(Model):
177
177
 
178
178
 
179
179
  class DiffWave(Model):
180
- def __init__(self, params: DiffWaveConfig = DiffWaveConfig()):
180
+ def __init__(self, cfg: Union[DiffWaveConfig, dict[str, object]] = DiffWaveConfig()):
181
181
  super().__init__()
182
- self.params = params
183
- self.n_hop = self.params.hop_samples
182
+ cfg = cfg if isinstance(cfg, DiffWaveConfig) else DiffWaveConfig(**cfg)
183
+ self.cfg = cfg
184
+ self.n_hop = self.cfg.hop_samples
184
185
  self.input_projection = ConvEXT(
185
186
  in_channels=1,
186
- out_channels=params.residual_channels,
187
+ out_channels=cfg.residual_channels,
187
188
  kernel_size=1,
188
- apply_norm=self.params.apply_norm,
189
+ apply_norm=self.cfg.apply_norm,
189
190
  activation_out=nn.LeakyReLU(0.1),
190
191
  )
191
- self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
192
+ self.diffusion_embedding = DiffusionEmbedding(len(cfg.noise_schedule))
192
193
  self.spectrogram_upsampler = (
193
- SpectrogramUpsampler() if not self.params.unconditional else None
194
+ SpectrogramUpsampler() if not self.cfg.unconditional else None
194
195
  )
195
196
 
196
197
  self.residual_layers = nn.ModuleList(
197
198
  [
198
199
  ResidualBlock(
199
- params.n_mels,
200
- params.residual_channels,
201
- 2 ** (i % params.dilation_cycle_length),
202
- uncond=params.unconditional,
203
- apply_norm=self.params.apply_norm_resblock,
200
+ cfg.n_mels,
201
+ cfg.residual_channels,
202
+ 2 ** (i % cfg.dilation_cycle_length),
203
+ uncond=cfg.unconditional,
204
+ apply_norm=self.cfg.apply_norm_resblock,
204
205
  )
205
- for i in range(params.residual_layers)
206
+ for i in range(cfg.residual_layers)
206
207
  ]
207
208
  )
208
209
  self.skip_projection = ConvEXT(
209
- in_channels=params.residual_channels,
210
- out_channels=params.residual_channels,
210
+ in_channels=cfg.residual_channels,
211
+ out_channels=cfg.residual_channels,
211
212
  kernel_size=1,
212
- apply_norm=self.params.apply_norm,
213
+ apply_norm=self.cfg.apply_norm,
213
214
  activation_out=nn.LeakyReLU(0.1),
214
215
  )
215
216
  self.output_projection = ConvEXT(
216
- params.residual_channels, 1, 1, apply_norm=self.params.apply_norm, init_weights=True,
217
+ cfg.residual_channels, 1, 1, apply_norm=self.cfg.apply_norm, init_weights=True,
217
218
  )
218
219
  self.activation = nn.LeakyReLU(0.1)
219
220
  self._res_d = sqrt(len(self.residual_layers))
@@ -224,7 +225,7 @@ class DiffWave(Model):
224
225
  diffusion_step: Tensor,
225
226
  spectrogram: Optional[Tensor] = None,
226
227
  ):
227
- if not self.params.unconditional:
228
+ if not self.cfg.unconditional:
228
229
  assert spectrogram is not None
229
230
  if audio.ndim < 3:
230
231
  if audio.ndim == 2:
@@ -234,7 +235,7 @@ class DiffWave(Model):
234
235
 
235
236
  x = self.input_projection(audio)
236
237
  diffusion_step = self.diffusion_embedding(diffusion_step)
237
- if not self.params.unconditional: # use conditional model
238
+ if not self.cfg.unconditional: # use conditional model
238
239
  spectrogram = self.spectrogram_upsampler(spectrogram)
239
240
 
240
241
  skip = torch.zeros_like(x, device=x.device)
@@ -5,7 +5,7 @@ from lt_utils.common import *
5
5
  from lt_tensor.torch_commons import *
6
6
  from lt_tensor.model_zoo.convs import ConvNets
7
7
  from lt_tensor.config_templates import ModelConfig
8
- from lt_utils.file_ops import is_file
8
+ from lt_utils.file_ops import is_file, load_json
9
9
  from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
10
10
 
11
11
 
@@ -16,11 +16,15 @@ def get_padding(kernel_size, dilation=1):
16
16
  class HifiganConfig(ModelConfig):
17
17
  # Training params
18
18
  in_channels: int = 80
19
- upsample_rates: List[Union[int, List[int]]] = [8,8,2,2]
20
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4]
19
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2]
20
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4]
21
21
  upsample_initial_channel: int = 512
22
22
  resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11]
23
- resblock_dilation_sizes: List[Union[int, List[int]]] = [[1,3,5], [1,3,5], [1,3,5]]
23
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
24
+ [1, 3, 5],
25
+ [1, 3, 5],
26
+ [1, 3, 5],
27
+ ]
24
28
 
25
29
  activation: nn.Module = nn.LeakyReLU(0.1)
26
30
  resblock_activation: nn.Module = nn.LeakyReLU(0.1)
@@ -29,10 +33,10 @@ class HifiganConfig(ModelConfig):
29
33
  def __init__(
30
34
  self,
31
35
  in_channels: int = 80,
32
- upsample_rates: List[Union[int, List[int]]] = [8,8,2,2],
33
- upsample_kernel_sizes: List[Union[int, List[int]]] = [16,16,4,4],
36
+ upsample_rates: List[Union[int, List[int]]] = [8, 8, 2, 2],
37
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16, 4, 4],
34
38
  upsample_initial_channel: int = 512,
35
- resblock_kernel_sizes: List[Union[int, List[int]]] = [3,7,11],
39
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
36
40
  resblock_dilation_sizes: List[Union[int, List[int]]] = [
37
41
  [1, 3, 5],
38
42
  [1, 3, 5],
@@ -63,9 +67,11 @@ class HifiganConfig(ModelConfig):
63
67
 
64
68
 
65
69
  class HifiganGenerator(ConvNets):
66
- def __init__(self, cfg: HifiganConfig = HifiganConfig()):
70
+ def __init__(self, cfg: Union[HifiganConfig, Dict[str, object]] = HifiganConfig()):
67
71
  super().__init__()
72
+ cfg = cfg if isinstance(cfg, HifiganConfig) else HifiganConfig(**cfg)
68
73
  self.cfg = cfg
74
+
69
75
  self.num_kernels = len(cfg.resblock_kernel_sizes)
70
76
  self.num_upsamples = len(cfg.upsample_rates)
71
77
  self.conv_pre = weight_norm(
@@ -115,46 +121,16 @@ class HifiganGenerator(ConvNets):
115
121
  x = self.conv_post(self.activation(x))
116
122
  return x.tanh()
117
123
 
118
- def load_weights(
119
- self,
120
- path,
121
- strict=False,
122
- assign=False,
123
- weights_only=False,
124
- mmap=None,
125
- raise_if_not_exists=False,
126
- **pickle_load_args,
127
- ):
128
- try:
129
- return super().load_weights(
130
- path,
131
- raise_if_not_exists,
132
- strict,
133
- assign,
134
- weights_only,
135
- mmap,
136
- **pickle_load_args,
137
- )
138
- except RuntimeError:
139
- self.remove_norms()
140
- return super().load_weights(
141
- path,
142
- raise_if_not_exists,
143
- strict,
144
- assign,
145
- weights_only,
146
- mmap,
147
- **pickle_load_args,
148
- )
149
-
150
124
  @classmethod
151
125
  def from_pretrained(
152
126
  cls,
153
127
  model_file: PathLike,
154
- model_config: Union[HifiganConfig, Dict[str, Any]],
128
+ model_config: Union[
129
+ HifiganConfig, Dict[str, Any], Dict[str, Any], PathLike
130
+ ] = HifiganConfig(),
155
131
  *,
156
132
  remove_norms: bool = False,
157
- strict: bool = False,
133
+ strict: bool = True,
158
134
  map_location: str = "cpu",
159
135
  weights_only: bool = False,
160
136
  **kwargs,
@@ -162,13 +138,17 @@ class HifiganGenerator(ConvNets):
162
138
 
163
139
  is_file(model_file, validate=True)
164
140
  model_state_dict = torch.load(
165
- model_file, weights_only=weights_only, map_location=map_location
141
+ model_file,
142
+ weights_only=weights_only,
143
+ map_location=map_location,
166
144
  )
167
145
 
168
146
  if isinstance(model_config, HifiganConfig):
169
147
  h = model_config
170
- else:
148
+ elif isinstance(model_config, dict):
171
149
  h = HifiganConfig(**model_config)
150
+ elif isinstance(model_config, (str, Path, bytes)):
151
+ h = HifiganConfig(**load_json(model_config, HifiganConfig().state_dict()))
172
152
 
173
153
  model = cls(h)
174
154
  if remove_norms:
@@ -3,7 +3,7 @@ from lt_utils.common import *
3
3
  from lt_tensor.torch_commons import *
4
4
  from lt_tensor.model_zoo.convs import ConvNets
5
5
  from lt_tensor.config_templates import ModelConfig
6
- from lt_utils.file_ops import is_file
6
+ from lt_utils.file_ops import is_file, load_json
7
7
  from lt_tensor.model_zoo.audio_models.resblocks import ResBlock1, ResBlock2
8
8
 
9
9
 
@@ -67,8 +67,11 @@ class iSTFTNetConfig(ModelConfig):
67
67
 
68
68
 
69
69
  class iSTFTNetGenerator(ConvNets):
70
- def __init__(self, cfg: iSTFTNetConfig = iSTFTNetConfig()):
70
+ def __init__(
71
+ self, cfg: Union[iSTFTNetConfig, Dict[str, object]] = iSTFTNetConfig()
72
+ ):
71
73
  super().__init__()
74
+ cfg = cfg if isinstance(cfg, iSTFTNetConfig) else iSTFTNetConfig(**cfg)
72
75
  self.cfg = cfg
73
76
  self.num_kernels = len(cfg.resblock_kernel_sizes)
74
77
  self.num_upsamples = len(cfg.upsample_rates)
@@ -146,46 +149,16 @@ class iSTFTNetGenerator(ConvNets):
146
149
 
147
150
  return spec, phase
148
151
 
149
- def load_weights(
150
- self,
151
- path,
152
- strict=False,
153
- assign=False,
154
- weights_only=False,
155
- mmap=None,
156
- raise_if_not_exists=False,
157
- **pickle_load_args,
158
- ):
159
- try:
160
- return super().load_weights(
161
- path,
162
- raise_if_not_exists,
163
- strict,
164
- assign,
165
- weights_only,
166
- mmap,
167
- **pickle_load_args,
168
- )
169
- except RuntimeError:
170
- self.remove_norms()
171
- return super().load_weights(
172
- path,
173
- raise_if_not_exists,
174
- strict,
175
- assign,
176
- weights_only,
177
- mmap,
178
- **pickle_load_args,
179
- )
180
-
181
152
  @classmethod
182
153
  def from_pretrained(
183
154
  cls,
184
155
  model_file: PathLike,
185
- model_config: Union[iSTFTNetConfig, Dict[str, Any]],
156
+ model_config: Union[
157
+ iSTFTNetConfig, Dict[str, Any], Dict[str, Any], PathLike
158
+ ] = iSTFTNetConfig(),
186
159
  *,
187
160
  remove_norms: bool = False,
188
- strict: bool = False,
161
+ strict: bool = True,
189
162
  map_location: str = "cpu",
190
163
  weights_only: bool = False,
191
164
  **kwargs,
@@ -193,14 +166,17 @@ class iSTFTNetGenerator(ConvNets):
193
166
 
194
167
  is_file(model_file, validate=True)
195
168
  model_state_dict = torch.load(
196
- model_file, weights_only=weights_only, map_location=map_location
169
+ model_file,
170
+ weights_only=weights_only,
171
+ map_location=map_location,
197
172
  )
198
173
 
199
174
  if isinstance(model_config, iSTFTNetConfig):
200
175
  h = model_config
201
- else:
176
+ elif isinstance(model_config, dict):
202
177
  h = iSTFTNetConfig(**model_config)
203
-
178
+ elif isinstance(model_config, (str, Path, bytes)):
179
+ h = iSTFTNetConfig(**load_json(model_config, iSTFTNetConfig().state_dict()))
204
180
  model = cls(h)
205
181
  if remove_norms:
206
182
  model.remove_norms()
@@ -1,11 +1,7 @@
1
1
  __all__ = ["ConvNets", "ConvEXT"]
2
- import math
3
2
  from lt_utils.common import *
4
- import torch.nn.functional as F
5
3
  from lt_tensor.torch_commons import *
6
4
  from lt_tensor.model_base import Model
7
- from lt_tensor.misc_utils import log_tensor
8
- from lt_tensor.model_zoo.fusion import AdaFusion1D, AdaIN1D
9
5
  from lt_utils.misc_utils import default
10
6
 
11
7
 
@@ -52,6 +48,41 @@ class ConvNets(Model):
52
48
  if "Conv" in m.__class__.__name__:
53
49
  m.weight.data.normal_(mean, std)
54
50
 
51
+ def load_weights(
52
+ self,
53
+ path,
54
+ strict=False,
55
+ assign=False,
56
+ weights_only=False,
57
+ mmap=None,
58
+ raise_if_not_exists=False,
59
+ **pickle_load_args,
60
+ ):
61
+ try:
62
+ return super().load_weights(
63
+ path,
64
+ raise_if_not_exists,
65
+ strict,
66
+ assign,
67
+ weights_only,
68
+ mmap,
69
+ **pickle_load_args,
70
+ )
71
+ except RuntimeError as e:
72
+ try:
73
+ self.remove_norms()
74
+ return super().load_weights(
75
+ path,
76
+ raise_if_not_exists,
77
+ strict,
78
+ assign,
79
+ weights_only,
80
+ mmap,
81
+ **pickle_load_args,
82
+ )
83
+ except:
84
+ raise e
85
+
55
86
 
56
87
  class ConvEXT(ConvNets):
57
88
  def __init__(
@@ -13,6 +13,7 @@ __all__ = [
13
13
  ]
14
14
 
15
15
  from lt_utils.common import *
16
+ from lt_tensor.model_base import Model
16
17
  import torch.nn.functional as F
17
18
  from lt_tensor.torch_commons import *
18
19
  import math
@@ -20,17 +21,17 @@ import random
20
21
  from lt_tensor.misc_utils import set_seed
21
22
 
22
23
 
23
- def add_gaussian_noise(x: Tensor, noise_level=0.025):
24
+ def add_gaussian_noise(x: Tensor, noise_level: float = 0.025) -> Tensor:
24
25
  noise = torch.randn_like(x) * noise_level
25
26
  return x + noise
26
27
 
27
28
 
28
- def add_uniform_noise(x: Tensor, noise_level=0.025):
29
+ def add_uniform_noise(x: Tensor, noise_level: float = 0.025) -> Tensor:
29
30
  noise = (torch.rand_like(x) - 0.5) * 2 * noise_level
30
31
  return x + noise
31
32
 
32
33
 
33
- def add_linear_noise(x, noise_level=0.05):
34
+ def add_linear_noise(x, noise_level=0.05) -> Tensor:
34
35
  T = x.shape[-1]
35
36
  ramp = torch.linspace(0, noise_level, T, device=x.device)
36
37
  for _ in range(x.dim() - 1):
@@ -38,7 +39,7 @@ def add_linear_noise(x, noise_level=0.05):
38
39
  return x + ramp.expand_as(x)
39
40
 
40
41
 
41
- def add_impulse_noise(x: Tensor, noise_level=0.025):
42
+ def add_impulse_noise(x: Tensor, noise_level: float = 0.025) -> Tensor:
42
43
  # For image inputs
43
44
  probs = torch.rand_like(x)
44
45
  x_clone = x.detach().clone()
@@ -47,7 +48,7 @@ def add_impulse_noise(x: Tensor, noise_level=0.025):
47
48
  return x_clone
48
49
 
49
50
 
50
- def add_pink_noise(x: Tensor, noise_level=0.05):
51
+ def add_pink_noise(x: Tensor, noise_level: float = 0.05) -> Tensor:
51
52
  # pink noise: divide freq spectrum by sqrt(f)
52
53
  if x.ndim == 3:
53
54
  x = x.view(-1, x.shape[-1]) # flatten to 2D [B*M, T]
@@ -66,12 +67,12 @@ def add_pink_noise(x: Tensor, noise_level=0.05):
66
67
  return x + pink_noised * noise_level
67
68
 
68
69
 
69
- def add_clipped_gaussian_noise(x, noise_level=0.025):
70
+ def add_clipped_gaussian_noise(x: Tensor, noise_level: float = 0.025) -> Tensor:
70
71
  noise = torch.randn_like(x) * noise_level
71
72
  return torch.clamp(x + noise, 0.0, 1.0)
72
73
 
73
74
 
74
- def add_multiplicative_noise(x, noise_level=0.025):
75
+ def add_multiplicative_noise(x: Tensor, noise_level: float = 0.025) -> Tensor:
75
76
  noise = 1 + torch.randn_like(x) * noise_level
76
77
  return x * noise
77
78
 
@@ -109,7 +110,15 @@ _NOISE_DIM_SUPPORT = {
109
110
 
110
111
  def apply_noise(
111
112
  x: Tensor,
112
- noise_type: str = "gaussian",
113
+ noise_type: Literal[
114
+ "gaussian",
115
+ "uniform",
116
+ "linear",
117
+ "impulse",
118
+ "pink",
119
+ "clipped_gaussian",
120
+ "multiplicative",
121
+ ] = "gaussian",
113
122
  noise_level: float = 0.01,
114
123
  seed: Optional[int] = None,
115
124
  on_error: Literal["raise", "try_others", "return_unchanged"] = "raise",
@@ -229,11 +238,11 @@ class NoiseSchedulerA(nn.Module):
229
238
  return collected, noise_history
230
239
 
231
240
 
232
- class NoiseSchedulerB(nn.Module):
233
- def __init__(self, timesteps: int = 512):
241
+ class NoiseSchedulerB(Model):
242
+ def __init__(self, timesteps: int = 50, l_min: float = 0.0005, l_max: float = 0.05):
234
243
  super().__init__()
235
244
 
236
- betas = torch.linspace(1e-4, 0.02, timesteps)
245
+ betas = torch.linspace(l_min, l_max, timesteps)
237
246
  alphas = 1.0 - betas
238
247
  alpha_cumprod = torch.cumprod(alphas, dim=0)
239
248
 
@@ -272,7 +281,7 @@ class NoiseSchedulerB(nn.Module):
272
281
  self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
273
282
  ) -> Tensor:
274
283
  assert (
275
- 0 <= t < self.timesteps
284
+ 0 <= t < self.timesteps
276
285
  ), f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
277
286
 
278
287
  if noise is None:
@@ -286,7 +295,7 @@ class NoiseSchedulerB(nn.Module):
286
295
  return alpha_term + noise_term
287
296
 
288
297
 
289
- class NoiseSchedulerC(nn.Module):
298
+ class NoiseSchedulerC(Model):
290
299
  def __init__(self, timesteps: int = 512):
291
300
  super().__init__()
292
301
 
@@ -92,9 +92,17 @@ def _comp_rms_helper(i: int, audio: Tensor, mel: Optional[Tensor]):
92
92
 
93
93
 
94
94
  class AudioProcessor(Model):
95
- def __init__(self, config: AudioProcessorConfig = AudioProcessorConfig()):
95
+ def __init__(
96
+ self,
97
+ config: Union[AudioProcessorConfig, Dict[str, Any]] = AudioProcessorConfig(),
98
+ ):
96
99
  super().__init__()
97
- self.cfg = config
100
+ assert isinstance(config, (AudioProcessorConfig, dict))
101
+ self.cfg = (
102
+ config
103
+ if isinstance(config, AudioProcessorConfig)
104
+ else AudioProcessorConfig(**config)
105
+ )
98
106
  self._mel_spec_torch = torchaudio.transforms.MelSpectrogram(
99
107
  sample_rate=self.cfg.sample_rate,
100
108
  n_mels=self.cfg.n_mels,
@@ -108,14 +116,6 @@ class AudioProcessor(Model):
108
116
  normalized=self.cfg.normalized,
109
117
  )
110
118
 
111
- self._mel_rscale = torchaudio.transforms.InverseMelScale(
112
- n_stft=self.cfg.n_stft,
113
- n_mels=self.cfg.n_mels,
114
- sample_rate=self.cfg.sample_rate,
115
- f_min=self.cfg.f_min,
116
- f_max=self.cfg.f_max,
117
- mel_scale=self.cfg.mel_scale,
118
- )
119
119
  self.mel_lib_padding = (self.cfg.n_fft - self.cfg.hop_length) // 2
120
120
  self.register_buffer(
121
121
  "window",
@@ -134,10 +134,10 @@ class AudioProcessor(Model):
134
134
  ).float(),
135
135
  )
136
136
 
137
- def spectral_norm(self, x: Tensor, c: int = 1, eps: float = 1e-5):
137
+ def spectral_norm(self, x: Tensor, c: int = 1, eps: float = 1e-5) -> Tensor:
138
138
  return torch.log(torch.clamp(x, min=eps) * c)
139
139
 
140
- def spectral_de_norm(self, x: Tensor, c: int = 1):
140
+ def spectral_de_norm(self, x: Tensor, c: int = 1) -> Tensor:
141
141
  return torch.exp(x) / c
142
142
 
143
143
  def log_norm(
@@ -201,7 +201,7 @@ class AudioProcessor(Model):
201
201
  spectral_norm: bool = False,
202
202
  *args,
203
203
  **kwargs,
204
- ):
204
+ ) -> Tensor:
205
205
  if wave.ndim == 1:
206
206
  wave = wave.unsqueeze(0)
207
207
  wave = torch.nn.functional.pad(
@@ -232,15 +232,6 @@ class AudioProcessor(Model):
232
232
  return self.spectral_norm(results, eps=eps).squeeze()
233
233
  return results.squeeze()
234
234
 
235
- def compute_inverse_mel(self, melspec: Tensor, *, _recall=False):
236
- try:
237
- return self._mel_rscale.forward(melspec.to(self.device)).squeeze()
238
- except RuntimeError as e:
239
- if not _recall:
240
- self._mel_rscale.to(self.device)
241
- return self.compute_inverse_mel(melspec, _recall=True)
242
- raise e
243
-
244
235
  def compute_rms(
245
236
  self,
246
237
  audio: Optional[Union[Tensor, np.ndarray]] = None,
@@ -248,7 +239,7 @@ class AudioProcessor(Model):
248
239
  frame_length: Optional[int] = None,
249
240
  hop_length: Optional[int] = None,
250
241
  center: Optional[int] = None,
251
- ):
242
+ ) -> Tensor:
252
243
  assert any([audio is not None, mel is not None])
253
244
  rms_kwargs = dict(
254
245
  frame_length=default(frame_length, self.cfg.n_fft),
@@ -297,7 +288,7 @@ class AudioProcessor(Model):
297
288
  audio: torch.Tensor,
298
289
  sample_rate: Optional[int] = None,
299
290
  n_steps: float = 2.0,
300
- ):
291
+ ) -> Tensor:
301
292
  """
302
293
  Shifts the pitch of an audio tensor by `n_steps` semitones.
303
294
 
@@ -327,24 +318,19 @@ class AudioProcessor(Model):
327
318
  device=src_device, dtype=src_dtype
328
319
  )
329
320
 
330
- @staticmethod
331
- def calc_pitch_fmin(sr: int, frame_length: float):
332
- """For pitch f_min"""
333
- return (sr / (frame_length - 1)) * 2
334
-
335
321
  def compute_pitch(
336
322
  self,
337
323
  audio: Tensor,
338
324
  *,
339
325
  pad_mode: str = "constant",
340
326
  trough_threshold: float = 0.1,
341
- fmin: Optional[float] = None,
342
- fmax: Optional[float] = None,
327
+ fmin: float = librosa.note_to_hz("C2"),
328
+ fmax: float = librosa.note_to_hz("C7"),
343
329
  sr: Optional[float] = None,
344
330
  frame_length: Optional[int] = None,
345
331
  hop_length: Optional[int] = None,
346
332
  center: Optional[bool] = None,
347
- ):
333
+ ) -> Tensor:
348
334
  default_dtype = audio.dtype
349
335
  default_device = audio.device
350
336
  if audio.ndim > 1:
@@ -353,10 +339,7 @@ class AudioProcessor(Model):
353
339
  B = 1
354
340
  sr = default(sr, self.cfg.sample_rate)
355
341
  frame_length = default(frame_length, self.cfg.n_fft)
356
- fmin = max(
357
- default(fmin, self.cfg.default_f_min), self.calc_pitch_fmin(sr, frame_length)
358
- )
359
- fmax = min(max(default(fmax, self.cfg.default_f_max), fmin + 1), sr // 2)
342
+ fmax = min(max(fmax, fmin + 1), sr // 2)
360
343
  hop_length = default(hop_length, self.cfg.hop_length)
361
344
  center = default(center, self.cfg.center)
362
345
  yn_kwargs = dict(
@@ -391,8 +374,8 @@ class AudioProcessor(Model):
391
374
  sr: Optional[float] = None,
392
375
  win_length: Optional[Number] = None,
393
376
  frame_length: Optional[Number] = None,
394
- ):
395
- sr = default(sr, self.sample_rate)
377
+ ) -> Tensor:
378
+ sr = default(sr, self.cfg.sample_rate)
396
379
  win_length = default(win_length, self.cfg.win_length)
397
380
  frame_length = default(frame_length, self.cfg.n_fft)
398
381
  fmin = default(fmin, self.calc_pitch_fmin(sr, frame_length))
@@ -411,7 +394,7 @@ class AudioProcessor(Model):
411
394
  array: np.ndarray,
412
395
  device: Optional[torch.device] = None,
413
396
  dtype: Optional[torch.dtype] = None,
414
- ):
397
+ ) -> Tensor:
415
398
  converted = torch.from_numpy(array)
416
399
  if device is None:
417
400
  device = self.device
@@ -422,13 +405,13 @@ class AudioProcessor(Model):
422
405
  arrays: List[np.ndarray],
423
406
  device: Optional[torch.device] = None,
424
407
  dtype: Optional[torch.dtype] = None,
425
- ):
408
+ ) -> Tensor:
426
409
  stacked = torch.stack([torch.from_numpy(x) for x in arrays])
427
410
  if device is None:
428
411
  device = self.device
429
412
  return stacked.to(device=device, dtype=dtype)
430
413
 
431
- def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]):
414
+ def to_numpy_safe(self, tensor: Union[Tensor, np.ndarray]) -> np.ndarray:
432
415
  if isinstance(tensor, np.ndarray):
433
416
  return tensor
434
417
  return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
@@ -450,7 +433,7 @@ class AudioProcessor(Model):
450
433
  scale_factor: Optional[list[float]] = None,
451
434
  recompute_scale_factor: Optional[bool] = None,
452
435
  antialias: bool = False,
453
- ):
436
+ ) -> Tensor:
454
437
  """
455
438
  The modes available for upsampling are: `nearest`, `linear` (3D-only),
456
439
  `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
@@ -482,7 +465,7 @@ class AudioProcessor(Model):
482
465
  normalized: Optional[bool] = None,
483
466
  onesided: Optional[bool] = None,
484
467
  return_complex: bool = False,
485
- ):
468
+ ) -> Tensor:
486
469
  """Util for models that needs to reconstruct the audio using inverse stft"""
487
470
  window = (
488
471
  torch.hann_window(win_length, device=spec.device)
@@ -513,7 +496,7 @@ class AudioProcessor(Model):
513
496
  normalized: Optional[bool] = None,
514
497
  onesided: Optional[bool] = None,
515
498
  return_complex: bool = False,
516
- ):
499
+ ) -> Tensor:
517
500
  window = (
518
501
  torch.hann_window(win_length, device=wave.device)
519
502
  if win_length is not None and win_length != self.cfg.win_length
@@ -544,7 +527,7 @@ class AudioProcessor(Model):
544
527
  normalized: Optional[bool] = None,
545
528
  onesided: Optional[bool] = None,
546
529
  return_complex: bool = True,
547
- ):
530
+ ) -> Tensor:
548
531
 
549
532
  window = (
550
533
  torch.hann_window(win_length, device=wave.device)
@@ -579,7 +562,7 @@ class AudioProcessor(Model):
579
562
  normalized: Optional[bool] = None,
580
563
  onesided: Optional[bool] = None,
581
564
  return_complex: bool = False,
582
- ):
565
+ ) -> Tensor:
583
566
  window = (
584
567
  torch.hann_window(win_length, device=wave.device)
585
568
  if win_length is not None and win_length != self.cfg.win_length
@@ -619,12 +602,11 @@ class AudioProcessor(Model):
619
602
  self,
620
603
  path: PathLike,
621
604
  top_db: Optional[float] = None,
622
- normalize: bool = False,
623
605
  mono: bool = True,
606
+ istft_norm: bool = True,
607
+ lib_norm: bool = False,
624
608
  *,
625
609
  sample_rate: Optional[float] = None,
626
- hop_length: int = 512,
627
- frame_length: int = 2048,
628
610
  duration: Optional[float] = None,
629
611
  offset: float = 0.0,
630
612
  dtype: Any = np.float32,
@@ -649,14 +631,6 @@ class AudioProcessor(Model):
649
631
  dtype=dtype,
650
632
  res_type=res_type,
651
633
  )
652
- if top_db is not None:
653
- wave, _ = librosa.effects.trim(
654
- wave,
655
- top_db=top_db,
656
- ref=ref,
657
- frame_length=frame_length,
658
- hop_length=hop_length,
659
- )
660
634
  if sr != sample_rate:
661
635
  wave = librosa.resample(
662
636
  wave,
@@ -667,7 +641,9 @@ class AudioProcessor(Model):
667
641
  scale=scale,
668
642
  axis=axis,
669
643
  )
670
- if normalize:
644
+ if top_db is not None:
645
+ wave, _ = librosa.effects.trim(wave, top_db=top_db)
646
+ if lib_norm:
671
647
  wave = librosa.util.normalize(
672
648
  wave,
673
649
  norm=norm,
@@ -675,6 +651,9 @@ class AudioProcessor(Model):
675
651
  threshold=norm_threshold,
676
652
  fill=norm_fill,
677
653
  )
654
+ results = torch.from_numpy(wave).float().unsqueeze(0).to(self.device)
655
+ if istft_norm:
656
+ results = self.istft_norm(results)
678
657
  return torch.from_numpy(wave).float().unsqueeze(0).to(self.device)
679
658
 
680
659
  def find_audios(
@@ -701,9 +680,84 @@ class AudioProcessor(Model):
701
680
  maximum,
702
681
  )
703
682
 
683
+ def audio_to_half(self, audio: Tensor):
684
+ audio = self.to_numpy_safe(audio)
685
+ data: np.ndarray = audio / np.abs(audio).max()
686
+ data = (data * 32767.0).astype(np.int16)
687
+ return self.from_numpy(data, dtype=torch.float16)
688
+
704
689
  def forward(
705
690
  self,
706
- *inputs: Union[Tensor, float],
707
- **inputs_kwargs,
691
+ x: Union[str, Path, Tensor],
692
+ *,
693
+ spectral_norm: bool = False,
694
+ add_batch_to_all: bool = False,
695
+ wave_batch_dim: bool = False,
696
+ mel_batch_dim: bool = False,
697
+ pitch_batch_dim: bool = False,
698
+ rms_batch_dim: bool = False,
699
+ spec_phase_batch_dim: bool = False,
708
700
  ):
709
- return self.compute_mel(*inputs, **inputs_kwargs)
701
+ results = {
702
+ "wave": None,
703
+ "mel": None,
704
+ "pitch": None,
705
+ "rms": None,
706
+ "spec": None,
707
+ "phase": None,
708
+ }
709
+ results["wave"] = (
710
+ x.squeeze()
711
+ if isinstance(x, Tensor)
712
+ else self.load_audio(x, istft_norm=True).squeeze()
713
+ )
714
+ results["mel"] = self.compute_mel_librosa(
715
+ wave=(
716
+ results["wave"]
717
+ if results["wave"].ndim == 3
718
+ else results["wave"].unsqueeze(0)
719
+ ),
720
+ spectral_norm=spectral_norm,
721
+ ).squeeze()
722
+ try:
723
+ results["pitch"] = self.compute_pitch(results["wave"]).squeeze()
724
+ except Exception as e:
725
+ results["pitch"] = e
726
+ try:
727
+ results["rms"] = self.compute_rms(results["wave"], results["mel"]).squeeze()
728
+ except Exception as e:
729
+ results["rms"] = e
730
+ try:
731
+ sp_ph = self.stft(results["wave"], return_complex=False)
732
+ spec, phase = sp_ph.split(1, -1)
733
+ results["spec"] = spec.squeeze()
734
+ results["phase"] = phase.squeeze()
735
+ except Exception as e:
736
+ results["spec"] = e
737
+ results["phase"] = e
738
+
739
+ if (add_batch_to_all or wave_batch_dim) and results["wave"].ndim == 1:
740
+ results["wave"] = results["wave"].unsqueeze(0)
741
+ if (add_batch_to_all or mel_batch_dim) and results["mel"].ndim == 2:
742
+ results["mel"] = results["mel"].unsqueeze(0)
743
+ if (
744
+ isinstance(results["rms"], Tensor)
745
+ and (add_batch_to_all or rms_batch_dim)
746
+ and results["rms"].ndim == 1
747
+ ):
748
+ results["rms"] = results["rms"].unsqueeze(0)
749
+ if (
750
+ isinstance(results["pitch"], Tensor)
751
+ and (add_batch_to_all or pitch_batch_dim)
752
+ and results["pitch"].ndim == 1
753
+ ):
754
+ results["pitch"] = results["pitch"].unsqueeze(0)
755
+ if (
756
+ isinstance(results["spec"], Tensor)
757
+ and (add_batch_to_all or spec_phase_batch_dim)
758
+ and results["spec"].ndim == 2
759
+ ):
760
+ results["spec"] = results["spec"].unsqueeze(0)
761
+ results["phase"] = results["phase"].unsqueeze(0)
762
+
763
+ return results
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a39
3
+ Version: 0.0.1a40
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -33,6 +33,7 @@ lt_tensor/model_zoo/activations/alias_free/resample.py
33
33
  lt_tensor/model_zoo/activations/snake/__init__.py
34
34
  lt_tensor/model_zoo/audio_models/__init__.py
35
35
  lt_tensor/model_zoo/audio_models/resblocks.py
36
+ lt_tensor/model_zoo/audio_models/bemaganv2/__init__.py
36
37
  lt_tensor/model_zoo/audio_models/bigvgan/__init__.py
37
38
  lt_tensor/model_zoo/audio_models/diffwave/__init__.py
38
39
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py
@@ -4,7 +4,7 @@ with open("README.md", "r", encoding="utf-8") as f:
4
4
  long_description = f.read()
5
5
 
6
6
  setup(
7
- version="0.0.1a39",
7
+ version="0.0.1a40",
8
8
  name="lt-tensor",
9
9
  description="General utilities for PyTorch and others. Built for general use.",
10
10
  long_description=long_description,
File without changes
File without changes
File without changes