homa 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. homa/activations/classes/APLU.py +69 -31
  2. homa/activations/classes/GALU.py +44 -28
  3. homa/activations/classes/MELU.py +51 -31
  4. homa/activations/classes/PDELU.py +33 -18
  5. homa/activations/classes/SReLU.py +46 -26
  6. homa/activations/classes/SmallGALU.py +37 -18
  7. homa/activations/classes/WideMELU.py +71 -42
  8. homa/activations/classes/__init__.py +0 -1
  9. homa/ensemble/Ensemble.py +2 -4
  10. homa/ensemble/concerns/CalculatesMetricNecessities.py +14 -10
  11. homa/ensemble/concerns/PredictsProbabilities.py +4 -0
  12. homa/ensemble/concerns/ReportsClassificationMetrics.py +1 -1
  13. homa/ensemble/concerns/ReportsEnsembleAccuracy.py +3 -2
  14. homa/ensemble/concerns/ReportsLogits.py +4 -0
  15. homa/ensemble/concerns/ReportsSize.py +2 -2
  16. homa/ensemble/concerns/StoresModels.py +29 -0
  17. homa/ensemble/concerns/__init__.py +1 -2
  18. homa/loss/LogitNormLoss.py +12 -0
  19. homa/loss/Loss.py +2 -0
  20. homa/loss/__init__.py +2 -0
  21. homa/torch/__init__.py +0 -1
  22. homa/vision/ClassificationModel.py +5 -0
  23. homa/vision/Resnet.py +6 -5
  24. homa/vision/StochasticClassifier.py +28 -0
  25. homa/vision/StochasticResnet.py +6 -5
  26. homa/vision/StochasticSwin.py +9 -0
  27. homa/vision/Swin.py +12 -0
  28. homa/vision/__init__.py +1 -0
  29. homa/vision/concerns/HasLabels.py +13 -0
  30. homa/vision/concerns/HasLogits.py +12 -0
  31. homa/vision/concerns/HasProbabilities.py +9 -0
  32. homa/vision/concerns/ReportsAccuracy.py +27 -0
  33. homa/vision/concerns/ReportsMetrics.py +6 -0
  34. homa/vision/concerns/Trainable.py +5 -2
  35. homa/vision/concerns/__init__.py +5 -0
  36. homa/vision/modules/SwinModule.py +23 -0
  37. homa/vision/modules/__init__.py +1 -1
  38. homa-0.2.0.dist-info/METADATA +75 -0
  39. homa-0.2.0.dist-info/RECORD +58 -0
  40. homa/activations/classes/StochasticActivation.py +0 -20
  41. homa/ensemble/concerns/HasNetwork.py +0 -5
  42. homa/ensemble/concerns/HasStateDicts.py +0 -8
  43. homa/ensemble/concerns/RecordsStateDictionaries.py +0 -23
  44. homa/torch/Module.py +0 -8
  45. homa/vision/modules/StochasticResnetModule.py +0 -9
  46. homa/vision/utils.py +0 -21
  47. homa-0.1.0.dist-info/METADATA +0 -21
  48. homa-0.1.0.dist-info/RECORD +0 -51
  49. {homa-0.1.0.dist-info → homa-0.2.0.dist-info}/WHEEL +0 -0
  50. {homa-0.1.0.dist-info → homa-0.2.0.dist-info}/entry_points.txt +0 -0
  51. {homa-0.1.0.dist-info → homa-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,48 +1,86 @@
1
1
  import torch
2
+ from torch import nn
3
+ from torch.nn.parameter import Parameter, UninitializedParameter
4
+ import torch.nn.functional as F
2
5
 
3
6
 
4
- class APLU(torch.nn.Module):
7
+ class APLU(nn.Module):
5
8
  def __init__(self, max_input: float = 1.0):
6
- super(APLU, self).__init__()
7
- self.max_input = max_input
8
- self.alpha = None
9
- self.beta = None
10
- self.gamma = None
11
- self.xi = None
12
- self.psi = None
13
- self.mu = None
9
+ super().__init__()
10
+ self.max_input = float(max_input)
11
+ self.alpha = UninitializedParameter()
12
+ self.beta = UninitializedParameter()
13
+ self.gamma = UninitializedParameter()
14
+ self.xi = UninitializedParameter()
15
+ self.psi = UninitializedParameter()
16
+ self.mu = UninitializedParameter()
14
17
  self._num_channels = None
15
18
 
16
- def _initialize_parameters(self, x):
19
+ def _initialize_parameters(self, x: torch.Tensor):
17
20
  if x.ndim < 2:
18
21
  raise ValueError(
19
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
22
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
20
23
  )
21
24
 
22
- num_channels = x.shape[1]
23
- self._num_channels = num_channels
24
-
25
+ channels = int(x.shape[1])
26
+ self._num_channels = channels
25
27
  param_shape = [1] * x.ndim
26
- param_shape[1] = num_channels
28
+ param_shape[1] = channels
27
29
 
28
- self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
29
- self.beta = torch.nn.Parameter(torch.zeros(param_shape))
30
- self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
30
+ with torch.no_grad():
31
+ self.alpha = Parameter(
32
+ torch.zeros(param_shape, dtype=x.dtype, device=x.device)
33
+ )
34
+ self.beta = Parameter(
35
+ torch.zeros(param_shape, dtype=x.dtype, device=x.device)
36
+ )
37
+ self.gamma = Parameter(
38
+ torch.zeros(param_shape, dtype=x.dtype, device=x.device)
39
+ )
40
+ self.xi = Parameter(
41
+ torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
42
+ 0.0, self.max_input
43
+ )
44
+ )
45
+ self.psi = Parameter(
46
+ torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
47
+ 0.0, self.max_input
48
+ )
49
+ )
50
+ self.mu = Parameter(
51
+ torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
52
+ 0.0, self.max_input
53
+ )
54
+ )
31
55
 
32
- self.xi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
33
- self.psi = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
34
- self.mu = torch.nn.Parameter(self.max_input * torch.rand(param_shape))
56
+ def reset_parameters(self):
57
+ if isinstance(self.alpha, UninitializedParameter):
58
+ return
35
59
 
36
- def forward(self, x):
37
- if self.alpha is None:
38
- self._initialize_parameters(x)
60
+ with torch.no_grad():
61
+ self.alpha.zero_()
62
+ self.beta.zero_()
63
+ self.gamma.zero_()
64
+ self.xi.uniform_(0.0, self.max_input)
65
+ self.psi.uniform_(0.0, self.max_input)
66
+ self.mu.uniform_(0.0, self.max_input)
39
67
 
40
- a = torch.relu(x)
68
+ def forward(self, x: torch.Tensor):
69
+ if isinstance(self.alpha, UninitializedParameter):
70
+ self._initialize_parameters(x)
41
71
 
42
- # following are called hinges
43
- b = self.alpha * torch.relu(-x + self.xi)
44
- c = self.beta * torch.relu(-x + self.psi)
45
- d = self.gamma * torch.relu(-x + self.mu)
46
- z = a + b + c + d
72
+ if x.ndim < 2:
73
+ raise ValueError(
74
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
75
+ )
76
+ if self._num_channels is not None and x.shape[1] != self._num_channels:
77
+ raise RuntimeError(
78
+ f"APLU was initialized with C={self._num_channels} but got C={x.shape[1]}. "
79
+ "Create a new APLU for a different channel size."
80
+ )
47
81
 
48
- return z
82
+ a = F.relu(x)
83
+ b = self.alpha * F.relu(-x + self.xi)
84
+ c = self.beta * F.relu(-x + self.psi)
85
+ d = self.gamma * F.relu(-x + self.mu)
86
+ return a + b + c + d
@@ -1,51 +1,67 @@
1
1
  import torch
2
+ from torch import nn
3
+ from torch.nn.parameter import Parameter, UninitializedParameter
4
+ import torch.nn.functional as F
2
5
 
3
6
 
4
- class GALU(torch.nn.Module):
7
+ class GALU(nn.Module):
5
8
  def __init__(self, max_input: float = 1.0):
6
- super(GALU, self).__init__()
9
+ super().__init__()
7
10
  if max_input <= 0:
8
11
  raise ValueError("max_input must be positive.")
9
- self.max_input = max_input
10
- self.alpha = None
11
- self.beta = None
12
- self.gamma = None
13
- self.delta = None
14
- self._num_channels = None
12
+ self.max_input = float(max_input)
13
+ self.alpha: torch.Tensor = UninitializedParameter()
14
+ self.beta: torch.Tensor = UninitializedParameter()
15
+ self.gamma: torch.Tensor = UninitializedParameter()
16
+ self.delta: torch.Tensor = UninitializedParameter()
15
17
 
16
- def _initialize_parameters(self, x):
18
+ def _initialize_parameters(self, x: torch.Tensor):
17
19
  if x.ndim < 2:
18
20
  raise ValueError(
19
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
21
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
20
22
  )
21
-
22
- num_channels = x.shape[1]
23
- self._num_channels = num_channels
24
23
  param_shape = [1] * x.ndim
25
- param_shape[1] = num_channels
26
- self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
27
- self.beta = torch.nn.Parameter(torch.zeros(param_shape))
28
- self.gamma = torch.nn.Parameter(torch.zeros(param_shape))
29
- self.delta = torch.nn.Parameter(torch.zeros(param_shape))
24
+ param_shape[1] = int(x.shape[1])
25
+ zeros = torch.zeros(param_shape, dtype=x.dtype, device=x.device)
26
+ with torch.no_grad():
27
+ for name in ("alpha", "beta", "gamma", "delta"):
28
+ setattr(self, name, Parameter(zeros.clone()))
29
+
30
+ def reset_parameters(self):
31
+ for name in ("alpha", "beta", "gamma", "delta"):
32
+ p = getattr(self, name)
33
+ if not isinstance(p, UninitializedParameter):
34
+ with torch.no_grad():
35
+ p.zero_()
30
36
 
31
- def forward(self, x):
32
- if self.alpha is None:
37
+ def forward(self, x: torch.Tensor):
38
+ if isinstance(self.alpha, UninitializedParameter):
33
39
  self._initialize_parameters(x)
34
40
 
35
- zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
41
+ if x.ndim < 2:
42
+ raise ValueError(
43
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
44
+ )
45
+ if not isinstance(self.alpha, UninitializedParameter) and x.shape[1] != self.alpha.shape[1]:
46
+ raise RuntimeError(
47
+ f"GALU was initialized with C={self.alpha.shape[1]} but got C={x.shape[1]}. "
48
+ "Create a new GALU for a different channel size."
49
+ )
50
+
36
51
  x_norm = x / self.max_input
37
- part_prelu = torch.relu(x_norm) + self.alpha * torch.min(x_norm, zero)
52
+ zero = x.new_zeros(1)
53
+ part_prelu = F.relu(x_norm) + self.alpha * torch.minimum(x_norm, zero)
38
54
  part_beta = self.beta * (
39
- torch.relu(1.0 - torch.abs(x_norm - 1.0))
40
- + torch.min(torch.abs(x_norm - 3.0) - 1.0, zero)
55
+ F.relu(1.0 - torch.abs(x_norm - 1.0))
56
+ + torch.minimum(torch.abs(x_norm - 3.0) - 1.0, zero)
41
57
  )
42
58
  part_gamma = self.gamma * (
43
- torch.relu(0.5 - torch.abs(x_norm - 0.5))
44
- + torch.min(torch.abs(x_norm - 1.5) - 0.5, zero)
59
+ F.relu(0.5 - torch.abs(x_norm - 0.5))
60
+ + torch.minimum(torch.abs(x_norm - 1.5) - 0.5, zero)
45
61
  )
46
62
  part_delta = self.delta * (
47
- torch.relu(0.5 - torch.abs(x_norm - 2.5))
48
- + torch.min(torch.abs(x_norm - 3.5) - 0.5, zero)
63
+ F.relu(0.5 - torch.abs(x_norm - 2.5))
64
+ + torch.minimum(torch.abs(x_norm - 3.5) - 0.5, zero)
49
65
  )
50
66
  z = part_prelu + part_beta + part_gamma + part_delta
51
67
  return z * self.max_input
@@ -1,50 +1,70 @@
1
1
  import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
2
4
 
3
5
 
4
- class MELU(torch.nn.Module):
6
+ class MELU(nn.Module):
5
7
  def __init__(self, maxInput: float = 1.0):
6
8
  super().__init__()
7
9
  self.maxInput = float(maxInput)
8
- self.alpha = None
9
- self.beta = None
10
- self.gamma = None
11
- self.delta = None
12
- self.xi = None
13
- self.psi = None
14
- self._initialized = False
15
-
16
- def _initialize_parameters(self, X: torch.Tensor):
17
- if X.dim() != 4:
10
+ self._num_channels = None
11
+ self.register_parameter("alpha", None)
12
+ self.register_parameter("beta", None)
13
+ self.register_parameter("gamma", None)
14
+ self.register_parameter("delta", None)
15
+ self.register_parameter("xi", None)
16
+ self.register_parameter("psi", None)
17
+
18
+ def _ensure_parameters(self, x: torch.Tensor):
19
+ if x.dim() != 4:
18
20
  raise ValueError(
19
- f"Expected 4D input (B, C, H, W), but got {X.dim()}D input."
21
+ f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
22
+ )
23
+ c = int(x.shape[1])
24
+ if self._num_channels is None:
25
+ self._num_channels = c
26
+ elif c != self._num_channels:
27
+ raise RuntimeError(
28
+ f"MELU was initialized with C={self._num_channels} but got C={c}. "
29
+ "Create a new MELU for a different channel size."
20
30
  )
21
- num_channels = X.shape[1]
22
- shape = (1, num_channels, 1, 1)
23
- self.alpha = torch.nn.Parameter(torch.zeros(shape))
24
- self.beta = torch.nn.Parameter(torch.zeros(shape))
25
- self.gamma = torch.nn.Parameter(torch.zeros(shape))
26
- self.delta = torch.nn.Parameter(torch.zeros(shape))
27
- self.xi = torch.nn.Parameter(torch.zeros(shape))
28
- self.psi = torch.nn.Parameter(torch.zeros(shape))
29
- self._initialized = True
31
+
32
+ if self.alpha is None:
33
+ shape = (1, c, 1, 1)
34
+ device, dtype = x.device, x.dtype
35
+ for name in ("alpha", "beta", "gamma", "delta", "xi", "psi"):
36
+ setattr(
37
+ self,
38
+ name,
39
+ nn.Parameter(torch.zeros(shape, dtype=dtype, device=device)),
40
+ )
41
+
42
+ def reset_parameters(self):
43
+ for p in (self.alpha, self.beta, self.gamma, self.delta, self.xi, self.psi):
44
+ if p is not None:
45
+ with torch.no_grad():
46
+ p.zero_()
30
47
 
31
48
  def forward(self, X: torch.Tensor) -> torch.Tensor:
32
- if not self._initialized:
33
- self._initialize_parameters(X)
49
+ self._ensure_parameters(X)
50
+
34
51
  X_norm = X / self.maxInput
35
52
  Y = torch.roll(X_norm, shifts=-1, dims=1)
36
- term1 = torch.relu(X_norm)
53
+
54
+ term1 = F.relu(X_norm)
37
55
  term2 = self.alpha * torch.clamp(X_norm, max=0)
56
+
38
57
  dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
39
58
  dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
40
59
  dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
41
60
  dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
42
61
  dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
43
- term3 = self.beta * torch.sqrt(torch.relu(2 - dist_sq_beta))
44
- term4 = self.gamma * torch.sqrt(torch.relu(1 - dist_sq_gamma))
45
- term5 = self.delta * torch.sqrt(torch.relu(1 - dist_sq_delta))
46
- term6 = self.xi * torch.sqrt(torch.relu(1 - dist_sq_xi))
47
- term7 = self.psi * torch.sqrt(torch.relu(1 - dist_sq_psi))
62
+
63
+ term3 = self.beta * torch.sqrt(F.relu(2 - dist_sq_beta))
64
+ term4 = self.gamma * torch.sqrt(F.relu(1 - dist_sq_gamma))
65
+ term5 = self.delta * torch.sqrt(F.relu(1 - dist_sq_delta))
66
+ term6 = self.xi * torch.sqrt(F.relu(1 - dist_sq_xi))
67
+ term7 = self.psi * torch.sqrt(F.relu(1 - dist_sq_psi))
68
+
48
69
  Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7
49
- Z = Z_norm * self.maxInput
50
- return Z
70
+ return Z_norm * self.maxInput
@@ -1,39 +1,54 @@
1
1
  import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
2
4
 
3
5
 
4
- class PDELU(torch.nn.Module):
6
+ class PDELU(nn.Module):
5
7
  def __init__(self, theta: float = 0.5):
6
- super(PDELU, self).__init__()
8
+ super().__init__()
7
9
  if theta == 1.0:
8
10
  raise ValueError(
9
11
  "theta cannot be 1.0, as it would cause a division by zero."
10
12
  )
11
- self.theta = theta
13
+ self.theta = float(theta)
12
14
  self._power_val = 1.0 / (1.0 - self.theta)
13
- self.alpha = torch.nn.UninitializedParameter()
15
+ self.register_parameter("alpha", None)
14
16
  self._num_channels = None
15
17
 
16
- def _initialize_parameters(self, x: torch.Tensor):
18
+ def _ensure_parameters(self, x: torch.Tensor):
17
19
  if x.ndim < 2:
18
20
  raise ValueError(
19
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
21
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
20
22
  )
21
23
 
22
- num_channels = x.shape[1]
23
- self._num_channels = num_channels
24
- param_shape = [1] * x.ndim
25
- param_shape[1] = num_channels
26
- init_tensor = torch.zeros(param_shape) + 0.1
27
- self.alpha = torch.nn.Parameter(init_tensor)
24
+ c = int(x.shape[1])
25
+ if self._num_channels is None:
26
+ self._num_channels = c
27
+ elif c != self._num_channels:
28
+ raise RuntimeError(
29
+ f"PDELU was initialized with C={self._num_channels} but got C={c}. "
30
+ "Create a new PDELU for a different channel size."
31
+ )
28
32
 
29
- def forward(self, x: torch.Tensor):
30
33
  if self.alpha is None:
31
- self._initialize_parameters(x)
34
+ param_shape = [1] * x.ndim
35
+ param_shape[1] = c
36
+ self.alpha = nn.Parameter(
37
+ torch.full(param_shape, 0.1, dtype=x.dtype, device=x.device)
38
+ )
39
+
40
+ def reset_parameters(self):
41
+ if self.alpha is not None:
42
+ with torch.no_grad():
43
+ self.alpha.fill_(0.1)
44
+
45
+ def forward(self, x: torch.Tensor):
46
+ self._ensure_parameters(x)
32
47
 
33
- zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
34
- positive_part = torch.relu(x)
35
- inner_term = torch.relu(1.0 + (1.0 - self.theta) * x)
48
+ positive_part = F.relu(x)
49
+ inner_term = F.relu(1.0 + (1.0 - self.theta) * x)
36
50
  powered_term = torch.pow(inner_term, self._power_val)
37
51
  subtracted_term = powered_term - 1.0
38
- negative_part = self.alpha * torch.min(subtracted_term, zero)
52
+ zero = torch.zeros(1, dtype=x.dtype, device=x.device)
53
+ negative_part = self.alpha * torch.minimum(subtracted_term, zero)
39
54
  return positive_part + negative_part
@@ -1,7 +1,8 @@
1
1
  import torch
2
+ from torch import nn
2
3
 
3
4
 
4
- class SReLU(torch.nn.Module):
5
+ class SReLU(nn.Module):
5
6
  def __init__(
6
7
  self,
7
8
  alpha_init: float = 0.0,
@@ -10,38 +11,57 @@ class SReLU(torch.nn.Module):
10
11
  delta_init: float = 1.0,
11
12
  ):
12
13
  super().__init__()
13
- self.alpha_init_val = alpha_init
14
- self.beta_init_val = beta_init
15
- self.gamma_init_val = gamma_init
16
- self.delta_init_val = delta_init
17
- self.alpha = torch.nn.UninitializedParameter()
18
- self.beta = torch.nn.UninitializedParameter()
19
- self.gamma = torch.nn.UninitializedParameter()
20
- self.delta = torch.nn.UninitializedParameter()
14
+ self.alpha_init_val = float(alpha_init)
15
+ self.beta_init_val = float(beta_init)
16
+ self.gamma_init_val = float(gamma_init)
17
+ self.delta_init_val = float(delta_init)
18
+ self._num_channels = None
19
+ self.register_parameter("alpha", None)
20
+ self.register_parameter("beta", None)
21
+ self.register_parameter("gamma", None)
22
+ self.register_parameter("delta", None)
21
23
 
22
- def _initialize_parameters(self, x: torch.Tensor):
23
- if isinstance(self.alpha, torch.nn.UninitializedParameter):
24
- if x.dim() < 2:
25
- raise ValueError(
26
- f"Input tensor must have at least 2 dimensions (N, C), but got {x.dim()}"
27
- )
24
+ def _ensure_parameters(self, x: torch.Tensor):
25
+ if x.dim() != 4:
26
+ raise ValueError(
27
+ f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
28
+ )
29
+ c = int(x.shape[1])
30
+ if self._num_channels is None:
31
+ self._num_channels = c
32
+ elif c != self._num_channels:
33
+ raise RuntimeError(
34
+ f"SReLU was initialized with C={self._num_channels} but got C={c}. "
35
+ "Create a new SReLU for different channel sizes."
36
+ )
28
37
 
29
- num_channels = x.shape[1]
30
- param_shape = [1] * x.dim()
31
- param_shape[1] = num_channels
32
- self.alpha = torch.nn.Parameter(
33
- torch.full(param_shape, self.alpha_init_val)
38
+ if self.alpha is None:
39
+ shape = (1, c, 1, 1)
40
+ device, dtype = x.device, x.dtype
41
+ self.alpha = nn.Parameter(
42
+ torch.full(shape, self.alpha_init_val, dtype=dtype, device=device)
43
+ )
44
+ self.beta = nn.Parameter(
45
+ torch.full(shape, self.beta_init_val, dtype=dtype, device=device)
34
46
  )
35
- self.beta = torch.nn.Parameter(torch.full(param_shape, self.beta_init_val))
36
- self.gamma = torch.nn.Parameter(
37
- torch.full(param_shape, self.gamma_init_val)
47
+ self.gamma = nn.Parameter(
48
+ torch.full(shape, self.gamma_init_val, dtype=dtype, device=device)
38
49
  )
39
- self.delta = torch.nn.Parameter(
40
- torch.full(param_shape, self.delta_init_val)
50
+ self.delta = nn.Parameter(
51
+ torch.full(shape, self.delta_init_val, dtype=dtype, device=device)
41
52
  )
42
53
 
54
+ def reset_parameters(self):
55
+ if self.alpha is not None:
56
+ with torch.no_grad():
57
+ self.alpha.fill_(self.alpha_init_val)
58
+ self.beta.fill_(self.beta_init_val)
59
+ self.gamma.fill_(self.gamma_init_val)
60
+ self.delta.fill_(self.delta_init_val)
61
+
43
62
  def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- self._initialize_parameters(x)
63
+ self._ensure_parameters(x)
64
+
45
65
  start = self.beta + self.alpha * (x - self.beta)
46
66
  finish = self.delta + self.gamma * (x - self.delta)
47
67
  out = torch.where(x < self.beta, start, x)
@@ -1,39 +1,58 @@
1
1
  import torch
2
+ from torch import nn
3
+ from torch.nn.parameter import Parameter
4
+ import torch.nn.functional as F
2
5
 
3
6
 
4
- class SmallGALU(torch.nn.Module):
7
+ class SmallGALU(nn.Module):
5
8
  def __init__(self, max_input: float = 1.0):
6
- super(SmallGALU, self).__init__()
9
+ super().__init__()
7
10
  if max_input <= 0:
8
11
  raise ValueError("max_input must be positive.")
9
- self.max_input = max_input
10
- self.alpha = None
11
- self.beta = None
12
+ self.max_input = float(max_input)
13
+ self.register_parameter("alpha", None)
14
+ self.register_parameter("beta", None)
12
15
  self._num_channels = None
13
16
 
14
- def _initialize_parameters(self, x):
17
+ def _initialize_parameters(self, x: torch.Tensor):
15
18
  if x.ndim < 2:
16
19
  raise ValueError(
17
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {x.shape}"
20
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
18
21
  )
19
-
20
- num_channels = x.shape[1]
21
- self._num_channels = num_channels
22
+ self._num_channels = int(x.shape[1])
22
23
  param_shape = [1] * x.ndim
23
- param_shape[1] = num_channels
24
- self.alpha = torch.nn.Parameter(torch.zeros(param_shape))
25
- self.beta = torch.nn.Parameter(torch.zeros(param_shape))
24
+ param_shape[1] = self._num_channels
25
+ device = x.device
26
+ dtype = x.dtype
27
+ self.alpha = Parameter(torch.zeros(param_shape, dtype=dtype, device=device))
28
+ self.beta = Parameter(torch.zeros(param_shape, dtype=dtype, device=device))
29
+
30
+ def reset_parameters(self):
31
+ if self.alpha is not None:
32
+ with torch.no_grad():
33
+ self.alpha.zero_()
34
+ self.beta.zero_()
26
35
 
27
- def forward(self, x):
36
+ def forward(self, x: torch.Tensor):
28
37
  if self.alpha is None:
29
38
  self._initialize_parameters(x)
39
+ else:
40
+ if x.ndim < 2:
41
+ raise ValueError(
42
+ f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
43
+ )
44
+ if x.shape[1] != self._num_channels:
45
+ raise RuntimeError(
46
+ f"SmallGALU was initialized with C={self._num_channels} but got C={x.shape[1]}. "
47
+ "Create a new SmallGALU for a different channel size."
48
+ )
30
49
 
31
- zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
32
50
  x_norm = x / self.max_input
33
- part_prelu = torch.relu(x_norm) + self.alpha * torch.min(x_norm, zero)
51
+ zero = torch.zeros(1, dtype=x.dtype, device=x.device)
52
+ part_prelu = F.relu(x_norm) + self.alpha * torch.minimum(x_norm, zero)
34
53
  part_beta = self.beta * (
35
- torch.relu(1.0 - torch.abs(x_norm - 1.0))
36
- + torch.min(torch.abs(x_norm - 3.0) - 1.0, zero)
54
+ F.relu(1.0 - torch.abs(x_norm - 1.0))
55
+ + torch.minimum(torch.abs(x_norm - 3.0) - 1.0, zero)
37
56
  )
38
57
  z = part_prelu + part_beta
39
58
  return z * self.max_input