lt-tensor 0.0.1a20__tar.gz → 0.0.1a22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/LICENSE +1 -1
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/PKG-INFO +3 -2
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_base.py +20 -41
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +1 -0
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/act.py +55 -0
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +183 -0
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +106 -0
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/activations/snake/__init__.py +129 -0
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +536 -0
- lt_tensor-0.0.1a22/lt_tensor/model_zoo/audio_models/bigvgan/cuda/__init__.py +160 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +130 -6
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/istft/__init__.py +132 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/residual.py +47 -11
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/processors/audio.py +16 -14
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/torch_commons.py +2 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/PKG-INFO +3 -2
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/SOURCES.txt +7 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/requires.txt +2 -1
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/setup.py +3 -2
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/README.md +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/__init__.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/config_templates.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/losses.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/lr_schedulers.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/math_ops.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/misc_utils.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/__init__.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/basic.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/features.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/fusion.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/pos_encoder.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/model_zoo/transformer.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/monotonic_align.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/noise_tools.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/processors/__init__.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor/transform.py +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/dependency_links.txt +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/lt_tensor.egg-info/top_level.txt +0 -0
- {lt_tensor-0.0.1a20 → lt_tensor-0.0.1a22}/setup.cfg +0 -0
@@ -186,7 +186,7 @@
|
|
186
186
|
same "printed page" as the copyright notice for easier
|
187
187
|
identification within third-party archives.
|
188
188
|
|
189
|
-
Copyright
|
189
|
+
Copyright 2025 gr1336
|
190
190
|
|
191
191
|
Licensed under the Apache License, Version 2.0 (the "License");
|
192
192
|
you may not use this file except in compliance with the License.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: lt-tensor
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a22
|
4
4
|
Summary: General utilities for PyTorch and others. Built for general use.
|
5
5
|
Home-page: https://github.com/gr1336/lt-tensor/
|
6
6
|
Author: gr1336
|
@@ -17,11 +17,12 @@ Requires-Dist: numpy>=1.26.4
|
|
17
17
|
Requires-Dist: tokenizers
|
18
18
|
Requires-Dist: pyyaml>=6.0.0
|
19
19
|
Requires-Dist: numba>0.60.0
|
20
|
-
Requires-Dist: lt-utils==0.0.
|
20
|
+
Requires-Dist: lt-utils==0.0.2
|
21
21
|
Requires-Dist: librosa==0.11.*
|
22
22
|
Requires-Dist: einops
|
23
23
|
Requires-Dist: plotly
|
24
24
|
Requires-Dist: scipy
|
25
|
+
Requires-Dist: huggingface_hub
|
25
26
|
Dynamic: author
|
26
27
|
Dynamic: classifier
|
27
28
|
Dynamic: description
|
@@ -194,10 +194,11 @@ class Model(_Devices_Base, ABC):
|
|
194
194
|
None # Example: lambda: nn.Linear(32, 32, True)
|
195
195
|
)
|
196
196
|
low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
|
197
|
+
# never freeze:
|
198
|
+
_never_freeze_modules: List[str] = ["low_rank_adapter"]
|
197
199
|
|
198
200
|
# dont save list:
|
199
201
|
_dont_save_items: List[str] = []
|
200
|
-
_loss_history: LossTracker = LossTracker(20_000)
|
201
202
|
|
202
203
|
@property
|
203
204
|
def autocast(self):
|
@@ -207,12 +208,16 @@ class Model(_Devices_Base, ABC):
|
|
207
208
|
def autocast(self, value: bool):
|
208
209
|
self._autocast = value
|
209
210
|
|
210
|
-
def freeze_all(self, exclude: Optional[List[str]] = None):
|
211
|
+
def freeze_all(self, exclude: Optional[List[str]] = None, force: bool = False):
|
211
212
|
no_exclusions = not exclude
|
212
213
|
no_exclusions = not exclude
|
213
214
|
results = []
|
214
215
|
for name, module in self.named_modules():
|
215
|
-
if
|
216
|
+
if (
|
217
|
+
name in self._never_freeze_modules
|
218
|
+
or not force
|
219
|
+
and name not in self.registered_freezable_modules
|
220
|
+
):
|
216
221
|
results.append(
|
217
222
|
(
|
218
223
|
name,
|
@@ -228,12 +233,17 @@ class Model(_Devices_Base, ABC):
|
|
228
233
|
results.append((name, "excluded"))
|
229
234
|
return results
|
230
235
|
|
231
|
-
def unfreeze_all(self, exclude: Optional[list[str]] = None):
|
236
|
+
def unfreeze_all(self, exclude: Optional[list[str]] = None, force: bool = False):
|
232
237
|
"""Unfreezes all model parameters except specified layers."""
|
233
238
|
no_exclusions = not exclude
|
234
239
|
results = []
|
235
240
|
for name, module in self.named_modules():
|
236
|
-
if
|
241
|
+
if (
|
242
|
+
name in self._never_freeze_modules
|
243
|
+
or not force
|
244
|
+
and name not in self.registered_freezable_modules
|
245
|
+
):
|
246
|
+
|
237
247
|
results.append(
|
238
248
|
(
|
239
249
|
name,
|
@@ -250,6 +260,9 @@ class Model(_Devices_Base, ABC):
|
|
250
260
|
return results
|
251
261
|
|
252
262
|
def change_frozen_state(self, freeze: bool, module: nn.Module):
|
263
|
+
assert isinstance(module, nn.Module)
|
264
|
+
if module.__class__.__name__ in self._never_freeze_modules:
|
265
|
+
return "Not Allowed"
|
253
266
|
try:
|
254
267
|
if isinstance(module, Model):
|
255
268
|
if module._can_be_frozen:
|
@@ -258,7 +271,7 @@ class Model(_Devices_Base, ABC):
|
|
258
271
|
return module.unfreeze_all()
|
259
272
|
else:
|
260
273
|
return "Not Allowed"
|
261
|
-
|
274
|
+
else:
|
262
275
|
module.requires_grad_(not freeze)
|
263
276
|
return not freeze
|
264
277
|
except Exception as e:
|
@@ -451,7 +464,7 @@ class Model(_Devices_Base, ABC):
|
|
451
464
|
raise_if_not_exists: bool = False,
|
452
465
|
strict: bool = False,
|
453
466
|
assign: bool = False,
|
454
|
-
weights_only: bool =
|
467
|
+
weights_only: bool = False,
|
455
468
|
mmap: Optional[bool] = None,
|
456
469
|
**pickle_load_args,
|
457
470
|
):
|
@@ -505,37 +518,3 @@ class Model(_Devices_Base, ABC):
|
|
505
518
|
self, *args, **kwargs
|
506
519
|
) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
|
507
520
|
pass
|
508
|
-
|
509
|
-
def add_loss(
|
510
|
-
self,
|
511
|
-
loss: Union[float, list[float]],
|
512
|
-
mode: Union[Literal["train", "eval"]] = "train",
|
513
|
-
):
|
514
|
-
if isinstance(loss, Number) and loss:
|
515
|
-
self._loss_history.append(loss, mode)
|
516
|
-
elif isinstance(loss, (list, tuple)):
|
517
|
-
if loss:
|
518
|
-
self._loss_history.append(sum(loss) / len(loss), mode=mode)
|
519
|
-
elif isinstance(loss, Tensor):
|
520
|
-
try:
|
521
|
-
self._loss_history.append(loss.detach().flatten().mean().item())
|
522
|
-
except Exception as e:
|
523
|
-
log_traceback(e, "add_loss - Tensor")
|
524
|
-
|
525
|
-
def save_loss_history(self, path: Optional[PathLike] = None):
|
526
|
-
self._loss_history.save(path)
|
527
|
-
|
528
|
-
def load_loss_history(self, path: Optional[PathLike] = None):
|
529
|
-
self._loss_history.load(path)
|
530
|
-
|
531
|
-
def get_loss_avg(
|
532
|
-
self,
|
533
|
-
mode: Union[Literal["train", "eval"], str],
|
534
|
-
quantity: int = 0,
|
535
|
-
):
|
536
|
-
t_list = self._loss_history.get(mode)
|
537
|
-
if not t_list:
|
538
|
-
return float("nan")
|
539
|
-
if quantity > 0:
|
540
|
-
t_list = t_list[-quantity:]
|
541
|
-
return sum(t_list) / len(t_list)
|
@@ -0,0 +1 @@
|
|
1
|
+
from . import *
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .resample import UpSample1d, DownSample1d
|
5
|
+
from .resample import UpSample2d, DownSample2d
|
6
|
+
|
7
|
+
|
8
|
+
class Activation1d(nn.Module):
|
9
|
+
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
activation,
|
13
|
+
up_ratio: int = 2,
|
14
|
+
down_ratio: int = 2,
|
15
|
+
up_kernel_size: int = 12,
|
16
|
+
down_kernel_size: int = 12,
|
17
|
+
):
|
18
|
+
super().__init__()
|
19
|
+
self.up_ratio = up_ratio
|
20
|
+
self.down_ratio = down_ratio
|
21
|
+
self.act = activation
|
22
|
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
23
|
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
24
|
+
|
25
|
+
# x: [B,C,T]
|
26
|
+
def forward(self, x):
|
27
|
+
x = self.upsample(x)
|
28
|
+
x = self.act(x)
|
29
|
+
x = self.downsample(x)
|
30
|
+
return x
|
31
|
+
|
32
|
+
|
33
|
+
class Activation2d(nn.Module):
|
34
|
+
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
activation,
|
38
|
+
up_ratio: int = 2,
|
39
|
+
down_ratio: int = 2,
|
40
|
+
up_kernel_size: int = 12,
|
41
|
+
down_kernel_size: int = 12,
|
42
|
+
):
|
43
|
+
super().__init__()
|
44
|
+
self.up_ratio = up_ratio
|
45
|
+
self.down_ratio = down_ratio
|
46
|
+
self.act = activation
|
47
|
+
self.upsample = UpSample2d(up_ratio, up_kernel_size)
|
48
|
+
self.downsample = DownSample2d(down_ratio, down_kernel_size)
|
49
|
+
|
50
|
+
# x: [B,C,W,H]
|
51
|
+
def forward(self, x):
|
52
|
+
x = self.upsample(x)
|
53
|
+
x = self.act(x)
|
54
|
+
x = self.downsample(x)
|
55
|
+
return x
|
@@ -0,0 +1,183 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import math
|
5
|
+
|
6
|
+
if "sinc" in dir(torch):
|
7
|
+
sinc = torch.sinc
|
8
|
+
else:
|
9
|
+
# This code is adopted from adefossez's julius.core.sinc
|
10
|
+
# https://adefossez.github.io/julius/julius/core.html
|
11
|
+
def sinc(x: torch.Tensor):
|
12
|
+
"""
|
13
|
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
14
|
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
15
|
+
"""
|
16
|
+
return torch.where(
|
17
|
+
x == 0,
|
18
|
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
19
|
+
torch.sin(math.pi * x) / math.pi / x,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters
|
24
|
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
25
|
+
|
26
|
+
|
27
|
+
# return filter [1,1,kernel_size]
|
28
|
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
29
|
+
even = kernel_size % 2 == 0
|
30
|
+
half_size = kernel_size // 2
|
31
|
+
|
32
|
+
# For kaiser window
|
33
|
+
delta_f = 4 * half_width
|
34
|
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35
|
+
if A > 50.0:
|
36
|
+
beta = 0.1102 * (A - 8.7)
|
37
|
+
elif A >= 21.0:
|
38
|
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
39
|
+
else:
|
40
|
+
beta = 0.0
|
41
|
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42
|
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
43
|
+
if even:
|
44
|
+
time = torch.arange(-half_size, half_size) + 0.5
|
45
|
+
else:
|
46
|
+
time = torch.arange(kernel_size) - half_size
|
47
|
+
if cutoff == 0:
|
48
|
+
filter_ = torch.zeros_like(time)
|
49
|
+
else:
|
50
|
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
51
|
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
52
|
+
# of the constant component in the input signal.
|
53
|
+
filter_ /= filter_.sum()
|
54
|
+
filter = filter_.view(1, 1, kernel_size)
|
55
|
+
return filter
|
56
|
+
|
57
|
+
|
58
|
+
def kaiser_sinc_filter2d(cutoff, half_width, kernel_size):
|
59
|
+
even = kernel_size % 2 == 0
|
60
|
+
half_size = kernel_size // 2
|
61
|
+
# For kaiser window
|
62
|
+
delta_f = 4 * half_width
|
63
|
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
64
|
+
if A > 50.0:
|
65
|
+
beta = 0.1102 * (A - 8.7)
|
66
|
+
elif A >= 21.0:
|
67
|
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
68
|
+
else:
|
69
|
+
beta = 0.0
|
70
|
+
|
71
|
+
# rotation equivariant grid
|
72
|
+
if even:
|
73
|
+
time = torch.stack(
|
74
|
+
torch.meshgrid(
|
75
|
+
torch.arange(-half_size, half_size) + 0.5,
|
76
|
+
torch.arange(-half_size, half_size) + 0.5,
|
77
|
+
),
|
78
|
+
dim=-1,
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
time = torch.stack(
|
82
|
+
torch.meshgrid(
|
83
|
+
torch.arange(kernel_size) - half_size,
|
84
|
+
torch.arange(kernel_size) - half_size,
|
85
|
+
),
|
86
|
+
dim=-1,
|
87
|
+
)
|
88
|
+
|
89
|
+
time = torch.norm(time, dim=-1)
|
90
|
+
# rotation equivariant window
|
91
|
+
window = torch.i0(
|
92
|
+
beta * torch.sqrt(1 - (time / half_size / 2**0.5) ** 2)
|
93
|
+
) / torch.i0(torch.tensor([beta]))
|
94
|
+
# ratio = 0.5/cutroff
|
95
|
+
if cutoff == 0:
|
96
|
+
filter_ = torch.zeros_like(time)
|
97
|
+
else:
|
98
|
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
99
|
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
100
|
+
# of the constant component in the input signal.
|
101
|
+
filter_ /= filter_.sum()
|
102
|
+
filter = filter_.view(1, 1, kernel_size, kernel_size)
|
103
|
+
return filter
|
104
|
+
|
105
|
+
|
106
|
+
class LowPassFilter1d(nn.Module):
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
cutoff=0.5,
|
111
|
+
half_width=0.6,
|
112
|
+
stride: int = 1,
|
113
|
+
padding: bool = True,
|
114
|
+
padding_mode: str = "replicate",
|
115
|
+
kernel_size: int = 12,
|
116
|
+
):
|
117
|
+
# kernel_size should be even number for stylegan3 setup,
|
118
|
+
# in this implementation, odd number is also possible.
|
119
|
+
super().__init__()
|
120
|
+
if cutoff < -0.0:
|
121
|
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
122
|
+
if cutoff > 0.5:
|
123
|
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
124
|
+
self.kernel_size = kernel_size
|
125
|
+
self.even = kernel_size % 2 == 0
|
126
|
+
self.pad_left = kernel_size // 2 - int(self.even)
|
127
|
+
self.pad_right = kernel_size // 2
|
128
|
+
self.stride = stride
|
129
|
+
self.padding = padding
|
130
|
+
self.padding_mode = padding_mode
|
131
|
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
132
|
+
self.register_buffer("filter", filter)
|
133
|
+
|
134
|
+
# input [B,C,T]
|
135
|
+
def forward(self, x):
|
136
|
+
_, C, _ = x.shape
|
137
|
+
if self.padding:
|
138
|
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
139
|
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
140
|
+
return out
|
141
|
+
|
142
|
+
|
143
|
+
class LowPassFilter2d(nn.Module):
|
144
|
+
|
145
|
+
def __init__(
|
146
|
+
self,
|
147
|
+
cutoff=0.5,
|
148
|
+
half_width=0.6,
|
149
|
+
stride: int = 1,
|
150
|
+
padding: bool = True,
|
151
|
+
padding_mode: str = "replicate",
|
152
|
+
kernel_size: int = 12,
|
153
|
+
):
|
154
|
+
# kernel_size should be even number for stylegan3 setup,
|
155
|
+
# in this implementation, odd number is also possible.
|
156
|
+
super().__init__()
|
157
|
+
if cutoff < -0.0:
|
158
|
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
159
|
+
if cutoff > 0.5:
|
160
|
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
161
|
+
self.kernel_size = kernel_size
|
162
|
+
self.even = kernel_size % 2 == 0
|
163
|
+
self.pad_left = kernel_size // 2 - int(self.even)
|
164
|
+
self.pad_right = kernel_size // 2
|
165
|
+
self.stride = stride
|
166
|
+
self.padding = padding
|
167
|
+
self.padding_mode = padding_mode
|
168
|
+
filter = kaiser_sinc_filter2d(cutoff, half_width, kernel_size)
|
169
|
+
self.register_buffer("filter", filter)
|
170
|
+
|
171
|
+
# input [B,C,W,H]
|
172
|
+
def forward(self, x):
|
173
|
+
_, C, _, _ = x.shape
|
174
|
+
if self.padding:
|
175
|
+
x = F.pad(
|
176
|
+
x,
|
177
|
+
(self.pad_left, self.pad_right, self.pad_left, self.pad_right),
|
178
|
+
mode=self.padding_mode,
|
179
|
+
)
|
180
|
+
out = F.conv2d(
|
181
|
+
x, self.filter.expand(C, -1, -1, -1), stride=self.stride, groups=C
|
182
|
+
)
|
183
|
+
return out
|
@@ -0,0 +1,106 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .filter import LowPassFilter1d, LowPassFilter2d
|
5
|
+
from .filter import kaiser_sinc_filter1d, kaiser_sinc_filter2d
|
6
|
+
|
7
|
+
|
8
|
+
class UpSample1d(nn.Module):
|
9
|
+
|
10
|
+
def __init__(self, ratio=2, kernel_size=None):
|
11
|
+
super().__init__()
|
12
|
+
self.ratio = ratio
|
13
|
+
self.kernel_size = (
|
14
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15
|
+
)
|
16
|
+
self.stride = ratio
|
17
|
+
self.pad = self.kernel_size // ratio - 1
|
18
|
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19
|
+
self.pad_right = (
|
20
|
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
21
|
+
)
|
22
|
+
filter = kaiser_sinc_filter1d(
|
23
|
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
24
|
+
)
|
25
|
+
self.register_buffer("filter", filter)
|
26
|
+
|
27
|
+
# x: [B,C,T]
|
28
|
+
def forward(self, x):
|
29
|
+
_, C, _ = x.shape
|
30
|
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
31
|
+
x = self.ratio * F.conv_transpose1d(
|
32
|
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
33
|
+
)
|
34
|
+
x = x[..., self.pad_left : -self.pad_right]
|
35
|
+
return x
|
36
|
+
|
37
|
+
|
38
|
+
class DownSample1d(nn.Module):
|
39
|
+
|
40
|
+
def __init__(self, ratio=2, kernel_size=None):
|
41
|
+
super().__init__()
|
42
|
+
self.ratio = ratio
|
43
|
+
self.kernel_size = (
|
44
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
45
|
+
)
|
46
|
+
self.lowpass = LowPassFilter1d(
|
47
|
+
cutoff=0.5 / ratio,
|
48
|
+
half_width=0.6 / ratio,
|
49
|
+
stride=ratio,
|
50
|
+
kernel_size=self.kernel_size,
|
51
|
+
)
|
52
|
+
|
53
|
+
def forward(self, x):
|
54
|
+
xx = self.lowpass(x)
|
55
|
+
return xx
|
56
|
+
|
57
|
+
|
58
|
+
class UpSample2d(nn.Module):
|
59
|
+
|
60
|
+
def __init__(self, ratio=2, kernel_size=None):
|
61
|
+
super().__init__()
|
62
|
+
self.ratio = ratio
|
63
|
+
self.kernel_size = (
|
64
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
65
|
+
)
|
66
|
+
self.stride = ratio
|
67
|
+
self.pad = self.kernel_size // 2 - ratio // 2
|
68
|
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
69
|
+
self.pad_right = (
|
70
|
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
71
|
+
)
|
72
|
+
filter = kaiser_sinc_filter2d(
|
73
|
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
74
|
+
)
|
75
|
+
self.register_buffer("filter", filter)
|
76
|
+
|
77
|
+
# x: [B,C,W,H]
|
78
|
+
def forward(self, x):
|
79
|
+
_, C, _, _ = x.shape
|
80
|
+
x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode="replicate")
|
81
|
+
x = self.ratio**2 * F.conv_transpose2d(
|
82
|
+
x, self.filter.expand(C, -1, -1, -1), stride=self.stride, groups=C
|
83
|
+
)
|
84
|
+
x = x[..., self.pad_left : -self.pad_right, self.pad_left : -self.pad_right]
|
85
|
+
return x
|
86
|
+
|
87
|
+
|
88
|
+
class DownSample2d(nn.Module):
|
89
|
+
|
90
|
+
def __init__(self, ratio=2, kernel_size=None):
|
91
|
+
super().__init__()
|
92
|
+
self.ratio = ratio
|
93
|
+
self.kernel_size = (
|
94
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
95
|
+
)
|
96
|
+
self.lowpass = LowPassFilter2d(
|
97
|
+
cutoff=0.5 / ratio,
|
98
|
+
half_width=0.6 / ratio,
|
99
|
+
stride=ratio,
|
100
|
+
kernel_size=self.kernel_size,
|
101
|
+
)
|
102
|
+
|
103
|
+
# x: [B,C,W,H]
|
104
|
+
def forward(self, x):
|
105
|
+
xx = self.lowpass(x)
|
106
|
+
return xx
|
@@ -0,0 +1,129 @@
|
|
1
|
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn, sin, pow
|
5
|
+
from torch.nn import Parameter
|
6
|
+
|
7
|
+
|
8
|
+
class Snake(nn.Module):
|
9
|
+
"""
|
10
|
+
Implementation of a sine-based periodic activation function
|
11
|
+
Shape:
|
12
|
+
- Input: (B, C, T)
|
13
|
+
- Output: (B, C, T), same shape as the input
|
14
|
+
Parameters:
|
15
|
+
- alpha - trainable parameter
|
16
|
+
References:
|
17
|
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
18
|
+
https://arxiv.org/abs/2006.08195
|
19
|
+
Examples:
|
20
|
+
>>> a1 = snake(256)
|
21
|
+
>>> x = torch.randn(256)
|
22
|
+
>>> x = a1(x)
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
in_features,
|
28
|
+
alpha=1.0,
|
29
|
+
alpha_trainable=True,
|
30
|
+
alpha_logscale=False,
|
31
|
+
):
|
32
|
+
"""
|
33
|
+
Initialization.
|
34
|
+
INPUT:
|
35
|
+
- in_features: shape of the input
|
36
|
+
- alpha: trainable parameter
|
37
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
38
|
+
alpha will be trained along with the rest of your model.
|
39
|
+
"""
|
40
|
+
super(Snake, self).__init__()
|
41
|
+
self.in_features = in_features
|
42
|
+
|
43
|
+
# initialize alpha
|
44
|
+
self.alpha_logscale = alpha_logscale
|
45
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
46
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
47
|
+
else: # linear scale alphas initialized to ones
|
48
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
49
|
+
|
50
|
+
self.alpha.requires_grad = alpha_trainable
|
51
|
+
|
52
|
+
self.no_div_by_zero = 1e-7
|
53
|
+
|
54
|
+
def forward(self, x):
|
55
|
+
"""
|
56
|
+
Forward pass of the function.
|
57
|
+
Applies the function to the input elementwise.
|
58
|
+
Snake ∶= x + 1/a * sin^2 (xa)
|
59
|
+
"""
|
60
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
61
|
+
if self.alpha_logscale:
|
62
|
+
alpha = torch.exp(alpha)
|
63
|
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
64
|
+
|
65
|
+
return x
|
66
|
+
|
67
|
+
|
68
|
+
class SnakeBeta(nn.Module):
|
69
|
+
"""
|
70
|
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
71
|
+
Shape:
|
72
|
+
- Input: (B, C, T)
|
73
|
+
- Output: (B, C, T), same shape as the input
|
74
|
+
Parameters:
|
75
|
+
- alpha - trainable parameter that controls frequency
|
76
|
+
- beta - trainable parameter that controls magnitude
|
77
|
+
References:
|
78
|
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
79
|
+
https://arxiv.org/abs/2006.08195
|
80
|
+
Examples:
|
81
|
+
>>> a1 = snakebeta(256)
|
82
|
+
>>> x = torch.randn(256)
|
83
|
+
>>> x = a1(x)
|
84
|
+
"""
|
85
|
+
|
86
|
+
def __init__(
|
87
|
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
88
|
+
):
|
89
|
+
"""
|
90
|
+
Initialization.
|
91
|
+
INPUT:
|
92
|
+
- in_features: shape of the input
|
93
|
+
- alpha - trainable parameter that controls frequency
|
94
|
+
- beta - trainable parameter that controls magnitude
|
95
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
96
|
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
97
|
+
alpha will be trained along with the rest of your model.
|
98
|
+
"""
|
99
|
+
super(SnakeBeta, self).__init__()
|
100
|
+
self.in_features = in_features
|
101
|
+
|
102
|
+
# initialize alpha
|
103
|
+
self.alpha_logscale = alpha_logscale
|
104
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
105
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
106
|
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
107
|
+
else: # linear scale alphas initialized to ones
|
108
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
109
|
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
110
|
+
|
111
|
+
self.alpha.requires_grad = alpha_trainable
|
112
|
+
self.beta.requires_grad = alpha_trainable
|
113
|
+
|
114
|
+
self.no_div_by_zero = 1e-7
|
115
|
+
|
116
|
+
def forward(self, x):
|
117
|
+
"""
|
118
|
+
Forward pass of the function.
|
119
|
+
Applies the function to the input elementwise.
|
120
|
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
121
|
+
"""
|
122
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
123
|
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
124
|
+
if self.alpha_logscale:
|
125
|
+
alpha = torch.exp(alpha)
|
126
|
+
beta = torch.exp(beta)
|
127
|
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
128
|
+
|
129
|
+
return x
|