lt-tensor 0.0.1a21__tar.gz → 0.0.1a26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/LICENSE +1 -1
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/PKG-INFO +2 -1
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/config_templates.py +9 -5
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/misc_utils.py +15 -3
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_base.py +17 -39
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +1 -0
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/activations/alias_free_torch/act.py +55 -0
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/activations/alias_free_torch/filter.py +183 -0
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/activations/alias_free_torch/resample.py +106 -0
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/activations/snake/__init__.py +129 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/diffwave/__init__.py +41 -14
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/hifigan/__init__.py +90 -8
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/istft/__init__.py +93 -3
- lt_tensor-0.0.1a26/lt_tensor/model_zoo/convs.py +124 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/residual.py +2 -101
- lt_tensor-0.0.1a26/lt_tensor/processors/__init__.py +3 -0
- lt_tensor-0.0.1a26/lt_tensor/processors/audio.py +527 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/torch_commons.py +2 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/PKG-INFO +2 -1
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/SOURCES.txt +6 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/requires.txt +1 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/setup.py +2 -1
- lt_tensor-0.0.1a21/lt_tensor/processors/__init__.py +0 -3
- lt_tensor-0.0.1a21/lt_tensor/processors/audio.py +0 -454
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/README.md +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/__init__.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/losses.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/lr_schedulers.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/math_ops.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/__init__.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/audio_models/__init__.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/basic.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/features.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/fusion.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/pos_encoder.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/model_zoo/transformer.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/monotonic_align.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/noise_tools.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor/transform.py +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/dependency_links.txt +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/lt_tensor.egg-info/top_level.txt +0 -0
- {lt_tensor-0.0.1a21 → lt_tensor-0.0.1a26}/setup.cfg +0 -0
@@ -186,7 +186,7 @@
|
|
186
186
|
same "printed page" as the copyright notice for easier
|
187
187
|
identification within third-party archives.
|
188
188
|
|
189
|
-
Copyright
|
189
|
+
Copyright 2025 gr1336
|
190
190
|
|
191
191
|
Licensed under the Apache License, Version 2.0 (the "License");
|
192
192
|
you may not use this file except in compliance with the License.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: lt-tensor
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a26
|
4
4
|
Summary: General utilities for PyTorch and others. Built for general use.
|
5
5
|
Home-page: https://github.com/gr1336/lt-tensor/
|
6
6
|
Author: gr1336
|
@@ -22,6 +22,7 @@ Requires-Dist: librosa==0.11.*
|
|
22
22
|
Requires-Dist: einops
|
23
23
|
Requires-Dist: plotly
|
24
24
|
Requires-Dist: scipy
|
25
|
+
Requires-Dist: huggingface_hub
|
25
26
|
Dynamic: author
|
26
27
|
Dynamic: classifier
|
27
28
|
Dynamic: description
|
@@ -6,9 +6,7 @@ from lt_tensor.misc_utils import updateDict
|
|
6
6
|
|
7
7
|
|
8
8
|
class ModelConfig(ABC, OrderedDict):
|
9
|
-
_default_settings: Dict[str, Any] = {}
|
10
9
|
_forbidden_list: List[str] = [
|
11
|
-
"_default_settings",
|
12
10
|
"_forbidden_list",
|
13
11
|
]
|
14
12
|
|
@@ -16,12 +14,15 @@ class ModelConfig(ABC, OrderedDict):
|
|
16
14
|
self,
|
17
15
|
**settings,
|
18
16
|
):
|
19
|
-
self.
|
20
|
-
self.set_state_dict(self._default_settings)
|
17
|
+
self.set_state_dict(settings)
|
21
18
|
|
22
19
|
def reset_settings(self):
|
23
20
|
raise NotImplementedError("Not implemented")
|
24
|
-
|
21
|
+
|
22
|
+
def post_process(self):
|
23
|
+
"""Implement the post process, to do a final check to the input data"""
|
24
|
+
pass
|
25
|
+
|
25
26
|
def save_config(
|
26
27
|
self,
|
27
28
|
path: str,
|
@@ -48,6 +49,7 @@ class ModelConfig(ABC, OrderedDict):
|
|
48
49
|
}
|
49
50
|
updateDict(self, new_state)
|
50
51
|
self.update(**new_state)
|
52
|
+
self.post_process()
|
51
53
|
|
52
54
|
def state_dict(self):
|
53
55
|
return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
|
@@ -89,3 +91,5 @@ class ModelConfig(ABC, OrderedDict):
|
|
89
91
|
settings.pop("path_name", None)
|
90
92
|
|
91
93
|
return ModelConfig(**settings)
|
94
|
+
|
95
|
+
|
@@ -111,10 +111,10 @@ def get_weights(directory: Union[str, PathLike]):
|
|
111
111
|
directory = Path(directory)
|
112
112
|
if is_file(directory):
|
113
113
|
if directory.name.endswith((".pt", ".ckpt", ".pth")):
|
114
|
-
return directory
|
114
|
+
return [directory]
|
115
115
|
directory = directory.parent
|
116
116
|
res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
|
117
|
-
return res
|
117
|
+
return res
|
118
118
|
|
119
119
|
|
120
120
|
def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
|
@@ -128,7 +128,19 @@ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
|
|
128
128
|
return load_json(directory, default)
|
129
129
|
return load_yaml(directory, default)
|
130
130
|
directory = directory.parent
|
131
|
-
res = sorted(
|
131
|
+
res = sorted(
|
132
|
+
find_files(
|
133
|
+
directory,
|
134
|
+
[
|
135
|
+
"config*.json",
|
136
|
+
"*config.json",
|
137
|
+
"config*.yml",
|
138
|
+
"*config.yml",
|
139
|
+
"*config.yaml",
|
140
|
+
"config*.yaml",
|
141
|
+
],
|
142
|
+
)
|
143
|
+
)
|
132
144
|
if res:
|
133
145
|
res = res[-1]
|
134
146
|
if Path(res).name.endswith(".json"):
|
@@ -194,10 +194,11 @@ class Model(_Devices_Base, ABC):
|
|
194
194
|
None # Example: lambda: nn.Linear(32, 32, True)
|
195
195
|
)
|
196
196
|
low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
|
197
|
+
# never freeze:
|
198
|
+
_never_freeze_modules: List[str] = ["low_rank_adapter"]
|
197
199
|
|
198
200
|
# dont save list:
|
199
201
|
_dont_save_items: List[str] = []
|
200
|
-
_loss_history: LossTracker = LossTracker(20_000)
|
201
202
|
|
202
203
|
@property
|
203
204
|
def autocast(self):
|
@@ -212,7 +213,11 @@ class Model(_Devices_Base, ABC):
|
|
212
213
|
no_exclusions = not exclude
|
213
214
|
results = []
|
214
215
|
for name, module in self.named_modules():
|
215
|
-
if
|
216
|
+
if (
|
217
|
+
name in self._never_freeze_modules
|
218
|
+
or not force
|
219
|
+
and name not in self.registered_freezable_modules
|
220
|
+
):
|
216
221
|
results.append(
|
217
222
|
(
|
218
223
|
name,
|
@@ -233,7 +238,11 @@ class Model(_Devices_Base, ABC):
|
|
233
238
|
no_exclusions = not exclude
|
234
239
|
results = []
|
235
240
|
for name, module in self.named_modules():
|
236
|
-
if
|
241
|
+
if (
|
242
|
+
name in self._never_freeze_modules
|
243
|
+
or not force
|
244
|
+
and name not in self.registered_freezable_modules
|
245
|
+
):
|
237
246
|
|
238
247
|
results.append(
|
239
248
|
(
|
@@ -251,6 +260,9 @@ class Model(_Devices_Base, ABC):
|
|
251
260
|
return results
|
252
261
|
|
253
262
|
def change_frozen_state(self, freeze: bool, module: nn.Module):
|
263
|
+
assert isinstance(module, nn.Module)
|
264
|
+
if module.__class__.__name__ in self._never_freeze_modules:
|
265
|
+
return "Not Allowed"
|
254
266
|
try:
|
255
267
|
if isinstance(module, Model):
|
256
268
|
if module._can_be_frozen:
|
@@ -259,7 +271,7 @@ class Model(_Devices_Base, ABC):
|
|
259
271
|
return module.unfreeze_all()
|
260
272
|
else:
|
261
273
|
return "Not Allowed"
|
262
|
-
|
274
|
+
else:
|
263
275
|
module.requires_grad_(not freeze)
|
264
276
|
return not freeze
|
265
277
|
except Exception as e:
|
@@ -452,7 +464,7 @@ class Model(_Devices_Base, ABC):
|
|
452
464
|
raise_if_not_exists: bool = False,
|
453
465
|
strict: bool = False,
|
454
466
|
assign: bool = False,
|
455
|
-
weights_only: bool =
|
467
|
+
weights_only: bool = False,
|
456
468
|
mmap: Optional[bool] = None,
|
457
469
|
**pickle_load_args,
|
458
470
|
):
|
@@ -506,37 +518,3 @@ class Model(_Devices_Base, ABC):
|
|
506
518
|
self, *args, **kwargs
|
507
519
|
) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
|
508
520
|
pass
|
509
|
-
|
510
|
-
def add_loss(
|
511
|
-
self,
|
512
|
-
loss: Union[float, list[float]],
|
513
|
-
mode: Union[Literal["train", "eval"]] = "train",
|
514
|
-
):
|
515
|
-
if isinstance(loss, Number) and loss:
|
516
|
-
self._loss_history.append(loss, mode)
|
517
|
-
elif isinstance(loss, (list, tuple)):
|
518
|
-
if loss:
|
519
|
-
self._loss_history.append(sum(loss) / len(loss), mode=mode)
|
520
|
-
elif isinstance(loss, Tensor):
|
521
|
-
try:
|
522
|
-
self._loss_history.append(loss.detach().flatten().mean().item())
|
523
|
-
except Exception as e:
|
524
|
-
log_traceback(e, "add_loss - Tensor")
|
525
|
-
|
526
|
-
def save_loss_history(self, path: Optional[PathLike] = None):
|
527
|
-
self._loss_history.save(path)
|
528
|
-
|
529
|
-
def load_loss_history(self, path: Optional[PathLike] = None):
|
530
|
-
self._loss_history.load(path)
|
531
|
-
|
532
|
-
def get_loss_avg(
|
533
|
-
self,
|
534
|
-
mode: Union[Literal["train", "eval"], str],
|
535
|
-
quantity: int = 0,
|
536
|
-
):
|
537
|
-
t_list = self._loss_history.get(mode)
|
538
|
-
if not t_list:
|
539
|
-
return float("nan")
|
540
|
-
if quantity > 0:
|
541
|
-
t_list = t_list[-quantity:]
|
542
|
-
return sum(t_list) / len(t_list)
|
@@ -0,0 +1 @@
|
|
1
|
+
from . import *
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .resample import UpSample1d, DownSample1d
|
5
|
+
from .resample import UpSample2d, DownSample2d
|
6
|
+
|
7
|
+
|
8
|
+
class Activation1d(nn.Module):
|
9
|
+
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
activation,
|
13
|
+
up_ratio: int = 2,
|
14
|
+
down_ratio: int = 2,
|
15
|
+
up_kernel_size: int = 12,
|
16
|
+
down_kernel_size: int = 12,
|
17
|
+
):
|
18
|
+
super().__init__()
|
19
|
+
self.up_ratio = up_ratio
|
20
|
+
self.down_ratio = down_ratio
|
21
|
+
self.act = activation
|
22
|
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
23
|
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
24
|
+
|
25
|
+
# x: [B,C,T]
|
26
|
+
def forward(self, x):
|
27
|
+
x = self.upsample(x)
|
28
|
+
x = self.act(x)
|
29
|
+
x = self.downsample(x)
|
30
|
+
return x
|
31
|
+
|
32
|
+
|
33
|
+
class Activation2d(nn.Module):
|
34
|
+
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
activation,
|
38
|
+
up_ratio: int = 2,
|
39
|
+
down_ratio: int = 2,
|
40
|
+
up_kernel_size: int = 12,
|
41
|
+
down_kernel_size: int = 12,
|
42
|
+
):
|
43
|
+
super().__init__()
|
44
|
+
self.up_ratio = up_ratio
|
45
|
+
self.down_ratio = down_ratio
|
46
|
+
self.act = activation
|
47
|
+
self.upsample = UpSample2d(up_ratio, up_kernel_size)
|
48
|
+
self.downsample = DownSample2d(down_ratio, down_kernel_size)
|
49
|
+
|
50
|
+
# x: [B,C,W,H]
|
51
|
+
def forward(self, x):
|
52
|
+
x = self.upsample(x)
|
53
|
+
x = self.act(x)
|
54
|
+
x = self.downsample(x)
|
55
|
+
return x
|
@@ -0,0 +1,183 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import math
|
5
|
+
|
6
|
+
if "sinc" in dir(torch):
|
7
|
+
sinc = torch.sinc
|
8
|
+
else:
|
9
|
+
# This code is adopted from adefossez's julius.core.sinc
|
10
|
+
# https://adefossez.github.io/julius/julius/core.html
|
11
|
+
def sinc(x: torch.Tensor):
|
12
|
+
"""
|
13
|
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
14
|
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
15
|
+
"""
|
16
|
+
return torch.where(
|
17
|
+
x == 0,
|
18
|
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
19
|
+
torch.sin(math.pi * x) / math.pi / x,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters
|
24
|
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
25
|
+
|
26
|
+
|
27
|
+
# return filter [1,1,kernel_size]
|
28
|
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
29
|
+
even = kernel_size % 2 == 0
|
30
|
+
half_size = kernel_size // 2
|
31
|
+
|
32
|
+
# For kaiser window
|
33
|
+
delta_f = 4 * half_width
|
34
|
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35
|
+
if A > 50.0:
|
36
|
+
beta = 0.1102 * (A - 8.7)
|
37
|
+
elif A >= 21.0:
|
38
|
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
39
|
+
else:
|
40
|
+
beta = 0.0
|
41
|
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42
|
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
43
|
+
if even:
|
44
|
+
time = torch.arange(-half_size, half_size) + 0.5
|
45
|
+
else:
|
46
|
+
time = torch.arange(kernel_size) - half_size
|
47
|
+
if cutoff == 0:
|
48
|
+
filter_ = torch.zeros_like(time)
|
49
|
+
else:
|
50
|
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
51
|
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
52
|
+
# of the constant component in the input signal.
|
53
|
+
filter_ /= filter_.sum()
|
54
|
+
filter = filter_.view(1, 1, kernel_size)
|
55
|
+
return filter
|
56
|
+
|
57
|
+
|
58
|
+
def kaiser_sinc_filter2d(cutoff, half_width, kernel_size):
|
59
|
+
even = kernel_size % 2 == 0
|
60
|
+
half_size = kernel_size // 2
|
61
|
+
# For kaiser window
|
62
|
+
delta_f = 4 * half_width
|
63
|
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
64
|
+
if A > 50.0:
|
65
|
+
beta = 0.1102 * (A - 8.7)
|
66
|
+
elif A >= 21.0:
|
67
|
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
68
|
+
else:
|
69
|
+
beta = 0.0
|
70
|
+
|
71
|
+
# rotation equivariant grid
|
72
|
+
if even:
|
73
|
+
time = torch.stack(
|
74
|
+
torch.meshgrid(
|
75
|
+
torch.arange(-half_size, half_size) + 0.5,
|
76
|
+
torch.arange(-half_size, half_size) + 0.5,
|
77
|
+
),
|
78
|
+
dim=-1,
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
time = torch.stack(
|
82
|
+
torch.meshgrid(
|
83
|
+
torch.arange(kernel_size) - half_size,
|
84
|
+
torch.arange(kernel_size) - half_size,
|
85
|
+
),
|
86
|
+
dim=-1,
|
87
|
+
)
|
88
|
+
|
89
|
+
time = torch.norm(time, dim=-1)
|
90
|
+
# rotation equivariant window
|
91
|
+
window = torch.i0(
|
92
|
+
beta * torch.sqrt(1 - (time / half_size / 2**0.5) ** 2)
|
93
|
+
) / torch.i0(torch.tensor([beta]))
|
94
|
+
# ratio = 0.5/cutroff
|
95
|
+
if cutoff == 0:
|
96
|
+
filter_ = torch.zeros_like(time)
|
97
|
+
else:
|
98
|
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
99
|
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
100
|
+
# of the constant component in the input signal.
|
101
|
+
filter_ /= filter_.sum()
|
102
|
+
filter = filter_.view(1, 1, kernel_size, kernel_size)
|
103
|
+
return filter
|
104
|
+
|
105
|
+
|
106
|
+
class LowPassFilter1d(nn.Module):
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
cutoff=0.5,
|
111
|
+
half_width=0.6,
|
112
|
+
stride: int = 1,
|
113
|
+
padding: bool = True,
|
114
|
+
padding_mode: str = "replicate",
|
115
|
+
kernel_size: int = 12,
|
116
|
+
):
|
117
|
+
# kernel_size should be even number for stylegan3 setup,
|
118
|
+
# in this implementation, odd number is also possible.
|
119
|
+
super().__init__()
|
120
|
+
if cutoff < -0.0:
|
121
|
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
122
|
+
if cutoff > 0.5:
|
123
|
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
124
|
+
self.kernel_size = kernel_size
|
125
|
+
self.even = kernel_size % 2 == 0
|
126
|
+
self.pad_left = kernel_size // 2 - int(self.even)
|
127
|
+
self.pad_right = kernel_size // 2
|
128
|
+
self.stride = stride
|
129
|
+
self.padding = padding
|
130
|
+
self.padding_mode = padding_mode
|
131
|
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
132
|
+
self.register_buffer("filter", filter)
|
133
|
+
|
134
|
+
# input [B,C,T]
|
135
|
+
def forward(self, x):
|
136
|
+
_, C, _ = x.shape
|
137
|
+
if self.padding:
|
138
|
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
139
|
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
140
|
+
return out
|
141
|
+
|
142
|
+
|
143
|
+
class LowPassFilter2d(nn.Module):
|
144
|
+
|
145
|
+
def __init__(
|
146
|
+
self,
|
147
|
+
cutoff=0.5,
|
148
|
+
half_width=0.6,
|
149
|
+
stride: int = 1,
|
150
|
+
padding: bool = True,
|
151
|
+
padding_mode: str = "replicate",
|
152
|
+
kernel_size: int = 12,
|
153
|
+
):
|
154
|
+
# kernel_size should be even number for stylegan3 setup,
|
155
|
+
# in this implementation, odd number is also possible.
|
156
|
+
super().__init__()
|
157
|
+
if cutoff < -0.0:
|
158
|
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
159
|
+
if cutoff > 0.5:
|
160
|
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
161
|
+
self.kernel_size = kernel_size
|
162
|
+
self.even = kernel_size % 2 == 0
|
163
|
+
self.pad_left = kernel_size // 2 - int(self.even)
|
164
|
+
self.pad_right = kernel_size // 2
|
165
|
+
self.stride = stride
|
166
|
+
self.padding = padding
|
167
|
+
self.padding_mode = padding_mode
|
168
|
+
filter = kaiser_sinc_filter2d(cutoff, half_width, kernel_size)
|
169
|
+
self.register_buffer("filter", filter)
|
170
|
+
|
171
|
+
# input [B,C,W,H]
|
172
|
+
def forward(self, x):
|
173
|
+
_, C, _, _ = x.shape
|
174
|
+
if self.padding:
|
175
|
+
x = F.pad(
|
176
|
+
x,
|
177
|
+
(self.pad_left, self.pad_right, self.pad_left, self.pad_right),
|
178
|
+
mode=self.padding_mode,
|
179
|
+
)
|
180
|
+
out = F.conv2d(
|
181
|
+
x, self.filter.expand(C, -1, -1, -1), stride=self.stride, groups=C
|
182
|
+
)
|
183
|
+
return out
|
@@ -0,0 +1,106 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .filter import LowPassFilter1d, LowPassFilter2d
|
5
|
+
from .filter import kaiser_sinc_filter1d, kaiser_sinc_filter2d
|
6
|
+
|
7
|
+
|
8
|
+
class UpSample1d(nn.Module):
|
9
|
+
|
10
|
+
def __init__(self, ratio=2, kernel_size=None):
|
11
|
+
super().__init__()
|
12
|
+
self.ratio = ratio
|
13
|
+
self.kernel_size = (
|
14
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15
|
+
)
|
16
|
+
self.stride = ratio
|
17
|
+
self.pad = self.kernel_size // ratio - 1
|
18
|
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19
|
+
self.pad_right = (
|
20
|
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
21
|
+
)
|
22
|
+
filter = kaiser_sinc_filter1d(
|
23
|
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
24
|
+
)
|
25
|
+
self.register_buffer("filter", filter)
|
26
|
+
|
27
|
+
# x: [B,C,T]
|
28
|
+
def forward(self, x):
|
29
|
+
_, C, _ = x.shape
|
30
|
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
31
|
+
x = self.ratio * F.conv_transpose1d(
|
32
|
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
33
|
+
)
|
34
|
+
x = x[..., self.pad_left : -self.pad_right]
|
35
|
+
return x
|
36
|
+
|
37
|
+
|
38
|
+
class DownSample1d(nn.Module):
|
39
|
+
|
40
|
+
def __init__(self, ratio=2, kernel_size=None):
|
41
|
+
super().__init__()
|
42
|
+
self.ratio = ratio
|
43
|
+
self.kernel_size = (
|
44
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
45
|
+
)
|
46
|
+
self.lowpass = LowPassFilter1d(
|
47
|
+
cutoff=0.5 / ratio,
|
48
|
+
half_width=0.6 / ratio,
|
49
|
+
stride=ratio,
|
50
|
+
kernel_size=self.kernel_size,
|
51
|
+
)
|
52
|
+
|
53
|
+
def forward(self, x):
|
54
|
+
xx = self.lowpass(x)
|
55
|
+
return xx
|
56
|
+
|
57
|
+
|
58
|
+
class UpSample2d(nn.Module):
|
59
|
+
|
60
|
+
def __init__(self, ratio=2, kernel_size=None):
|
61
|
+
super().__init__()
|
62
|
+
self.ratio = ratio
|
63
|
+
self.kernel_size = (
|
64
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
65
|
+
)
|
66
|
+
self.stride = ratio
|
67
|
+
self.pad = self.kernel_size // 2 - ratio // 2
|
68
|
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
69
|
+
self.pad_right = (
|
70
|
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
71
|
+
)
|
72
|
+
filter = kaiser_sinc_filter2d(
|
73
|
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
74
|
+
)
|
75
|
+
self.register_buffer("filter", filter)
|
76
|
+
|
77
|
+
# x: [B,C,W,H]
|
78
|
+
def forward(self, x):
|
79
|
+
_, C, _, _ = x.shape
|
80
|
+
x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode="replicate")
|
81
|
+
x = self.ratio**2 * F.conv_transpose2d(
|
82
|
+
x, self.filter.expand(C, -1, -1, -1), stride=self.stride, groups=C
|
83
|
+
)
|
84
|
+
x = x[..., self.pad_left : -self.pad_right, self.pad_left : -self.pad_right]
|
85
|
+
return x
|
86
|
+
|
87
|
+
|
88
|
+
class DownSample2d(nn.Module):
|
89
|
+
|
90
|
+
def __init__(self, ratio=2, kernel_size=None):
|
91
|
+
super().__init__()
|
92
|
+
self.ratio = ratio
|
93
|
+
self.kernel_size = (
|
94
|
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
95
|
+
)
|
96
|
+
self.lowpass = LowPassFilter2d(
|
97
|
+
cutoff=0.5 / ratio,
|
98
|
+
half_width=0.6 / ratio,
|
99
|
+
stride=ratio,
|
100
|
+
kernel_size=self.kernel_size,
|
101
|
+
)
|
102
|
+
|
103
|
+
# x: [B,C,W,H]
|
104
|
+
def forward(self, x):
|
105
|
+
xx = self.lowpass(x)
|
106
|
+
return xx
|
@@ -0,0 +1,129 @@
|
|
1
|
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn, sin, pow
|
5
|
+
from torch.nn import Parameter
|
6
|
+
|
7
|
+
|
8
|
+
class Snake(nn.Module):
|
9
|
+
"""
|
10
|
+
Implementation of a sine-based periodic activation function
|
11
|
+
Shape:
|
12
|
+
- Input: (B, C, T)
|
13
|
+
- Output: (B, C, T), same shape as the input
|
14
|
+
Parameters:
|
15
|
+
- alpha - trainable parameter
|
16
|
+
References:
|
17
|
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
18
|
+
https://arxiv.org/abs/2006.08195
|
19
|
+
Examples:
|
20
|
+
>>> a1 = snake(256)
|
21
|
+
>>> x = torch.randn(256)
|
22
|
+
>>> x = a1(x)
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
in_features,
|
28
|
+
alpha=1.0,
|
29
|
+
alpha_trainable=True,
|
30
|
+
alpha_logscale=False,
|
31
|
+
):
|
32
|
+
"""
|
33
|
+
Initialization.
|
34
|
+
INPUT:
|
35
|
+
- in_features: shape of the input
|
36
|
+
- alpha: trainable parameter
|
37
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
38
|
+
alpha will be trained along with the rest of your model.
|
39
|
+
"""
|
40
|
+
super(Snake, self).__init__()
|
41
|
+
self.in_features = in_features
|
42
|
+
|
43
|
+
# initialize alpha
|
44
|
+
self.alpha_logscale = alpha_logscale
|
45
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
46
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
47
|
+
else: # linear scale alphas initialized to ones
|
48
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
49
|
+
|
50
|
+
self.alpha.requires_grad = alpha_trainable
|
51
|
+
|
52
|
+
self.no_div_by_zero = 1e-7
|
53
|
+
|
54
|
+
def forward(self, x):
|
55
|
+
"""
|
56
|
+
Forward pass of the function.
|
57
|
+
Applies the function to the input elementwise.
|
58
|
+
Snake ∶= x + 1/a * sin^2 (xa)
|
59
|
+
"""
|
60
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
61
|
+
if self.alpha_logscale:
|
62
|
+
alpha = torch.exp(alpha)
|
63
|
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
64
|
+
|
65
|
+
return x
|
66
|
+
|
67
|
+
|
68
|
+
class SnakeBeta(nn.Module):
|
69
|
+
"""
|
70
|
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
71
|
+
Shape:
|
72
|
+
- Input: (B, C, T)
|
73
|
+
- Output: (B, C, T), same shape as the input
|
74
|
+
Parameters:
|
75
|
+
- alpha - trainable parameter that controls frequency
|
76
|
+
- beta - trainable parameter that controls magnitude
|
77
|
+
References:
|
78
|
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
79
|
+
https://arxiv.org/abs/2006.08195
|
80
|
+
Examples:
|
81
|
+
>>> a1 = snakebeta(256)
|
82
|
+
>>> x = torch.randn(256)
|
83
|
+
>>> x = a1(x)
|
84
|
+
"""
|
85
|
+
|
86
|
+
def __init__(
|
87
|
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
88
|
+
):
|
89
|
+
"""
|
90
|
+
Initialization.
|
91
|
+
INPUT:
|
92
|
+
- in_features: shape of the input
|
93
|
+
- alpha - trainable parameter that controls frequency
|
94
|
+
- beta - trainable parameter that controls magnitude
|
95
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
96
|
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
97
|
+
alpha will be trained along with the rest of your model.
|
98
|
+
"""
|
99
|
+
super(SnakeBeta, self).__init__()
|
100
|
+
self.in_features = in_features
|
101
|
+
|
102
|
+
# initialize alpha
|
103
|
+
self.alpha_logscale = alpha_logscale
|
104
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
105
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
106
|
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
107
|
+
else: # linear scale alphas initialized to ones
|
108
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
109
|
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
110
|
+
|
111
|
+
self.alpha.requires_grad = alpha_trainable
|
112
|
+
self.beta.requires_grad = alpha_trainable
|
113
|
+
|
114
|
+
self.no_div_by_zero = 1e-7
|
115
|
+
|
116
|
+
def forward(self, x):
|
117
|
+
"""
|
118
|
+
Forward pass of the function.
|
119
|
+
Applies the function to the input elementwise.
|
120
|
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
121
|
+
"""
|
122
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
123
|
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
124
|
+
if self.alpha_logscale:
|
125
|
+
alpha = torch.exp(alpha)
|
126
|
+
beta = torch.exp(beta)
|
127
|
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
128
|
+
|
129
|
+
return x
|