lt-tensor 0.0.1a29__py3-none-any.whl → 0.0.1a31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -70,6 +70,7 @@ class LossTracker:
70
70
 
71
71
  class _Devices_Base(nn.Module):
72
72
  _device: torch.device = ROOT_DEVICE
73
+ _setting_device: bool = False
73
74
 
74
75
  @property
75
76
  def device(self):
@@ -136,25 +137,35 @@ class _Devices_Base(nn.Module):
136
137
  f"Item '{module_or_name}' is not a valid module, parameter or tensor."
137
138
  )
138
139
 
139
- def _apply_device(self):
140
- """This may be seem as overkill, but its necessary"""
141
- for modules in self.modules():
142
- try:
143
- modules.to(self.device)
144
- except:
145
- pass
146
-
147
- for buffer in self.buffers():
148
- try:
149
- buffer.to(self.device)
150
- except:
151
- pass
152
-
153
- for tensor in self.parameters():
154
- try:
155
- tensor.to(self.device)
156
- except:
157
- pass
140
+ def apply_device(self):
141
+ """Helps to apply devices towards all the internal components"""
142
+ if self._setting_device:
143
+ return
144
+
145
+ self._setting_device = True
146
+
147
+ try:
148
+ for modules in self.modules():
149
+ try:
150
+ modules.to(self.device)
151
+ except:
152
+ pass
153
+
154
+ for buffer in self.buffers():
155
+ try:
156
+ buffer.to(self.device)
157
+ except:
158
+ pass
159
+
160
+ for tensor in self.parameters():
161
+ try:
162
+ tensor.to(self.device)
163
+ except:
164
+ pass
165
+ except:
166
+ pass
167
+ finally:
168
+ self._setting_device = False
158
169
 
159
170
  def _to_dvc(
160
171
  self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
@@ -166,7 +177,11 @@ class _Devices_Base(nn.Module):
166
177
  elif hasattr(device_id, "index"):
167
178
  device += ":" + str(device_id.index)
168
179
  self.device = device
169
- self._apply_device()
180
+ if not self._setting_device:
181
+ self.apply_device()
182
+
183
+ def _to(self, *args, **kwargs):
184
+ self.to(*args, _internal=True, **kwargs)
170
185
 
171
186
  def to(self, *args, **kwargs):
172
187
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
@@ -378,6 +393,7 @@ class Model(_Devices_Base, ABC):
378
393
  state_dict = self.state_dict()
379
394
  if not save_with_adapters or isinstance(self.low_rank_adapter, nn.Identity):
380
395
  state_dict.pop("low_rank_adapter", None)
396
+ state_dict.pop("_setting_device", None)
381
397
  torch.save(obj=state_dict, f=str(model_dir))
382
398
 
383
399
  def save_lora(
@@ -20,6 +20,7 @@ def get_padding(kernel_size, dilation=1):
20
20
 
21
21
  class MultiDiscriminatorWrapper(ConvNets):
22
22
  """Base for all multi-steps type of discriminators"""
23
+
23
24
  def __init__(self, *args, **kwargs):
24
25
  super().__init__(*args, **kwargs)
25
26
  self.leaky_relu = nn.LeakyReLU(kwargs.get("negative_slope", 0.1))
@@ -207,7 +208,7 @@ class MultiPeriodDiscriminator(MultiDiscriminatorWrapper):
207
208
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
208
209
 
209
210
 
210
- class EnvelopeExtractor(nn.Module):
211
+ class EnvelopeExtractor(Model):
211
212
  """Extracts the amplitude envelope of the audio signal."""
212
213
 
213
214
  def __init__(self, kernel_size=101):
@@ -216,7 +217,7 @@ class EnvelopeExtractor(nn.Module):
216
217
  self.kernel_size = kernel_size
217
218
  self.register_buffer("kernel", torch.ones(1, 1, kernel_size) / kernel_size)
218
219
 
219
- def forward(self, x):
220
+ def forward(self, x: Tensor):
220
221
  # x: (B, 1, T) -> abs(x)
221
222
  envelope = torch.abs(x)
222
223
  # Apply low-pass smoothing (via conv1d)
@@ -274,7 +275,6 @@ class MultiEnvelopeDiscriminator(MultiDiscriminatorWrapper):
274
275
  def forward(self, y, y_hat):
275
276
  y_d_rs, y_d_gs = [], []
276
277
  fmap_rs, fmap_gs = [], []
277
-
278
278
  for i, d in enumerate(self.discriminators):
279
279
  if i != 0:
280
280
  y = self.meanpools[i - 1](y)
@@ -555,7 +555,19 @@ class MultiResolutionDiscriminator(MultiDiscriminatorWrapper):
555
555
 
556
556
 
557
557
  class MultiDiscriminatorStep(Model):
558
- def __init__(self, list_discriminator: List[MultiDiscriminatorWrapper]):
558
+ def __init__(
559
+ self, list_discriminator: List[MultiDiscriminatorWrapper]
560
+ ):
561
+ """Setup example:
562
+ model_d = MultiDiscriminatorStep(
563
+ [
564
+ MultiEnvelopeDiscriminator(),
565
+ MultiBandDiscriminator(),
566
+ MultiResolutionDiscriminator(),
567
+ MultiPeriodDiscriminator(0.5),
568
+ ]
569
+ )
570
+ """
559
571
  super().__init__()
560
572
  self.disc: Sequence[MultiDiscriminatorWrapper] = nn.ModuleList(
561
573
  list_discriminator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a29
3
+ Version: 0.0.1a31
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -4,7 +4,7 @@ lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
4
4
  lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
5
  lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
6
6
  lt_tensor/misc_utils.py,sha256=N2r3UmxC4RM2BZBQhpjDZ_BKLrzsyIlKzopTzJbnjFU,28962
7
- lt_tensor/model_base.py,sha256=8qcFXe0_y8f1_tAwt18gwQjyyapbnVEKcjCMrKnQatw,17614
7
+ lt_tensor/model_base.py,sha256=5T4dbAh4MXbQmPRpihGtMYwTY8sJTQOhY6An3VboM58,18086
8
8
  lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
9
9
  lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
10
  lt_tensor/torch_commons.py,sha256=8l0bxmrAzwvyqjivCIVISXlbvKarlg4DdE0BOGSnMuQ,812
@@ -27,11 +27,11 @@ lt_tensor/model_zoo/audio_models/diffwave/__init__.py,sha256=PDuDYN1omD1RoAXcmxH
27
27
  lt_tensor/model_zoo/audio_models/hifigan/__init__.py,sha256=7GJqKLw7-juXpfp5IFzjASLut0uouDhjZ1CQknf3H68,16533
28
28
  lt_tensor/model_zoo/audio_models/istft/__init__.py,sha256=ltIuD9t1gmS3bTmCqZIwJHKrhC6DYya3OaXlskWX9kw,17606
29
29
  lt_tensor/model_zoo/losses/__init__.py,sha256=B9RAUxBiOZwooztnij1oLeRwZ7_MjnN3mPoum7saD6s,59
30
- lt_tensor/model_zoo/losses/discriminators.py,sha256=yYh7HzRTUtr0RVTG7cWpcYsJZsRCz6yzg6Loq8FtyOk,20405
30
+ lt_tensor/model_zoo/losses/discriminators.py,sha256=ZA7Qqrhe8kELrI1-IITadGSl8JCgpgPKFCW6qvSOk1E,20724
31
31
  lt_tensor/processors/__init__.py,sha256=Pvxhh0KR65zLCgUd53_k5Z0y5JWWcO0ZBXFK9rv0o5w,109
32
32
  lt_tensor/processors/audio.py,sha256=1JuxxexfUsXkLjVjWUk-oTRU-QNnCCwvKX3eP0m7LGE,16452
33
- lt_tensor-0.0.1a29.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
34
- lt_tensor-0.0.1a29.dist-info/METADATA,sha256=F03dNMnEydcKjjZF3IntNaIj34FwLdoy-L0pBB_yz0E,1062
35
- lt_tensor-0.0.1a29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- lt_tensor-0.0.1a29.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
- lt_tensor-0.0.1a29.dist-info/RECORD,,
33
+ lt_tensor-0.0.1a31.dist-info/licenses/LICENSE,sha256=tQHc38scHOba4kDBNG4U0U6PpObaloiZG-FvKSgv2b0,11336
34
+ lt_tensor-0.0.1a31.dist-info/METADATA,sha256=qhs6RI_KE0LNjKiD7kky9d_4wcolq6PHyInLPVePzNo,1062
35
+ lt_tensor-0.0.1a31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ lt_tensor-0.0.1a31.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
37
+ lt_tensor-0.0.1a31.dist-info/RECORD,,