lt-tensor 0.0.1a20__tar.gz → 0.0.1a22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/LICENSE +1 -1
  2. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/PKG-INFO +3 -2
  3. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_base.py +20 -41
  4. lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +1 -0
  5. lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/act.py +55 -0
  6. lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +183 -0
  7. lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +106 -0
  8. lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/snake/__init__.py +129 -0
  9. lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +536 -0
  10. lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/cuda/__init__.py +160 -0
  11. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +130 -6
  12. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/istft/__init__.py +132 -0
  13. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/residual.py +47 -11
  14. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/processors/audio.py +16 -14
  15. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/torch_commons.py +2 -0
  16. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/PKG-INFO +3 -2
  17. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/SOURCES.txt +7 -0
  18. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/requires.txt +2 -1
  19. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/setup.py +3 -2
  20. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/README.md +0 -0
  21. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/__init__.py +0 -0
  22. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/config_templates.py +0 -0
  23. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/losses.py +0 -0
  24. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/lr_schedulers.py +0 -0
  25. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/math_ops.py +0 -0
  26. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/misc_utils.py +0 -0
  27. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/__init__.py +0 -0
  28. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
  29. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
  30. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/basic.py +0 -0
  31. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/features.py +0 -0
  32. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/fusion.py +0 -0
  33. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/pos_encoder.py +0 -0
  34. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/transformer.py +0 -0
  35. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/monotonic_align.py +0 -0
  36. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/noise_tools.py +0 -0
  37. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/processors/__init__.py +0 -0
  38. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/transform.py +0 -0
  39. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/dependency_links.txt +0 -0
  40. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/top_level.txt +0 -0
  41. {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/setup.cfg +0 -0
@@ -186,7 +186,7 @@
186
186
  same "printed page" as the copyright notice for easier
187
187
  identification within third-party archives.
188
188
 
189
- Copyright [2025] [gr1336 (Gabriel Ribeiro)]
189
+ Copyright 2025 gr1336
190
190
 
191
191
  Licensed under the Apache License, Version 2.0 (the "License");
192
192
  you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a20
3
+ Version: 0.0.1a22
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,11 +17,12 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.2a3
20
+ Requires-Dist: lt-utils==0.0.2
21
21
  Requires-Dist: librosa==0.11.*
22
22
  Requires-Dist: einops
23
23
  Requires-Dist: plotly
24
24
  Requires-Dist: scipy
25
+ Requires-Dist: huggingface_hub
25
26
  Dynamic: author
26
27
  Dynamic: classifier
27
28
  Dynamic: description
@@ -194,10 +194,11 @@ class Model(_Devices_Base, ABC):
194
194
  None # Example: lambda: nn.Linear(32, 32, True)
195
195
  )
196
196
  low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
197
+ # never freeze:
198
+ _never_freeze_modules: List[str] = ["low_rank_adapter"]
197
199
 
198
200
  # dont save list:
199
201
  _dont_save_items: List[str] = []
200
- _loss_history: LossTracker = LossTracker(20_000)
201
202
 
202
203
  @property
203
204
  def autocast(self):
@@ -207,12 +208,16 @@ class Model(_Devices_Base, ABC):
207
208
  def autocast(self, value: bool):
208
209
  self._autocast = value
209
210
 
210
- def freeze_all(self, exclude: Optional[List[str]] = None):
211
+ def freeze_all(self, exclude: Optional[List[str]] = None, force: bool = False):
211
212
  no_exclusions = not exclude
212
213
  no_exclusions = not exclude
213
214
  results = []
214
215
  for name, module in self.named_modules():
215
- if name not in self.registered_freezable_modules:
216
+ if (
217
+ name in self._never_freeze_modules
218
+ or not force
219
+ and name not in self.registered_freezable_modules
220
+ ):
216
221
  results.append(
217
222
  (
218
223
  name,
@@ -228,12 +233,17 @@ class Model(_Devices_Base, ABC):
228
233
  results.append((name, "excluded"))
229
234
  return results
230
235
 
231
- def unfreeze_all(self, exclude: Optional[list[str]] = None):
236
+ def unfreeze_all(self, exclude: Optional[list[str]] = None, force: bool = False):
232
237
  """Unfreezes all model parameters except specified layers."""
233
238
  no_exclusions = not exclude
234
239
  results = []
235
240
  for name, module in self.named_modules():
236
- if name not in self.registered_freezable_modules:
241
+ if (
242
+ name in self._never_freeze_modules
243
+ or not force
244
+ and name not in self.registered_freezable_modules
245
+ ):
246
+
237
247
  results.append(
238
248
  (
239
249
  name,
@@ -250,6 +260,9 @@ class Model(_Devices_Base, ABC):
250
260
  return results
251
261
 
252
262
  def change_frozen_state(self, freeze: bool, module: nn.Module):
263
+ assert isinstance(module, nn.Module)
264
+ if module.__class__.__name__ in self._never_freeze_modules:
265
+ return "Not Allowed"
253
266
  try:
254
267
  if isinstance(module, Model):
255
268
  if module._can_be_frozen:
@@ -258,7 +271,7 @@ class Model(_Devices_Base, ABC):
258
271
  return module.unfreeze_all()
259
272
  else:
260
273
  return "Not Allowed"
261
- elif isinstance(module, nn.Module):
274
+ else:
262
275
  module.requires_grad_(not freeze)
263
276
  return not freeze
264
277
  except Exception as e:
@@ -451,7 +464,7 @@ class Model(_Devices_Base, ABC):
451
464
  raise_if_not_exists: bool = False,
452
465
  strict: bool = False,
453
466
  assign: bool = False,
454
- weights_only: bool = True,
467
+ weights_only: bool = False,
455
468
  mmap: Optional[bool] = None,
456
469
  **pickle_load_args,
457
470
  ):
@@ -505,37 +518,3 @@ class Model(_Devices_Base, ABC):
505
518
  self, *args, **kwargs
506
519
  ) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
507
520
  pass
508
-
509
- def add_loss(
510
- self,
511
- loss: Union[float, list[float]],
512
- mode: Union[Literal["train", "eval"]] = "train",
513
- ):
514
- if isinstance(loss, Number) and loss:
515
- self._loss_history.append(loss, mode)
516
- elif isinstance(loss, (list, tuple)):
517
- if loss:
518
- self._loss_history.append(sum(loss) / len(loss), mode=mode)
519
- elif isinstance(loss, Tensor):
520
- try:
521
- self._loss_history.append(loss.detach().flatten().mean().item())
522
- except Exception as e:
523
- log_traceback(e, "add_loss - Tensor")
524
-
525
- def save_loss_history(self, path: Optional[PathLike] = None):
526
- self._loss_history.save(path)
527
-
528
- def load_loss_history(self, path: Optional[PathLike] = None):
529
- self._loss_history.load(path)
530
-
531
- def get_loss_avg(
532
- self,
533
- mode: Union[Literal["train", "eval"], str],
534
- quantity: int = 0,
535
- ):
536
- t_list = self._loss_history.get(mode)
537
- if not t_list:
538
- return float("nan")
539
- if quantity > 0:
540
- t_list = t_list[-quantity:]
541
- return sum(t_list) / len(t_list)
@@ -0,0 +1,55 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .resample import UpSample1d, DownSample1d
5
+ from .resample import UpSample2d, DownSample2d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ activation,
13
+ up_ratio: int = 2,
14
+ down_ratio: int = 2,
15
+ up_kernel_size: int = 12,
16
+ down_kernel_size: int = 12,
17
+ ):
18
+ super().__init__()
19
+ self.up_ratio = up_ratio
20
+ self.down_ratio = down_ratio
21
+ self.act = activation
22
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
23
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
24
+
25
+ # x: [B,C,T]
26
+ def forward(self, x):
27
+ x = self.upsample(x)
28
+ x = self.act(x)
29
+ x = self.downsample(x)
30
+ return x
31
+
32
+
33
+ class Activation2d(nn.Module):
34
+
35
+ def __init__(
36
+ self,
37
+ activation,
38
+ up_ratio: int = 2,
39
+ down_ratio: int = 2,
40
+ up_kernel_size: int = 12,
41
+ down_kernel_size: int = 12,
42
+ ):
43
+ super().__init__()
44
+ self.up_ratio = up_ratio
45
+ self.down_ratio = down_ratio
46
+ self.act = activation
47
+ self.upsample = UpSample2d(up_ratio, up_kernel_size)
48
+ self.downsample = DownSample2d(down_ratio, down_kernel_size)
49
+
50
+ # x: [B,C,W,H]
51
+ def forward(self, x):
52
+ x = self.upsample(x)
53
+ x = self.act(x)
54
+ x = self.downsample(x)
55
+ return x
@@ -0,0 +1,183 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ if "sinc" in dir(torch):
7
+ sinc = torch.sinc
8
+ else:
9
+ # This code is adopted from adefossez's julius.core.sinc
10
+ # https://adefossez.github.io/julius/julius/core.html
11
+ def sinc(x: torch.Tensor):
12
+ """
13
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
14
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
15
+ """
16
+ return torch.where(
17
+ x == 0,
18
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
19
+ torch.sin(math.pi * x) / math.pi / x,
20
+ )
21
+
22
+
23
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters
24
+ # https://adefossez.github.io/julius/julius/lowpass.html
25
+
26
+
27
+ # return filter [1,1,kernel_size]
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
29
+ even = kernel_size % 2 == 0
30
+ half_size = kernel_size // 2
31
+
32
+ # For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.0:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.0:
38
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
39
+ else:
40
+ beta = 0.0
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
43
+ if even:
44
+ time = torch.arange(-half_size, half_size) + 0.5
45
+ else:
46
+ time = torch.arange(kernel_size) - half_size
47
+ if cutoff == 0:
48
+ filter_ = torch.zeros_like(time)
49
+ else:
50
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
51
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
52
+ # of the constant component in the input signal.
53
+ filter_ /= filter_.sum()
54
+ filter = filter_.view(1, 1, kernel_size)
55
+ return filter
56
+
57
+
58
+ def kaiser_sinc_filter2d(cutoff, half_width, kernel_size):
59
+ even = kernel_size % 2 == 0
60
+ half_size = kernel_size // 2
61
+ # For kaiser window
62
+ delta_f = 4 * half_width
63
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
64
+ if A > 50.0:
65
+ beta = 0.1102 * (A - 8.7)
66
+ elif A >= 21.0:
67
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
68
+ else:
69
+ beta = 0.0
70
+
71
+ # rotation equivariant grid
72
+ if even:
73
+ time = torch.stack(
74
+ torch.meshgrid(
75
+ torch.arange(-half_size, half_size) + 0.5,
76
+ torch.arange(-half_size, half_size) + 0.5,
77
+ ),
78
+ dim=-1,
79
+ )
80
+ else:
81
+ time = torch.stack(
82
+ torch.meshgrid(
83
+ torch.arange(kernel_size) - half_size,
84
+ torch.arange(kernel_size) - half_size,
85
+ ),
86
+ dim=-1,
87
+ )
88
+
89
+ time = torch.norm(time, dim=-1)
90
+ # rotation equivariant window
91
+ window = torch.i0(
92
+ beta * torch.sqrt(1 - (time / half_size / 2**0.5) ** 2)
93
+ ) / torch.i0(torch.tensor([beta]))
94
+ # ratio = 0.5/cutroff
95
+ if cutoff == 0:
96
+ filter_ = torch.zeros_like(time)
97
+ else:
98
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
99
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
100
+ # of the constant component in the input signal.
101
+ filter_ /= filter_.sum()
102
+ filter = filter_.view(1, 1, kernel_size, kernel_size)
103
+ return filter
104
+
105
+
106
+ class LowPassFilter1d(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ cutoff=0.5,
111
+ half_width=0.6,
112
+ stride: int = 1,
113
+ padding: bool = True,
114
+ padding_mode: str = "replicate",
115
+ kernel_size: int = 12,
116
+ ):
117
+ # kernel_size should be even number for stylegan3 setup,
118
+ # in this implementation, odd number is also possible.
119
+ super().__init__()
120
+ if cutoff < -0.0:
121
+ raise ValueError("Minimum cutoff must be larger than zero.")
122
+ if cutoff > 0.5:
123
+ raise ValueError("A cutoff above 0.5 does not make sense.")
124
+ self.kernel_size = kernel_size
125
+ self.even = kernel_size % 2 == 0
126
+ self.pad_left = kernel_size // 2 - int(self.even)
127
+ self.pad_right = kernel_size // 2
128
+ self.stride = stride
129
+ self.padding = padding
130
+ self.padding_mode = padding_mode
131
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
132
+ self.register_buffer("filter", filter)
133
+
134
+ # input [B,C,T]
135
+ def forward(self, x):
136
+ _, C, _ = x.shape
137
+ if self.padding:
138
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
139
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
140
+ return out
141
+
142
+
143
+ class LowPassFilter2d(nn.Module):
144
+
145
+ def __init__(
146
+ self,
147
+ cutoff=0.5,
148
+ half_width=0.6,
149
+ stride: int = 1,
150
+ padding: bool = True,
151
+ padding_mode: str = "replicate",
152
+ kernel_size: int = 12,
153
+ ):
154
+ # kernel_size should be even number for stylegan3 setup,
155
+ # in this implementation, odd number is also possible.
156
+ super().__init__()
157
+ if cutoff < -0.0:
158
+ raise ValueError("Minimum cutoff must be larger than zero.")
159
+ if cutoff > 0.5:
160
+ raise ValueError("A cutoff above 0.5 does not make sense.")
161
+ self.kernel_size = kernel_size
162
+ self.even = kernel_size % 2 == 0
163
+ self.pad_left = kernel_size // 2 - int(self.even)
164
+ self.pad_right = kernel_size // 2
165
+ self.stride = stride
166
+ self.padding = padding
167
+ self.padding_mode = padding_mode
168
+ filter = kaiser_sinc_filter2d(cutoff, half_width, kernel_size)
169
+ self.register_buffer("filter", filter)
170
+
171
+ # input [B,C,W,H]
172
+ def forward(self, x):
173
+ _, C, _, _ = x.shape
174
+ if self.padding:
175
+ x = F.pad(
176
+ x,
177
+ (self.pad_left, self.pad_right, self.pad_left, self.pad_right),
178
+ mode=self.padding_mode,
179
+ )
180
+ out = F.conv2d(
181
+ x, self.filter.expand(C, -1, -1, -1), stride=self.stride, groups=C
182
+ )
183
+ return out
@@ -0,0 +1,106 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .filter import LowPassFilter1d, LowPassFilter2d
5
+ from .filter import kaiser_sinc_filter1d, kaiser_sinc_filter2d
6
+
7
+
8
+ class UpSample1d(nn.Module):
9
+
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B,C,T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
31
+ x = self.ratio * F.conv_transpose1d(
32
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
33
+ )
34
+ x = x[..., self.pad_left : -self.pad_right]
35
+ return x
36
+
37
+
38
+ class DownSample1d(nn.Module):
39
+
40
+ def __init__(self, ratio=2, kernel_size=None):
41
+ super().__init__()
42
+ self.ratio = ratio
43
+ self.kernel_size = (
44
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
45
+ )
46
+ self.lowpass = LowPassFilter1d(
47
+ cutoff=0.5 / ratio,
48
+ half_width=0.6 / ratio,
49
+ stride=ratio,
50
+ kernel_size=self.kernel_size,
51
+ )
52
+
53
+ def forward(self, x):
54
+ xx = self.lowpass(x)
55
+ return xx
56
+
57
+
58
+ class UpSample2d(nn.Module):
59
+
60
+ def __init__(self, ratio=2, kernel_size=None):
61
+ super().__init__()
62
+ self.ratio = ratio
63
+ self.kernel_size = (
64
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
65
+ )
66
+ self.stride = ratio
67
+ self.pad = self.kernel_size // 2 - ratio // 2
68
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
69
+ self.pad_right = (
70
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
71
+ )
72
+ filter = kaiser_sinc_filter2d(
73
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
74
+ )
75
+ self.register_buffer("filter", filter)
76
+
77
+ # x: [B,C,W,H]
78
+ def forward(self, x):
79
+ _, C, _, _ = x.shape
80
+ x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode="replicate")
81
+ x = self.ratio**2 * F.conv_transpose2d(
82
+ x, self.filter.expand(C, -1, -1, -1), stride=self.stride, groups=C
83
+ )
84
+ x = x[..., self.pad_left : -self.pad_right, self.pad_left : -self.pad_right]
85
+ return x
86
+
87
+
88
+ class DownSample2d(nn.Module):
89
+
90
+ def __init__(self, ratio=2, kernel_size=None):
91
+ super().__init__()
92
+ self.ratio = ratio
93
+ self.kernel_size = (
94
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
95
+ )
96
+ self.lowpass = LowPassFilter2d(
97
+ cutoff=0.5 / ratio,
98
+ half_width=0.6 / ratio,
99
+ stride=ratio,
100
+ kernel_size=self.kernel_size,
101
+ )
102
+
103
+ # x: [B,C,W,H]
104
+ def forward(self, x):
105
+ xx = self.lowpass(x)
106
+ return xx
@@ -0,0 +1,129 @@
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+
3
+ import torch
4
+ from torch import nn, sin, pow
5
+ from torch.nn import Parameter
6
+
7
+
8
+ class Snake(nn.Module):
9
+ """
10
+ Implementation of a sine-based periodic activation function
11
+ Shape:
12
+ - Input: (B, C, T)
13
+ - Output: (B, C, T), same shape as the input
14
+ Parameters:
15
+ - alpha - trainable parameter
16
+ References:
17
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
18
+ https://arxiv.org/abs/2006.08195
19
+ Examples:
20
+ >>> a1 = snake(256)
21
+ >>> x = torch.randn(256)
22
+ >>> x = a1(x)
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ in_features,
28
+ alpha=1.0,
29
+ alpha_trainable=True,
30
+ alpha_logscale=False,
31
+ ):
32
+ """
33
+ Initialization.
34
+ INPUT:
35
+ - in_features: shape of the input
36
+ - alpha: trainable parameter
37
+ alpha is initialized to 1 by default, higher values = higher-frequency.
38
+ alpha will be trained along with the rest of your model.
39
+ """
40
+ super(Snake, self).__init__()
41
+ self.in_features = in_features
42
+
43
+ # initialize alpha
44
+ self.alpha_logscale = alpha_logscale
45
+ if self.alpha_logscale: # log scale alphas initialized to zeros
46
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
47
+ else: # linear scale alphas initialized to ones
48
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
49
+
50
+ self.alpha.requires_grad = alpha_trainable
51
+
52
+ self.no_div_by_zero = 1e-7
53
+
54
+ def forward(self, x):
55
+ """
56
+ Forward pass of the function.
57
+ Applies the function to the input elementwise.
58
+ Snake ∶= x + 1/a * sin^2 (xa)
59
+ """
60
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
61
+ if self.alpha_logscale:
62
+ alpha = torch.exp(alpha)
63
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
64
+
65
+ return x
66
+
67
+
68
+ class SnakeBeta(nn.Module):
69
+ """
70
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
71
+ Shape:
72
+ - Input: (B, C, T)
73
+ - Output: (B, C, T), same shape as the input
74
+ Parameters:
75
+ - alpha - trainable parameter that controls frequency
76
+ - beta - trainable parameter that controls magnitude
77
+ References:
78
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
79
+ https://arxiv.org/abs/2006.08195
80
+ Examples:
81
+ >>> a1 = snakebeta(256)
82
+ >>> x = torch.randn(256)
83
+ >>> x = a1(x)
84
+ """
85
+
86
+ def __init__(
87
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
88
+ ):
89
+ """
90
+ Initialization.
91
+ INPUT:
92
+ - in_features: shape of the input
93
+ - alpha - trainable parameter that controls frequency
94
+ - beta - trainable parameter that controls magnitude
95
+ alpha is initialized to 1 by default, higher values = higher-frequency.
96
+ beta is initialized to 1 by default, higher values = higher-magnitude.
97
+ alpha will be trained along with the rest of your model.
98
+ """
99
+ super(SnakeBeta, self).__init__()
100
+ self.in_features = in_features
101
+
102
+ # initialize alpha
103
+ self.alpha_logscale = alpha_logscale
104
+ if self.alpha_logscale: # log scale alphas initialized to zeros
105
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
106
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
107
+ else: # linear scale alphas initialized to ones
108
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
109
+ self.beta = Parameter(torch.ones(in_features) * alpha)
110
+
111
+ self.alpha.requires_grad = alpha_trainable
112
+ self.beta.requires_grad = alpha_trainable
113
+
114
+ self.no_div_by_zero = 1e-7
115
+
116
+ def forward(self, x):
117
+ """
118
+ Forward pass of the function.
119
+ Applies the function to the input elementwise.
120
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
121
+ """
122
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
123
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
124
+ if self.alpha_logscale:
125
+ alpha = torch.exp(alpha)
126
+ beta = torch.exp(beta)
127
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
128
+
129
+ return x