lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -1,12 +1,14 @@
1
- __all__ = ["Model", "_ModelExtended", "LossTracker"]
1
+ __all__ = ["Model", "LossTracker"]
2
2
 
3
3
  import gc
4
4
  import json
5
5
  import math
6
6
  import warnings
7
- from .torch_commons import *
7
+ from lt_tensor.torch_commons import *
8
8
  from lt_utils.common import *
9
+ from lt_tensor.misc_utils import plot_view, get_weights, get_config, updateDict
9
10
  from lt_utils.misc_utils import log_traceback, get_current_time
11
+ from lt_utils.file_ops import load_json, save_json
10
12
 
11
13
  T = TypeVar("T")
12
14
 
@@ -29,79 +31,47 @@ class LossTracker:
29
31
  "eval": [],
30
32
  }
31
33
 
32
- def append(self, loss: float, mode: Literal["train", "eval"] = "train"):
33
- assert mode in self.history, f"Invalid mode '{mode}'. Use 'train' or 'eval'."
34
+ def append(
35
+ self, loss: float, mode: Union[Literal["train", "eval"], "str"] = "train"
36
+ ):
34
37
  self.history[mode].append(float(loss))
35
38
  if len(self.history[mode]) > self.max_len:
36
39
  self.history[mode] = self.history[mode][-self.max_len :]
37
40
 
38
- def get(self, mode: Literal["train", "eval"] = "train"):
41
+ def get(self, mode: Union[Literal["train", "eval"], "str"] = "train"):
39
42
  return self.history.get(mode, [])
40
43
 
41
44
  def save(self, path: Optional[PathLike] = None):
42
45
  if path is None:
43
46
  path = f"logs/history_{get_current_time()}.json"
44
-
45
- Path(path).parent.mkdir(exist_ok=True, parents=True)
46
- with open(path, "w") as f:
47
- json.dump(self.history, f, indent=2)
48
-
47
+ save_json(path, self.history, indent=2)
49
48
  self.last_file = path
50
49
 
51
50
  def load(self, path: Optional[PathLike] = None):
52
51
  if path is None:
53
- _path = self.last_file
54
- else:
55
- _path = path
56
- with open(_path) as f:
57
- self.history = json.load(f)
58
- if path is not None:
59
- self.last_file = path
60
-
61
- def plot(self, backend: Literal["matplotlib", "plotly"] = "plotly"):
62
- if backend == "plotly":
63
- try:
64
- import plotly.graph_objs as go
65
- except ModuleNotFoundError:
66
- warnings.warn(
67
- "No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
68
- )
69
- return
70
- fig = go.Figure()
71
- for mode, losses in self.history.items():
72
- if losses:
73
- fig.add_trace(go.Scatter(y=losses, name=mode.capitalize()))
74
- fig.update_layout(
75
- title="Training vs Evaluation Loss",
76
- xaxis_title="Step",
77
- yaxis_title="Loss",
78
- template="plotly_dark",
79
- )
80
- fig.show()
81
-
82
- elif backend == "matplotlib":
83
- import matplotlib.pyplot as plt
84
-
85
- for mode, losses in self.history.items():
86
- if losses:
87
- plt.plot(losses, label=f"{mode.capitalize()} Loss")
88
- plt.title("Loss over Time")
89
- plt.xlabel("Step")
90
- plt.ylabel("Loss")
91
- plt.legend()
92
- plt.grid(True)
93
- plt.show()
52
+ path = self.last_file
53
+ self.history = load_json(path, [])
54
+ self.last_file = path
94
55
 
56
+ def plot(
57
+ self,
58
+ history_keys: Union[str, List[str]],
59
+ max_amount: int = 0,
60
+ title: str = "Loss",
61
+ ):
62
+ if isinstance(history_keys, str):
63
+ history_keys = [history_keys]
64
+ return plot_view(
65
+ {k: v for k, v in self.history.items() if k in history_keys},
66
+ title,
67
+ max_amount,
68
+ )
95
69
 
96
- class Model(nn.Module, ABC):
97
- """
98
- This makes it easier to assign a device and retrieves it later
99
- """
100
70
 
71
+ class _Devices_Base(nn.Module):
101
72
  _device: torch.device = ROOT_DEVICE
102
73
  _autocast: bool = False
103
74
  _loss_history: LossTracker = LossTracker(100_000)
104
- _is_unfrozen: bool = False
105
75
 
106
76
  @property
107
77
  def autocast(self):
@@ -119,7 +89,96 @@ class Model(nn.Module, ABC):
119
89
  def device(self, device: Union[torch.device, str]):
120
90
  assert isinstance(device, (str, torch.device))
121
91
  self._device = torch.device(device) if isinstance(device, str) else device
122
- self._apply_device_to()
92
+
93
+ def to(self, *args, **kwargs):
94
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
95
+ *args, **kwargs
96
+ )
97
+
98
+ if dtype is not None:
99
+ if not (dtype.is_floating_point or dtype.is_complex):
100
+ raise TypeError(
101
+ "nn.Module.to only accepts floating point or complex "
102
+ f"dtypes, but got desired dtype={dtype}"
103
+ )
104
+ if dtype.is_complex:
105
+ warnings.warn(
106
+ "Complex modules are a new feature under active development whose design may change, "
107
+ "and some modules might not work as expected when using complex tensors as parameters or buffers. "
108
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
109
+ "if a complex module does not work as expected."
110
+ )
111
+
112
+ def convert(t: Tensor):
113
+ try:
114
+ if convert_to_format is not None and t.dim() in (4, 5):
115
+ return t.to(
116
+ device,
117
+ dtype if t.is_floating_point() or t.is_complex() else None,
118
+ non_blocking,
119
+ memory_format=convert_to_format,
120
+ )
121
+ return t.to(
122
+ device,
123
+ dtype if t.is_floating_point() or t.is_complex() else None,
124
+ non_blocking,
125
+ )
126
+ except NotImplementedError as e:
127
+ if str(e) == "Cannot copy out of meta tensor; no data!":
128
+ raise NotImplementedError(
129
+ f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
130
+ f"when moving module from meta to a different device."
131
+ ) from None
132
+ else:
133
+ raise
134
+
135
+ self._apply(convert)
136
+ self.device = device
137
+ return self
138
+
139
+ def _to_dvc(
140
+ self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
141
+ ):
142
+ device = device_name
143
+ if device_id is not None:
144
+ if isinstance(device_id, Number):
145
+ device += ":" + str(int(device_id))
146
+ elif hasattr(device_id, "index"):
147
+ device += ":" + str(device_id.index)
148
+ self.device = device
149
+
150
+ def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
151
+ super().ipu(device)
152
+ self._to_dvc("ipu", device)
153
+ return self
154
+
155
+ def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
156
+ super().xpu(device)
157
+ self._to_dvc("xpu", device)
158
+ return self
159
+
160
+ def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
161
+ super().cuda(device)
162
+ self._to_dvc("cuda", device)
163
+ return self
164
+
165
+ def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
166
+ super().mtia(device)
167
+ self._to_dvc("mtia", device)
168
+ return self
169
+
170
+ def cpu(self) -> T:
171
+ super().cpu()
172
+ self._to_dvc("cpu", None)
173
+ return self
174
+
175
+
176
+ class Model(_Devices_Base, ABC):
177
+ """
178
+ This makes it easier to assign a device and retrieves it later
179
+ """
180
+
181
+ _is_unfrozen: bool = False
123
182
 
124
183
  def _apply_device_to(self):
125
184
  """Add here components that are needed to have device applied to them,
@@ -239,103 +298,6 @@ class Model(nn.Module, ABC):
239
298
  not_unfrozen.append((name, "excluded"))
240
299
  return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
241
300
 
242
- def to(self, *args, **kwargs):
243
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
244
- *args, **kwargs
245
- )
246
-
247
- if dtype is not None:
248
- if not (dtype.is_floating_point or dtype.is_complex):
249
- raise TypeError(
250
- "nn.Module.to only accepts floating point or complex "
251
- f"dtypes, but got desired dtype={dtype}"
252
- )
253
- if dtype.is_complex:
254
- warnings.warn(
255
- "Complex modules are a new feature under active development whose design may change, "
256
- "and some modules might not work as expected when using complex tensors as parameters or buffers. "
257
- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
258
- "if a complex module does not work as expected."
259
- )
260
-
261
- def convert(t: Tensor):
262
- try:
263
- if convert_to_format is not None and t.dim() in (4, 5):
264
- return t.to(
265
- device,
266
- dtype if t.is_floating_point() or t.is_complex() else None,
267
- non_blocking,
268
- memory_format=convert_to_format,
269
- )
270
- return t.to(
271
- device,
272
- dtype if t.is_floating_point() or t.is_complex() else None,
273
- non_blocking,
274
- )
275
- except NotImplementedError as e:
276
- if str(e) == "Cannot copy out of meta tensor; no data!":
277
- raise NotImplementedError(
278
- f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
279
- f"when moving module from meta to a different device."
280
- ) from None
281
- else:
282
- raise
283
-
284
- self._apply(convert)
285
- self.device = device
286
- self._apply_device_to()
287
- return self
288
-
289
- def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
290
- super().ipu(device)
291
- dvc = "ipu"
292
- if device is not None:
293
- dvc += (
294
- ":" + str(device) if isinstance(device, (int, float)) else device.index
295
- )
296
- self.device = dvc
297
- self._apply_device_to()
298
- return self
299
-
300
- def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
301
- super().xpu(device)
302
- dvc = "xpu"
303
- if device is not None:
304
- dvc += (
305
- ":" + str(device) if isinstance(device, (int, float)) else device.index
306
- )
307
- self.device = dvc
308
- self._apply_device_to()
309
- return self
310
-
311
- def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
312
- super().cuda(device)
313
- dvc = "cuda"
314
- if device is not None:
315
- dvc += (
316
- ":" + str(device) if isinstance(device, (int, float)) else device.index
317
- )
318
- self.device = dvc
319
- self._apply_device_to()
320
- return self
321
-
322
- def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
323
- super().mtia(device)
324
- dvc = "mtia"
325
- if device is not None:
326
- dvc += (
327
- ":" + str(device) if isinstance(device, (int, float)) else device.index
328
- )
329
- self.device = dvc
330
- self._apply_device_to()
331
- return self
332
-
333
- def cpu(self) -> T:
334
- super().cpu()
335
- self.device = "cpu"
336
- self._apply_device_to()
337
- return self
338
-
339
301
  def count_trainable_parameters(self, module_name: Optional[str] = None):
340
302
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
341
303
  if module_name is not None:
@@ -411,6 +373,19 @@ class Model(nn.Module, ABC):
411
373
  else:
412
374
  print(f"Non-Trainable Parameters: {params}")
413
375
 
376
+ @classmethod
377
+ def from_pretrained(
378
+ cls,
379
+ model_path: Union[str, PathLike],
380
+ model_config: Optional[Union[str, PathLike]] = None,
381
+ model_config_dict: Optional[Dict[str, Any]] = None,
382
+ ):
383
+ """TODO: Finish this implementation or leave it as an optional function."""
384
+ assert isinstance(model_path, (str, PathLike, bytes))
385
+ model_path = get_weights(model_path)
386
+ model_config = get_config(model_config, model_config_dict)
387
+ return model_path, model_config
388
+
414
389
  def save_weights(
415
390
  self,
416
391
  path: Union[Path, str],
@@ -494,7 +469,9 @@ class Model(nn.Module, ABC):
494
469
  pass
495
470
 
496
471
  def add_loss(
497
- self, loss: Union[float, list[float]], mode: Literal["train", "eval"] = "train"
472
+ self,
473
+ loss: Union[float, list[float]],
474
+ mode: Union[Literal["train", "eval"]] = "train",
498
475
  ):
499
476
  if isinstance(loss, Number) and loss:
500
477
  self._loss_history.append(loss, mode)
@@ -513,8 +490,12 @@ class Model(nn.Module, ABC):
513
490
  def load_loss_history(self, path: Optional[PathLike] = None):
514
491
  self._loss_history.load(path)
515
492
 
516
- def get_loss_avg(self, mode: Literal["train", "eval"], quantity: int = 0):
517
- t_list = self._loss_history.get("train")
493
+ def get_loss_avg(
494
+ self,
495
+ mode: Union[Literal["train", "eval"], str],
496
+ quantity: int = 0,
497
+ ):
498
+ t_list = self._loss_history.get(mode)
518
499
  if not t_list:
519
500
  return float("nan")
520
501
  if quantity > 0:
@@ -524,9 +505,10 @@ class Model(nn.Module, ABC):
524
505
  def freeze_unfreeze_loss(
525
506
  self,
526
507
  losses: Optional[Union[float, List[float]]] = None,
527
- trigger_loss: float = 0.1,
508
+ trigger_loss: Union[float, bool] = 0.1,
528
509
  excluded_modules: Optional[List[str]] = None,
529
- eval_last: int = 1000,
510
+ max_items: int = 1000,
511
+ loss_name: str = "train",
530
512
  ):
531
513
  """If a certain threshold is reached the weights will freeze or unfreeze the modules.
532
514
  the biggest use-case for this function is when training GANs where the balance
@@ -534,18 +516,29 @@ class Model(nn.Module, ABC):
534
516
 
535
517
  Args:
536
518
  losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
537
- trigger_loss (float, optional): The value where the weights will be either freeze or unfreeze. Defaults to 0.1.
519
+ trigger_loss (float, bool, optional): The value where the weights will be either freeze or unfreeze. If set to a boolean it will freeze or unfreeze immediately according to the value (True = Freeze, False = Unfreeze). Defaults to 0.1.
538
520
  excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
539
- eval_last (float, optional): The number of previous losses to be locked behind to calculate the current averange. Default to 1000.
540
-
521
+ max_items (float, optional): The number of previous losses to be locked behind to calculate the current average. Default to 1000.
522
+ loss_name (str, optional): Responsible to define with key to recover the loss.
541
523
  returns:
542
524
  bool: True when its frozen and false when its trainable.
543
525
  """
544
526
  if losses is not None:
545
- calculated = None
546
527
  self.add_loss(losses)
547
528
 
548
- value = self.get_loss_avg("train", eval_last)
529
+ if isinstance(trigger_loss, bool):
530
+ if trigger_loss:
531
+ if self._is_unfrozen:
532
+ self.freeze_all(excluded_modules)
533
+ self._is_unfrozen = False
534
+ return True
535
+ # else
536
+ if not self._is_unfrozen:
537
+ self.unfreeze_all(excluded_modules)
538
+ self._is_unfrozen = True
539
+ return False
540
+
541
+ value = self.get_loss_avg(loss_name, max_items)
549
542
 
550
543
  if value <= trigger_loss:
551
544
  if self._is_unfrozen:
@@ -557,42 +550,3 @@ class Model(nn.Module, ABC):
557
550
  self.unfreeze_all(excluded_modules)
558
551
  self._is_unfrozen = True
559
552
  return False
560
-
561
-
562
- class _ModelExtended(Model):
563
- """Planed, but not ready, maybe in the near future?"""
564
- criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None
565
- optimizer: Optional[optim.Optimizer] = None
566
-
567
- def train_step(
568
- self,
569
- *inputs,
570
- loss_label: Optional[Tensor] = None,
571
- **kwargs,
572
- ):
573
- if not self.training:
574
- self.train()
575
- if self.optimizer is not None:
576
- self.optimizer.zero_grad()
577
- if self.autocast:
578
- if self.criterion is None:
579
- raise RuntimeError(
580
- "To use autocast during training, you must assign a criterion first!"
581
- )
582
- with torch.autocast(device_type=self.device.type):
583
- out = self.forward(*loss_label, **kwargs)
584
- loss = self.criterion(out, loss_label)
585
-
586
- if self.optimizer is not None:
587
- loss.backward()
588
- self.optimizer.step()
589
- return loss
590
- elif self.criterion is not None:
591
- out = self.forward(*loss_label, **kwargs)
592
- loss = self.criterion(out, loss_label)
593
- if self.optimizer is not None:
594
- loss.backward()
595
- self.optimizer.step()
596
- return loss
597
- else:
598
- return self(*inputs, **kwargs)
@@ -5,7 +5,7 @@ __all__ = [
5
5
  "pos", # positional encoders
6
6
  "fsn", # fusion
7
7
  "gns", # generators
8
- "disc", # discriminators
9
- "istft" # self-explanatory
8
+ "disc", # discriminators
9
+ "istft", # self-explanatory
10
10
  ]
11
11
  from . import bsc, fsn, gns, istft, pos, rsd, tfrms, disc
@@ -10,9 +10,9 @@ __all__ = [
10
10
  "MultiScaleEncoder1D",
11
11
  ]
12
12
 
13
- from ..torch_commons import *
14
- from ..model_base import Model
15
- from ..transform import get_sinusoidal_embedding
13
+ from lt_tensor.torch_commons import *
14
+ from lt_tensor.model_base import Model
15
+ from lt_tensor.transform import get_sinusoidal_embedding
16
16
 
17
17
 
18
18
  class FeedForward(Model):
@@ -208,10 +208,10 @@ class MultiScaleEncoder1D(Model):
208
208
  def forward(self, x: torch.Tensor) -> torch.Tensor:
209
209
  # x: [B, C, T]
210
210
  return self.net(x) # [B, hidden, T]
211
-
212
-
211
+
212
+
213
213
  class AudioClassifier(Model):
214
- def __init__(self, n_mels:int=80, num_classes=5):
214
+ def __init__(self, n_mels: int = 80, num_classes=5):
215
215
  super().__init__()
216
216
  self.model = nn.Sequential(
217
217
  nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
@@ -1,4 +1,4 @@
1
- from ..torch_commons import *
1
+ from lt_tensor.torch_commons import *
2
2
  import torch.nn.functional as F
3
3
  from lt_tensor.model_base import Model
4
4
  from lt_utils.common import *
@@ -6,8 +6,8 @@ __all__ = [
6
6
  "GatedFusion",
7
7
  ]
8
8
 
9
- from ..torch_commons import *
10
- from ..model_base import Model
9
+ from lt_tensor.torch_commons import *
10
+ from lt_tensor.model_base import Model
11
11
 
12
12
 
13
13
  class ConcatFusion(Model):
@@ -7,10 +7,10 @@ __all__ = [
7
7
  "NoisePredictor1D",
8
8
  ]
9
9
 
10
- from ..torch_commons import *
11
- from ..model_base import Model
12
- from .rsd import ResBlock1D, ResBlocks
13
- from ..misc_utils import log_tensor
10
+ from lt_tensor.torch_commons import *
11
+ from lt_tensor.model_base import Model
12
+ from lt_tensor.model_zoo.rsd import ResBlock1D
13
+ from lt_tensor.misc_utils import log_tensor
14
14
 
15
15
  import torch.nn.functional as F
16
16
 
@@ -0,0 +1,5 @@
1
+ from .generator import iSTFTGenerator
2
+ from . import trainer
3
+
4
+
5
+ __all__ = ["iSTFTGenerator", "trainer"]
@@ -0,0 +1,150 @@
1
+ __all__ = ["iSTFTGenerator", "ResBlocks"]
2
+ import gc
3
+ import math
4
+ import itertools
5
+ from lt_utils.common import *
6
+ from lt_tensor.torch_commons import *
7
+ from lt_tensor.model_base import Model
8
+ from lt_tensor.misc_utils import log_tensor
9
+ from lt_tensor.model_zoo.rsd import ResBlock1D, ConvNets, get_weight_norm
10
+ from lt_utils.misc_utils import log_traceback
11
+ from lt_tensor.processors import AudioProcessor
12
+ from lt_utils.type_utils import is_dir, is_pathlike
13
+ from lt_tensor.misc_utils import set_seed, clear_cache
14
+ from lt_tensor.model_zoo.disc import MultiPeriodDiscriminator, MultiScaleDiscriminator
15
+ import torch.nn.functional as F
16
+ from lt_tensor.config_templates import updateDict, ModelConfig
17
+
18
+
19
+ class ResBlocks(ConvNets):
20
+ def __init__(
21
+ self,
22
+ channels: int,
23
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
24
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
25
+ [1, 3, 5],
26
+ [1, 3, 5],
27
+ [1, 3, 5],
28
+ ],
29
+ activation: nn.Module = nn.LeakyReLU(0.1),
30
+ ):
31
+ super().__init__()
32
+ self.num_kernels = len(resblock_kernel_sizes)
33
+ self.rb = nn.ModuleList()
34
+ self.activation = activation
35
+
36
+ for k, j in zip(resblock_kernel_sizes, resblock_dilation_sizes):
37
+ self.rb.append(ResBlock1D(channels, k, j, activation))
38
+
39
+ self.rb.apply(self.init_weights)
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ xs = None
43
+ for i, block in enumerate(self.rb):
44
+ if i == 0:
45
+ xs = block(x)
46
+ else:
47
+ xs += block(x)
48
+ x = xs / self.num_kernels
49
+ return self.activation(x)
50
+
51
+
52
+ class iSTFTGenerator(ConvNets):
53
+ def __init__(
54
+ self,
55
+ in_channels: int = 80,
56
+ upsample_rates: List[Union[int, List[int]]] = [8, 8],
57
+ upsample_kernel_sizes: List[Union[int, List[int]]] = [16, 16],
58
+ upsample_initial_channel: int = 512,
59
+ resblock_kernel_sizes: List[Union[int, List[int]]] = [3, 7, 11],
60
+ resblock_dilation_sizes: List[Union[int, List[int]]] = [
61
+ [1, 3, 5],
62
+ [1, 3, 5],
63
+ [1, 3, 5],
64
+ ],
65
+ n_fft: int = 16,
66
+ activation: nn.Module = nn.LeakyReLU(0.1),
67
+ hop_length: int = 256,
68
+ ):
69
+ super().__init__()
70
+ self.num_kernels = len(resblock_kernel_sizes)
71
+ self.num_upsamples = len(upsample_rates)
72
+ self.hop_length = hop_length
73
+ self.conv_pre = weight_norm(
74
+ nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
75
+ )
76
+ self.blocks = nn.ModuleList()
77
+ self.activation = activation
78
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
79
+ self.blocks.append(
80
+ self._make_blocks(
81
+ (i, k, u),
82
+ upsample_initial_channel,
83
+ resblock_kernel_sizes,
84
+ resblock_dilation_sizes,
85
+ )
86
+ )
87
+
88
+ ch = upsample_initial_channel // (2 ** (i + 1))
89
+ self.post_n_fft = n_fft // 2 + 1
90
+ self.conv_post = weight_norm(nn.Conv1d(ch, n_fft + 2, 7, 1, padding=3))
91
+ self.conv_post.apply(self.init_weights)
92
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
93
+
94
+ self.phase = nn.Sequential(
95
+ nn.LeakyReLU(0.2),
96
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
97
+ nn.LeakyReLU(0.2),
98
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
99
+ )
100
+ self.spec = nn.Sequential(
101
+ nn.LeakyReLU(0.2),
102
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
103
+ nn.LeakyReLU(0.2),
104
+ nn.Conv1d(self.post_n_fft, self.post_n_fft, kernel_size=3, padding=1),
105
+ )
106
+
107
+ def _make_blocks(
108
+ self,
109
+ state: Tuple[int, int, int],
110
+ upsample_initial_channel: int,
111
+ resblock_kernel_sizes: List[Union[int, List[int]]],
112
+ resblock_dilation_sizes: List[int | List[int]],
113
+ ):
114
+ i, k, u = state
115
+ channels = upsample_initial_channel // (2 ** (i + 1))
116
+ return nn.ModuleDict(
117
+ dict(
118
+ up=nn.Sequential(
119
+ self.activation,
120
+ weight_norm(
121
+ nn.ConvTranspose1d(
122
+ upsample_initial_channel // (2**i),
123
+ channels,
124
+ k,
125
+ u,
126
+ padding=(k - u) // 2,
127
+ )
128
+ ).apply(self.init_weights),
129
+ ),
130
+ residual=ResBlocks(
131
+ channels,
132
+ resblock_kernel_sizes,
133
+ resblock_dilation_sizes,
134
+ self.activation,
135
+ ),
136
+ )
137
+ )
138
+
139
+ def forward(self, x):
140
+ x = self.conv_pre(x)
141
+ for block in self.blocks:
142
+ x = block["up"](x)
143
+ x = block["residual"](x)
144
+
145
+ x = self.reflection_pad(x)
146
+ x = self.conv_post(x)
147
+ spec = torch.exp(self.spec(x[:, : self.post_n_fft, :]))
148
+ phase = torch.sin(self.phase(x[:, self.post_n_fft :, :]))
149
+
150
+ return spec, phase