hyper-connections 0.0.20__tar.gz → 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -155,3 +155,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
155
155
  url = {https://api.semanticscholar.org/CorpusID:272987528}
156
156
  }
157
157
  ```
158
+
159
+ ```bibtex
160
+ @misc{Rubin2024,
161
+ author = {Ohad Rubin},
162
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
163
+ }
164
+ ```
@@ -112,3 +112,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
112
112
  url = {https://api.semanticscholar.org/CorpusID:272987528}
113
113
  }
114
114
  ```
115
+
116
+ ```bibtex
117
+ @misc{Rubin2024,
118
+ author = {Ohad Rubin},
119
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
120
+ }
121
+ ```
@@ -31,6 +31,17 @@ def default(v, d):
31
31
  def identity(t):
32
32
  return t
33
33
 
34
+ # norms
35
+
36
+ class RMSNorm(Module):
37
+ def __init__(self, dim):
38
+ super().__init__()
39
+ self.scale = dim ** 0.5
40
+ self.gamma = nn.Parameter(torch.zeros(dim))
41
+
42
+ def forward(self, x):
43
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
44
+
34
45
  # main classes
35
46
 
36
47
  # residual base class
@@ -108,7 +119,7 @@ class HyperConnections(Module):
108
119
 
109
120
  self.act = nn.Tanh() if tanh else nn.Identity()
110
121
 
111
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
122
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
112
123
 
113
124
  assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
114
125
 
@@ -22,7 +22,7 @@ br - branch functions
22
22
  t - residual streams + num branch inputs
23
23
  """
24
24
 
25
- from hyper_connections.hyper_connections import Residual, StreamEmbed
25
+ from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
26
26
 
27
27
  # helper functions
28
28
 
@@ -74,7 +74,7 @@ class HyperConnections(Module):
74
74
 
75
75
  self.act = nn.Tanh() if tanh else nn.Identity()
76
76
 
77
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
77
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
78
78
 
79
79
  self.num_residual_streams = num_residual_streams
80
80
  self.num_branch_inputs = num_branch_inputs
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.20"
3
+ version = "0.0.21"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }