homa 0.1.94__tar.gz → 0.1.99__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {homa-0.1.94 → homa-0.1.99}/PKG-INFO +1 -1
- {homa-0.1.94 → homa-0.1.99}/pyproject.toml +1 -1
- homa-0.1.99/src/homa/activations/classes/APLU.py +86 -0
- homa-0.1.99/src/homa/activations/classes/GALU.py +67 -0
- homa-0.1.99/src/homa/activations/classes/MELU.py +70 -0
- homa-0.1.99/src/homa/activations/classes/PDELU.py +54 -0
- homa-0.1.99/src/homa/activations/classes/SReLU.py +69 -0
- homa-0.1.99/src/homa/activations/classes/SmallGALU.py +58 -0
- homa-0.1.99/src/homa/activations/classes/WideMELU.py +90 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa.egg-info/PKG-INFO +1 -1
- homa-0.1.94/src/homa/activations/classes/APLU.py +0 -54
- homa-0.1.94/src/homa/activations/classes/GALU.py +0 -53
- homa-0.1.94/src/homa/activations/classes/MELU.py +0 -52
- homa-0.1.94/src/homa/activations/classes/PDELU.py +0 -41
- homa-0.1.94/src/homa/activations/classes/SReLU.py +0 -53
- homa-0.1.94/src/homa/activations/classes/SmallGALU.py +0 -41
- homa-0.1.94/src/homa/activations/classes/WideMELU.py +0 -63
- {homa-0.1.94 → homa-0.1.99}/README.md +0 -0
- {homa-0.1.94 → homa-0.1.99}/setup.cfg +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/activations/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/activations/classes/StochasticActivation.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/activations/classes/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/activations/utils.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/cli/HomaCommand.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/cli/namespaces/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/device.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/Ensemble.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/PredictsProbabilities.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/ReportsLogits.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/ReportsSize.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/StoresModels.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/ensemble/concerns/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/loss/LogitNormLoss.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/loss/Loss.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/loss/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/settings.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/torch/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/torch/helpers.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/utils.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/ClassificationModel.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/Model.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/Resnet.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/StochasticResnet.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/StochasticSwin.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/Swin.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/HasLabels.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/HasLogits.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/HasProbabilities.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/ReportsAccuracy.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/ReportsMetrics.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/Trainable.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/concerns/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/modules/ResnetModule.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/modules/SwinModule.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/modules/__init__.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa/vision/utils.py +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa.egg-info/SOURCES.txt +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa.egg-info/dependency_links.txt +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa.egg-info/entry_points.txt +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa.egg-info/requires.txt +0 -0
- {homa-0.1.94 → homa-0.1.99}/src/homa.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "homa"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.99"
|
|
8
8
|
description = "A curated list of machine learning and deep learning helpers."
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class APLU(nn.Module):
|
|
8
|
+
def __init__(self, max_input: float = 1.0):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.max_input = float(max_input)
|
|
11
|
+
self.alpha = UninitializedParameter()
|
|
12
|
+
self.beta = UninitializedParameter()
|
|
13
|
+
self.gamma = UninitializedParameter()
|
|
14
|
+
self.xi = UninitializedParameter()
|
|
15
|
+
self.psi = UninitializedParameter()
|
|
16
|
+
self.mu = UninitializedParameter()
|
|
17
|
+
self._num_channels = None
|
|
18
|
+
|
|
19
|
+
def _initialize_parameters(self, x: torch.Tensor):
|
|
20
|
+
if x.ndim < 2:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
channels = int(x.shape[1])
|
|
26
|
+
self._num_channels = channels
|
|
27
|
+
param_shape = [1] * x.ndim
|
|
28
|
+
param_shape[1] = channels
|
|
29
|
+
|
|
30
|
+
with torch.no_grad():
|
|
31
|
+
self.alpha = Parameter(
|
|
32
|
+
torch.zeros(param_shape, dtype=x.dtype, device=x.device)
|
|
33
|
+
)
|
|
34
|
+
self.beta = Parameter(
|
|
35
|
+
torch.zeros(param_shape, dtype=x.dtype, device=x.device)
|
|
36
|
+
)
|
|
37
|
+
self.gamma = Parameter(
|
|
38
|
+
torch.zeros(param_shape, dtype=x.dtype, device=x.device)
|
|
39
|
+
)
|
|
40
|
+
self.xi = Parameter(
|
|
41
|
+
torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
|
|
42
|
+
0.0, self.max_input
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
self.psi = Parameter(
|
|
46
|
+
torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
|
|
47
|
+
0.0, self.max_input
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
self.mu = Parameter(
|
|
51
|
+
torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
|
|
52
|
+
0.0, self.max_input
|
|
53
|
+
)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def reset_parameters(self):
|
|
57
|
+
if isinstance(self.alpha, UninitializedParameter):
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
with torch.no_grad():
|
|
61
|
+
self.alpha.zero_()
|
|
62
|
+
self.beta.zero_()
|
|
63
|
+
self.gamma.zero_()
|
|
64
|
+
self.xi.uniform_(0.0, self.max_input)
|
|
65
|
+
self.psi.uniform_(0.0, self.max_input)
|
|
66
|
+
self.mu.uniform_(0.0, self.max_input)
|
|
67
|
+
|
|
68
|
+
def forward(self, x: torch.Tensor):
|
|
69
|
+
if isinstance(self.alpha, UninitializedParameter):
|
|
70
|
+
self._initialize_parameters(x)
|
|
71
|
+
|
|
72
|
+
if x.ndim < 2:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
75
|
+
)
|
|
76
|
+
if self._num_channels is not None and x.shape[1] != self._num_channels:
|
|
77
|
+
raise RuntimeError(
|
|
78
|
+
f"APLU was initialized with C={self._num_channels} but got C={x.shape[1]}. "
|
|
79
|
+
"Create a new APLU for a different channel size."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
a = F.relu(x)
|
|
83
|
+
b = self.alpha * F.relu(-x + self.xi)
|
|
84
|
+
c = self.beta * F.relu(-x + self.psi)
|
|
85
|
+
d = self.gamma * F.relu(-x + self.mu)
|
|
86
|
+
return a + b + c + d
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GALU(nn.Module):
|
|
8
|
+
def __init__(self, max_input: float = 1.0):
|
|
9
|
+
super().__init__()
|
|
10
|
+
if max_input <= 0:
|
|
11
|
+
raise ValueError("max_input must be positive.")
|
|
12
|
+
self.max_input = float(max_input)
|
|
13
|
+
self.alpha: torch.Tensor = UninitializedParameter()
|
|
14
|
+
self.beta: torch.Tensor = UninitializedParameter()
|
|
15
|
+
self.gamma: torch.Tensor = UninitializedParameter()
|
|
16
|
+
self.delta: torch.Tensor = UninitializedParameter()
|
|
17
|
+
|
|
18
|
+
def _initialize_parameters(self, x: torch.Tensor):
|
|
19
|
+
if x.ndim < 2:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
22
|
+
)
|
|
23
|
+
param_shape = [1] * x.ndim
|
|
24
|
+
param_shape[1] = int(x.shape[1])
|
|
25
|
+
zeros = torch.zeros(param_shape, dtype=x.dtype, device=x.device)
|
|
26
|
+
with torch.no_grad():
|
|
27
|
+
for name in ("alpha", "beta", "gamma", "delta"):
|
|
28
|
+
setattr(self, name, Parameter(zeros.clone()))
|
|
29
|
+
|
|
30
|
+
def reset_parameters(self):
|
|
31
|
+
for name in ("alpha", "beta", "gamma", "delta"):
|
|
32
|
+
p = getattr(self, name)
|
|
33
|
+
if not isinstance(p, UninitializedParameter):
|
|
34
|
+
with torch.no_grad():
|
|
35
|
+
p.zero_()
|
|
36
|
+
|
|
37
|
+
def forward(self, x: torch.Tensor):
|
|
38
|
+
if isinstance(self.alpha, UninitializedParameter):
|
|
39
|
+
self._initialize_parameters(x)
|
|
40
|
+
|
|
41
|
+
if x.ndim < 2:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
44
|
+
)
|
|
45
|
+
if not isinstance(self.alpha, UninitializedParameter) and x.shape[1] != self.alpha.shape[1]:
|
|
46
|
+
raise RuntimeError(
|
|
47
|
+
f"GALU was initialized with C={self.alpha.shape[1]} but got C={x.shape[1]}. "
|
|
48
|
+
"Create a new GALU for a different channel size."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
x_norm = x / self.max_input
|
|
52
|
+
zero = x.new_zeros(1)
|
|
53
|
+
part_prelu = F.relu(x_norm) + self.alpha * torch.minimum(x_norm, zero)
|
|
54
|
+
part_beta = self.beta * (
|
|
55
|
+
F.relu(1.0 - torch.abs(x_norm - 1.0))
|
|
56
|
+
+ torch.minimum(torch.abs(x_norm - 3.0) - 1.0, zero)
|
|
57
|
+
)
|
|
58
|
+
part_gamma = self.gamma * (
|
|
59
|
+
F.relu(0.5 - torch.abs(x_norm - 0.5))
|
|
60
|
+
+ torch.minimum(torch.abs(x_norm - 1.5) - 0.5, zero)
|
|
61
|
+
)
|
|
62
|
+
part_delta = self.delta * (
|
|
63
|
+
F.relu(0.5 - torch.abs(x_norm - 2.5))
|
|
64
|
+
+ torch.minimum(torch.abs(x_norm - 3.5) - 0.5, zero)
|
|
65
|
+
)
|
|
66
|
+
z = part_prelu + part_beta + part_gamma + part_delta
|
|
67
|
+
return z * self.max_input
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MELU(nn.Module):
|
|
7
|
+
def __init__(self, maxInput: float = 1.0):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.maxInput = float(maxInput)
|
|
10
|
+
self._num_channels = None
|
|
11
|
+
self.register_parameter("alpha", None)
|
|
12
|
+
self.register_parameter("beta", None)
|
|
13
|
+
self.register_parameter("gamma", None)
|
|
14
|
+
self.register_parameter("delta", None)
|
|
15
|
+
self.register_parameter("xi", None)
|
|
16
|
+
self.register_parameter("psi", None)
|
|
17
|
+
|
|
18
|
+
def _ensure_parameters(self, x: torch.Tensor):
|
|
19
|
+
if x.dim() != 4:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
|
|
22
|
+
)
|
|
23
|
+
c = int(x.shape[1])
|
|
24
|
+
if self._num_channels is None:
|
|
25
|
+
self._num_channels = c
|
|
26
|
+
elif c != self._num_channels:
|
|
27
|
+
raise RuntimeError(
|
|
28
|
+
f"MELU was initialized with C={self._num_channels} but got C={c}. "
|
|
29
|
+
"Create a new MELU for a different channel size."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if self.alpha is None:
|
|
33
|
+
shape = (1, c, 1, 1)
|
|
34
|
+
device, dtype = x.device, x.dtype
|
|
35
|
+
for name in ("alpha", "beta", "gamma", "delta", "xi", "psi"):
|
|
36
|
+
setattr(
|
|
37
|
+
self,
|
|
38
|
+
name,
|
|
39
|
+
nn.Parameter(torch.zeros(shape, dtype=dtype, device=device)),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def reset_parameters(self):
|
|
43
|
+
for p in (self.alpha, self.beta, self.gamma, self.delta, self.xi, self.psi):
|
|
44
|
+
if p is not None:
|
|
45
|
+
with torch.no_grad():
|
|
46
|
+
p.zero_()
|
|
47
|
+
|
|
48
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
self._ensure_parameters(X)
|
|
50
|
+
|
|
51
|
+
X_norm = X / self.maxInput
|
|
52
|
+
Y = torch.roll(X_norm, shifts=-1, dims=1)
|
|
53
|
+
|
|
54
|
+
term1 = F.relu(X_norm)
|
|
55
|
+
term2 = self.alpha * torch.clamp(X_norm, max=0)
|
|
56
|
+
|
|
57
|
+
dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
|
|
58
|
+
dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
|
|
59
|
+
dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
|
|
60
|
+
dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
|
|
61
|
+
dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
|
|
62
|
+
|
|
63
|
+
term3 = self.beta * torch.sqrt(F.relu(2 - dist_sq_beta))
|
|
64
|
+
term4 = self.gamma * torch.sqrt(F.relu(1 - dist_sq_gamma))
|
|
65
|
+
term5 = self.delta * torch.sqrt(F.relu(1 - dist_sq_delta))
|
|
66
|
+
term6 = self.xi * torch.sqrt(F.relu(1 - dist_sq_xi))
|
|
67
|
+
term7 = self.psi * torch.sqrt(F.relu(1 - dist_sq_psi))
|
|
68
|
+
|
|
69
|
+
Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7
|
|
70
|
+
return Z_norm * self.maxInput
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PDELU(nn.Module):
|
|
7
|
+
def __init__(self, theta: float = 0.5):
|
|
8
|
+
super().__init__()
|
|
9
|
+
if theta == 1.0:
|
|
10
|
+
raise ValueError(
|
|
11
|
+
"theta cannot be 1.0, as it would cause a division by zero."
|
|
12
|
+
)
|
|
13
|
+
self.theta = float(theta)
|
|
14
|
+
self._power_val = 1.0 / (1.0 - self.theta)
|
|
15
|
+
self.register_parameter("alpha", None)
|
|
16
|
+
self._num_channels = None
|
|
17
|
+
|
|
18
|
+
def _ensure_parameters(self, x: torch.Tensor):
|
|
19
|
+
if x.ndim < 2:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
c = int(x.shape[1])
|
|
25
|
+
if self._num_channels is None:
|
|
26
|
+
self._num_channels = c
|
|
27
|
+
elif c != self._num_channels:
|
|
28
|
+
raise RuntimeError(
|
|
29
|
+
f"PDELU was initialized with C={self._num_channels} but got C={c}. "
|
|
30
|
+
"Create a new PDELU for a different channel size."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if self.alpha is None:
|
|
34
|
+
param_shape = [1] * x.ndim
|
|
35
|
+
param_shape[1] = c
|
|
36
|
+
self.alpha = nn.Parameter(
|
|
37
|
+
torch.full(param_shape, 0.1, dtype=x.dtype, device=x.device)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def reset_parameters(self):
|
|
41
|
+
if self.alpha is not None:
|
|
42
|
+
with torch.no_grad():
|
|
43
|
+
self.alpha.fill_(0.1)
|
|
44
|
+
|
|
45
|
+
def forward(self, x: torch.Tensor):
|
|
46
|
+
self._ensure_parameters(x)
|
|
47
|
+
|
|
48
|
+
positive_part = F.relu(x)
|
|
49
|
+
inner_term = F.relu(1.0 + (1.0 - self.theta) * x)
|
|
50
|
+
powered_term = torch.pow(inner_term, self._power_val)
|
|
51
|
+
subtracted_term = powered_term - 1.0
|
|
52
|
+
zero = torch.zeros(1, dtype=x.dtype, device=x.device)
|
|
53
|
+
negative_part = self.alpha * torch.minimum(subtracted_term, zero)
|
|
54
|
+
return positive_part + negative_part
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SReLU(nn.Module):
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
alpha_init: float = 0.0,
|
|
9
|
+
beta_init: float = 0.0,
|
|
10
|
+
gamma_init: float = 1.0,
|
|
11
|
+
delta_init: float = 1.0,
|
|
12
|
+
):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.alpha_init_val = float(alpha_init)
|
|
15
|
+
self.beta_init_val = float(beta_init)
|
|
16
|
+
self.gamma_init_val = float(gamma_init)
|
|
17
|
+
self.delta_init_val = float(delta_init)
|
|
18
|
+
self._num_channels = None
|
|
19
|
+
self.register_parameter("alpha", None)
|
|
20
|
+
self.register_parameter("beta", None)
|
|
21
|
+
self.register_parameter("gamma", None)
|
|
22
|
+
self.register_parameter("delta", None)
|
|
23
|
+
|
|
24
|
+
def _ensure_parameters(self, x: torch.Tensor):
|
|
25
|
+
if x.dim() != 4:
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
|
|
28
|
+
)
|
|
29
|
+
c = int(x.shape[1])
|
|
30
|
+
if self._num_channels is None:
|
|
31
|
+
self._num_channels = c
|
|
32
|
+
elif c != self._num_channels:
|
|
33
|
+
raise RuntimeError(
|
|
34
|
+
f"SReLU was initialized with C={self._num_channels} but got C={c}. "
|
|
35
|
+
"Create a new SReLU for different channel sizes."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if self.alpha is None:
|
|
39
|
+
shape = (1, c, 1, 1)
|
|
40
|
+
device, dtype = x.device, x.dtype
|
|
41
|
+
self.alpha = nn.Parameter(
|
|
42
|
+
torch.full(shape, self.alpha_init_val, dtype=dtype, device=device)
|
|
43
|
+
)
|
|
44
|
+
self.beta = nn.Parameter(
|
|
45
|
+
torch.full(shape, self.beta_init_val, dtype=dtype, device=device)
|
|
46
|
+
)
|
|
47
|
+
self.gamma = nn.Parameter(
|
|
48
|
+
torch.full(shape, self.gamma_init_val, dtype=dtype, device=device)
|
|
49
|
+
)
|
|
50
|
+
self.delta = nn.Parameter(
|
|
51
|
+
torch.full(shape, self.delta_init_val, dtype=dtype, device=device)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def reset_parameters(self):
|
|
55
|
+
if self.alpha is not None:
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
self.alpha.fill_(self.alpha_init_val)
|
|
58
|
+
self.beta.fill_(self.beta_init_val)
|
|
59
|
+
self.gamma.fill_(self.gamma_init_val)
|
|
60
|
+
self.delta.fill_(self.delta_init_val)
|
|
61
|
+
|
|
62
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
63
|
+
self._ensure_parameters(x)
|
|
64
|
+
|
|
65
|
+
start = self.beta + self.alpha * (x - self.beta)
|
|
66
|
+
finish = self.delta + self.gamma * (x - self.delta)
|
|
67
|
+
out = torch.where(x < self.beta, start, x)
|
|
68
|
+
out = torch.where(x > self.delta, finish, out)
|
|
69
|
+
return out
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn.parameter import Parameter
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SmallGALU(nn.Module):
|
|
8
|
+
def __init__(self, max_input: float = 1.0):
|
|
9
|
+
super().__init__()
|
|
10
|
+
if max_input <= 0:
|
|
11
|
+
raise ValueError("max_input must be positive.")
|
|
12
|
+
self.max_input = float(max_input)
|
|
13
|
+
self.register_parameter("alpha", None)
|
|
14
|
+
self.register_parameter("beta", None)
|
|
15
|
+
self._num_channels = None
|
|
16
|
+
|
|
17
|
+
def _initialize_parameters(self, x: torch.Tensor):
|
|
18
|
+
if x.ndim < 2:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
21
|
+
)
|
|
22
|
+
self._num_channels = int(x.shape[1])
|
|
23
|
+
param_shape = [1] * x.ndim
|
|
24
|
+
param_shape[1] = self._num_channels
|
|
25
|
+
device = x.device
|
|
26
|
+
dtype = x.dtype
|
|
27
|
+
self.alpha = Parameter(torch.zeros(param_shape, dtype=dtype, device=device))
|
|
28
|
+
self.beta = Parameter(torch.zeros(param_shape, dtype=dtype, device=device))
|
|
29
|
+
|
|
30
|
+
def reset_parameters(self):
|
|
31
|
+
if self.alpha is not None:
|
|
32
|
+
with torch.no_grad():
|
|
33
|
+
self.alpha.zero_()
|
|
34
|
+
self.beta.zero_()
|
|
35
|
+
|
|
36
|
+
def forward(self, x: torch.Tensor):
|
|
37
|
+
if self.alpha is None:
|
|
38
|
+
self._initialize_parameters(x)
|
|
39
|
+
else:
|
|
40
|
+
if x.ndim < 2:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
|
|
43
|
+
)
|
|
44
|
+
if x.shape[1] != self._num_channels:
|
|
45
|
+
raise RuntimeError(
|
|
46
|
+
f"SmallGALU was initialized with C={self._num_channels} but got C={x.shape[1]}. "
|
|
47
|
+
"Create a new SmallGALU for a different channel size."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
x_norm = x / self.max_input
|
|
51
|
+
zero = torch.zeros(1, dtype=x.dtype, device=x.device)
|
|
52
|
+
part_prelu = F.relu(x_norm) + self.alpha * torch.minimum(x_norm, zero)
|
|
53
|
+
part_beta = self.beta * (
|
|
54
|
+
F.relu(1.0 - torch.abs(x_norm - 1.0))
|
|
55
|
+
+ torch.minimum(torch.abs(x_norm - 3.0) - 1.0, zero)
|
|
56
|
+
)
|
|
57
|
+
z = part_prelu + part_beta
|
|
58
|
+
return z * self.max_input
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WideMELU(nn.Module):
|
|
7
|
+
def __init__(self, maxInput: float = 1.0):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.maxInput = float(maxInput)
|
|
10
|
+
self._num_channels = None
|
|
11
|
+
self.register_parameter("alpha", None)
|
|
12
|
+
self.register_parameter("beta", None)
|
|
13
|
+
self.register_parameter("gamma", None)
|
|
14
|
+
self.register_parameter("delta", None)
|
|
15
|
+
self.register_parameter("xi", None)
|
|
16
|
+
self.register_parameter("psi", None)
|
|
17
|
+
self.register_parameter("theta", None)
|
|
18
|
+
self.register_parameter("lam", None)
|
|
19
|
+
|
|
20
|
+
def _ensure_parameters(self, x: torch.Tensor):
|
|
21
|
+
if x.dim() != 4:
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
c = int(x.shape[1])
|
|
27
|
+
if self._num_channels is None:
|
|
28
|
+
self._num_channels = c
|
|
29
|
+
elif c != self._num_channels:
|
|
30
|
+
raise RuntimeError(
|
|
31
|
+
f"WideMELU was initialized with C={self._num_channels} but got C={c}. "
|
|
32
|
+
"Create a new WideMELU for different channel sizes."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if self.alpha is None:
|
|
36
|
+
shape = (1, c, 1, 1)
|
|
37
|
+
device, dtype = x.device, x.dtype
|
|
38
|
+
for name in (
|
|
39
|
+
"alpha",
|
|
40
|
+
"beta",
|
|
41
|
+
"gamma",
|
|
42
|
+
"delta",
|
|
43
|
+
"xi",
|
|
44
|
+
"psi",
|
|
45
|
+
"theta",
|
|
46
|
+
"lam",
|
|
47
|
+
):
|
|
48
|
+
param = nn.Parameter(torch.zeros(shape, dtype=dtype, device=device))
|
|
49
|
+
setattr(self, name, param)
|
|
50
|
+
|
|
51
|
+
def reset_parameters(self):
|
|
52
|
+
params = (
|
|
53
|
+
self.alpha,
|
|
54
|
+
self.beta,
|
|
55
|
+
self.gamma,
|
|
56
|
+
self.delta,
|
|
57
|
+
self.xi,
|
|
58
|
+
self.psi,
|
|
59
|
+
self.theta,
|
|
60
|
+
self.lam,
|
|
61
|
+
)
|
|
62
|
+
for p in params:
|
|
63
|
+
if p is not None:
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
p.zero_()
|
|
66
|
+
|
|
67
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
self._ensure_parameters(x)
|
|
69
|
+
|
|
70
|
+
X_norm = x / self.maxInput
|
|
71
|
+
Y = torch.roll(X_norm, shifts=-1, dims=1)
|
|
72
|
+
|
|
73
|
+
term1 = F.relu(X_norm)
|
|
74
|
+
term2 = self.alpha * torch.clamp(X_norm, max=0)
|
|
75
|
+
dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
|
|
76
|
+
dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
|
|
77
|
+
dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
|
|
78
|
+
dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
|
|
79
|
+
dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
|
|
80
|
+
dist_sq_theta = (X_norm - 1) ** 2 + (Y - 2) ** 2
|
|
81
|
+
dist_sq_lambda = (X_norm - 3) ** 2 + (Y - 2) ** 2
|
|
82
|
+
term3 = self.beta * torch.sqrt(F.relu(2 - dist_sq_beta))
|
|
83
|
+
term4 = self.gamma * torch.sqrt(F.relu(1 - dist_sq_gamma))
|
|
84
|
+
term5 = self.delta * torch.sqrt(F.relu(1 - dist_sq_delta))
|
|
85
|
+
term6 = self.xi * torch.sqrt(F.relu(1 - dist_sq_xi))
|
|
86
|
+
term7 = self.psi * torch.sqrt(F.relu(1 - dist_sq_psi))
|
|
87
|
+
term8 = self.theta * torch.sqrt(F.relu(1 - dist_sq_theta))
|
|
88
|
+
term9 = self.lam * torch.sqrt(F.relu(1 - dist_sq_lambda))
|
|
89
|
+
Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8 + term9
|
|
90
|
+
return Z_norm * self.maxInput
|
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class APLU(torch.nn.Module):
|
|
6
|
-
def __init__(self, max_input: float = 1.0):
|
|
7
|
-
super(APLU, self).__init__()
|
|
8
|
-
self.max_input = max_input
|
|
9
|
-
self.alpha = None
|
|
10
|
-
self.beta = None
|
|
11
|
-
self.gamma = None
|
|
12
|
-
self.xi = None
|
|
13
|
-
self.psi = None
|
|
14
|
-
self.mu = None
|
|
15
|
-
self._num_channels = None
|
|
16
|
-
self.device = get_device()
|
|
17
|
-
|
|
18
|
-
def _initialize_parameters(self, x):
|
|
19
|
-
if x.ndim < 2:
|
|
20
|
-
raise ValueError(
|
|
21
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
num_channels = x.shape[1]
|
|
25
|
-
self._num_channels = num_channels
|
|
26
|
-
|
|
27
|
-
param_shape = [1] * x.ndim
|
|
28
|
-
param_shape[1] = num_channels
|
|
29
|
-
|
|
30
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape), device=device())
|
|
31
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape), device=device())
|
|
32
|
-
self.gamma = torch.nn.Parameter(torch.zeros(param_shape), device=device())
|
|
33
|
-
|
|
34
|
-
self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
|
|
35
|
-
self.device
|
|
36
|
-
)
|
|
37
|
-
self.psi = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
|
|
38
|
-
self.device
|
|
39
|
-
)
|
|
40
|
-
self.mu = torch.nn.Parameter(self.max_input * torch.rand(param_shape)).to(
|
|
41
|
-
self.device
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
def forward(self, x):
|
|
45
|
-
if self.alpha is None:
|
|
46
|
-
self._initialize_parameters(x)
|
|
47
|
-
a = torch.relu(x)
|
|
48
|
-
|
|
49
|
-
# following are called hinges
|
|
50
|
-
b = self.alpha * torch.relu(-x + self.xi)
|
|
51
|
-
c = self.beta * torch.relu(-x + self.psi)
|
|
52
|
-
d = self.gamma * torch.relu(-x + self.mu)
|
|
53
|
-
z = a + b + c + d
|
|
54
|
-
return z
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class GALU(torch.nn.Module):
|
|
6
|
-
def __init__(self, max_input: float = 1.0):
|
|
7
|
-
super(GALU, self).__init__()
|
|
8
|
-
if max_input <= 0:
|
|
9
|
-
raise ValueError("max_input must be positive.")
|
|
10
|
-
self.max_input = max_input
|
|
11
|
-
self.alpha = None
|
|
12
|
-
self.beta = None
|
|
13
|
-
self.gamma = None
|
|
14
|
-
self.delta = None
|
|
15
|
-
self._num_channels = None
|
|
16
|
-
self.device = get_device()
|
|
17
|
-
|
|
18
|
-
def _initialize_parameters(self, x):
|
|
19
|
-
if x.ndim < 2:
|
|
20
|
-
raise ValueError(
|
|
21
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
num_channels = x.shape[1]
|
|
25
|
-
self._num_channels = num_channels
|
|
26
|
-
param_shape = [1] * x.ndim
|
|
27
|
-
param_shape[1] = num_channels
|
|
28
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
29
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
30
|
-
self.gamma = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
31
|
-
self.delta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
32
|
-
|
|
33
|
-
def forward(self, x):
|
|
34
|
-
if self.alpha is None:
|
|
35
|
-
self._initialize_parameters(x)
|
|
36
|
-
|
|
37
|
-
zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
38
|
-
x_norm = x / self.max_input
|
|
39
|
-
part_prelu = torch.relu(x_norm) + self.alpha * torch.min(x_norm, zero)
|
|
40
|
-
part_beta = self.beta * (
|
|
41
|
-
torch.relu(1.0 - torch.abs(x_norm - 1.0))
|
|
42
|
-
+ torch.min(torch.abs(x_norm - 3.0) - 1.0, zero)
|
|
43
|
-
)
|
|
44
|
-
part_gamma = self.gamma * (
|
|
45
|
-
torch.relu(0.5 - torch.abs(x_norm - 0.5))
|
|
46
|
-
+ torch.min(torch.abs(x_norm - 1.5) - 0.5, zero)
|
|
47
|
-
)
|
|
48
|
-
part_delta = self.delta * (
|
|
49
|
-
torch.relu(0.5 - torch.abs(x_norm - 2.5))
|
|
50
|
-
+ torch.min(torch.abs(x_norm - 3.5) - 0.5, zero)
|
|
51
|
-
)
|
|
52
|
-
z = part_prelu + part_beta + part_gamma + part_delta
|
|
53
|
-
return z * self.max_input
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class MELU(torch.nn.Module):
|
|
6
|
-
def __init__(self, maxInput: float = 1.0):
|
|
7
|
-
super().__init__()
|
|
8
|
-
self.maxInput = float(maxInput)
|
|
9
|
-
self.alpha = None
|
|
10
|
-
self.beta = None
|
|
11
|
-
self.gamma = None
|
|
12
|
-
self.delta = None
|
|
13
|
-
self.xi = None
|
|
14
|
-
self.psi = None
|
|
15
|
-
self._initialized = False
|
|
16
|
-
self.device = get_device()
|
|
17
|
-
|
|
18
|
-
def _initialize_parameters(self, X: torch.Tensor):
|
|
19
|
-
if X.dim() != 4:
|
|
20
|
-
raise ValueError(
|
|
21
|
-
f"Expected 4D input (B, C, H, W), but got {X.dim()}D input."
|
|
22
|
-
)
|
|
23
|
-
num_channels = X.shape[1]
|
|
24
|
-
shape = (1, num_channels, 1, 1)
|
|
25
|
-
self.alpha = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
26
|
-
self.beta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
27
|
-
self.gamma = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
28
|
-
self.delta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
29
|
-
self.xi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
30
|
-
self.psi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
31
|
-
self._initialized = True
|
|
32
|
-
|
|
33
|
-
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
34
|
-
if not self._initialized:
|
|
35
|
-
self._initialize_parameters(X)
|
|
36
|
-
X_norm = X / self.maxInput
|
|
37
|
-
Y = torch.roll(X_norm, shifts=-1, dims=1)
|
|
38
|
-
term1 = torch.relu(X_norm)
|
|
39
|
-
term2 = self.alpha * torch.clamp(X_norm, max=0)
|
|
40
|
-
dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
|
|
41
|
-
dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
|
|
42
|
-
dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
|
|
43
|
-
dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
|
|
44
|
-
dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
|
|
45
|
-
term3 = self.beta * torch.sqrt(torch.relu(2 - dist_sq_beta))
|
|
46
|
-
term4 = self.gamma * torch.sqrt(torch.relu(1 - dist_sq_gamma))
|
|
47
|
-
term5 = self.delta * torch.sqrt(torch.relu(1 - dist_sq_delta))
|
|
48
|
-
term6 = self.xi * torch.sqrt(torch.relu(1 - dist_sq_xi))
|
|
49
|
-
term7 = self.psi * torch.sqrt(torch.relu(1 - dist_sq_psi))
|
|
50
|
-
Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7
|
|
51
|
-
Z = Z_norm * self.maxInput
|
|
52
|
-
return Z
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class PDELU(torch.nn.Module):
|
|
6
|
-
def __init__(self, theta: float = 0.5):
|
|
7
|
-
super(PDELU, self).__init__()
|
|
8
|
-
if theta == 1.0:
|
|
9
|
-
raise ValueError(
|
|
10
|
-
"theta cannot be 1.0, as it would cause a division by zero."
|
|
11
|
-
)
|
|
12
|
-
self.theta = theta
|
|
13
|
-
self._power_val = 1.0 / (1.0 - self.theta)
|
|
14
|
-
self.alpha = torch.nn.UninitializedParameter()
|
|
15
|
-
self._num_channels = None
|
|
16
|
-
self.device = get_device()
|
|
17
|
-
|
|
18
|
-
def _initialize_parameters(self, x: torch.Tensor):
|
|
19
|
-
if x.ndim < 2:
|
|
20
|
-
raise ValueError(
|
|
21
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
num_channels = x.shape[1]
|
|
25
|
-
self._num_channels = num_channels
|
|
26
|
-
param_shape = [1] * x.ndim
|
|
27
|
-
param_shape[1] = num_channels
|
|
28
|
-
init_tensor = torch.zeros(param_shape, device=self.device) + 0.1
|
|
29
|
-
self.alpha = torch.nn.Parameter(init_tensor).to(self.device)
|
|
30
|
-
|
|
31
|
-
def forward(self, x: torch.Tensor):
|
|
32
|
-
if self.alpha is None:
|
|
33
|
-
self._initialize_parameters(x)
|
|
34
|
-
|
|
35
|
-
zero = torch.tensor(0.0, device=self.device, dtype=x.dtype)
|
|
36
|
-
positive_part = torch.relu(x)
|
|
37
|
-
inner_term = torch.relu(1.0 + (1.0 - self.theta) * x)
|
|
38
|
-
powered_term = torch.pow(inner_term, self._power_val)
|
|
39
|
-
subtracted_term = powered_term - 1.0
|
|
40
|
-
negative_part = self.alpha * torch.min(subtracted_term, zero)
|
|
41
|
-
return positive_part + negative_part
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class SReLU(torch.nn.Module):
|
|
6
|
-
def __init__(
|
|
7
|
-
self,
|
|
8
|
-
alpha_init: float = 0.0,
|
|
9
|
-
beta_init: float = 0.0,
|
|
10
|
-
gamma_init: float = 1.0,
|
|
11
|
-
delta_init: float = 1.0,
|
|
12
|
-
):
|
|
13
|
-
super().__init__()
|
|
14
|
-
self.alpha_init_val = alpha_init
|
|
15
|
-
self.beta_init_val = beta_init
|
|
16
|
-
self.gamma_init_val = gamma_init
|
|
17
|
-
self.delta_init_val = delta_init
|
|
18
|
-
self.alpha = torch.nn.UninitializedParameter()
|
|
19
|
-
self.beta = torch.nn.UninitializedParameter()
|
|
20
|
-
self.gamma = torch.nn.UninitializedParameter()
|
|
21
|
-
self.delta = torch.nn.UninitializedParameter()
|
|
22
|
-
self.device = get_device()
|
|
23
|
-
|
|
24
|
-
def _initialize_parameters(self, x: torch.Tensor):
|
|
25
|
-
if isinstance(self.alpha, torch.nn.UninitializedParameter):
|
|
26
|
-
if x.dim() < 2:
|
|
27
|
-
raise ValueError(
|
|
28
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got {x.dim()}"
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
num_channels = x.shape[1]
|
|
32
|
-
param_shape = [1] * x.dim()
|
|
33
|
-
param_shape[1] = num_channels
|
|
34
|
-
self.alpha = torch.nn.Parameter(
|
|
35
|
-
torch.full(param_shape, self.alpha_init_val)
|
|
36
|
-
).to(self.device)
|
|
37
|
-
self.beta = torch.nn.Parameter(
|
|
38
|
-
torch.full(param_shape, self.beta_init_val)
|
|
39
|
-
).to(self.device)
|
|
40
|
-
self.gamma = torch.nn.Parameter(
|
|
41
|
-
torch.full(param_shape, self.gamma_init_val)
|
|
42
|
-
).to(self.device)
|
|
43
|
-
self.delta = torch.nn.Parameter(
|
|
44
|
-
torch.full(param_shape, self.delta_init_val)
|
|
45
|
-
).to(self.device)
|
|
46
|
-
|
|
47
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
-
self._initialize_parameters(x)
|
|
49
|
-
start = self.beta + self.alpha * (x - self.beta)
|
|
50
|
-
finish = self.delta + self.gamma * (x - self.delta)
|
|
51
|
-
out = torch.where(x < self.beta, start, x)
|
|
52
|
-
out = torch.where(x > self.delta, finish, out)
|
|
53
|
-
return out
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class SmallGALU(torch.nn.Module):
|
|
6
|
-
def __init__(self, max_input: float = 1.0):
|
|
7
|
-
super(SmallGALU, self).__init__()
|
|
8
|
-
if max_input <= 0:
|
|
9
|
-
raise ValueError("max_input must be positive.")
|
|
10
|
-
self.max_input = max_input
|
|
11
|
-
self.alpha = None
|
|
12
|
-
self.beta = None
|
|
13
|
-
self._num_channels = None
|
|
14
|
-
self.device = get_device()
|
|
15
|
-
|
|
16
|
-
def _initialize_parameters(self, x):
|
|
17
|
-
if x.ndim < 2:
|
|
18
|
-
raise ValueError(
|
|
19
|
-
f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
num_channels = x.shape[1]
|
|
23
|
-
self._num_channels = num_channels
|
|
24
|
-
param_shape = [1] * x.ndim
|
|
25
|
-
param_shape[1] = num_channels
|
|
26
|
-
self.alpha = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
27
|
-
self.beta = torch.nn.Parameter(torch.zeros(param_shape)).to(self.device)
|
|
28
|
-
|
|
29
|
-
def forward(self, x):
|
|
30
|
-
if self.alpha is None:
|
|
31
|
-
self._initialize_parameters(x)
|
|
32
|
-
|
|
33
|
-
zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
|
34
|
-
x_norm = x / self.max_input
|
|
35
|
-
part_prelu = torch.relu(x_norm) + self.alpha * torch.min(x_norm, zero)
|
|
36
|
-
part_beta = self.beta * (
|
|
37
|
-
torch.relu(1.0 - torch.abs(x_norm - 1.0))
|
|
38
|
-
+ torch.min(torch.abs(x_norm - 3.0) - 1.0, zero)
|
|
39
|
-
)
|
|
40
|
-
z = part_prelu + part_beta
|
|
41
|
-
return z * self.max_input
|
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...device import get_device
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class WideMELU(torch.nn.Module):
|
|
6
|
-
def __init__(self, maxInput: float = 1.0):
|
|
7
|
-
super().__init__()
|
|
8
|
-
self.maxInput = float(maxInput)
|
|
9
|
-
self.alpha = None
|
|
10
|
-
self.beta = None
|
|
11
|
-
self.gamma = None
|
|
12
|
-
self.delta = None
|
|
13
|
-
self.xi = None
|
|
14
|
-
self.psi = None
|
|
15
|
-
self.theta = None
|
|
16
|
-
self.lam = None
|
|
17
|
-
self._initialized = False
|
|
18
|
-
self.device = get_device()
|
|
19
|
-
|
|
20
|
-
def _initialize_parameters(self, X: torch.Tensor):
|
|
21
|
-
if X.dim() != 4:
|
|
22
|
-
raise ValueError(
|
|
23
|
-
f"Expected 4D input (B, C, H, W), but got {X.dim()}D input."
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
num_channels = X.shape[1]
|
|
27
|
-
shape = (1, num_channels, 1, 1)
|
|
28
|
-
|
|
29
|
-
self.alpha = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
30
|
-
self.beta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
31
|
-
self.gamma = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
32
|
-
self.delta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
33
|
-
self.xi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
34
|
-
self.psi = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
35
|
-
self.theta = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
36
|
-
self.lam = torch.nn.Parameter(torch.zeros(shape)).to(self.device)
|
|
37
|
-
self._initialized = True
|
|
38
|
-
|
|
39
|
-
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
40
|
-
if not self._initialized:
|
|
41
|
-
self._initialize_parameters(X)
|
|
42
|
-
X_norm = X / self.maxInput
|
|
43
|
-
Y = torch.roll(X_norm, shifts=-1, dims=1)
|
|
44
|
-
term1 = torch.relu(X_norm)
|
|
45
|
-
term2 = self.alpha * torch.clamp(X_norm, max=0)
|
|
46
|
-
dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
|
|
47
|
-
dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
|
|
48
|
-
dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
|
|
49
|
-
dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
|
|
50
|
-
dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
|
|
51
|
-
dist_sq_theta = (X_norm - 1) ** 2 + (Y - 2) ** 2
|
|
52
|
-
dist_sq_lambda = (X_norm - 3) ** 2 + (Y - 2) ** 2
|
|
53
|
-
|
|
54
|
-
term3 = self.beta * torch.sqrt(torch.relu(2 - dist_sq_beta))
|
|
55
|
-
term4 = self.gamma * torch.sqrt(torch.relu(1 - dist_sq_gamma))
|
|
56
|
-
term5 = self.delta * torch.sqrt(torch.relu(1 - dist_sq_delta))
|
|
57
|
-
term6 = self.xi * torch.sqrt(torch.relu(1 - dist_sq_xi))
|
|
58
|
-
term7 = self.psi * torch.sqrt(torch.relu(1 - dist_sq_psi))
|
|
59
|
-
term8 = self.theta * torch.sqrt(torch.relu(1 - dist_sq_theta))
|
|
60
|
-
term9 = self.lam * torch.sqrt(torch.relu(1 - dist_sq_lambda))
|
|
61
|
-
Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8 + term9
|
|
62
|
-
Z = Z_norm * self.maxInput
|
|
63
|
-
return Z
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|