hyper-connections 0.0.19__tar.gz → 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -155,3 +155,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
155
155
  url = {https://api.semanticscholar.org/CorpusID:272987528}
156
156
  }
157
157
  ```
158
+
159
+ ```bibtex
160
+ @misc{Rubin2024,
161
+ author = {Ohad Rubin},
162
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
163
+ }
164
+ ```
@@ -112,3 +112,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
112
112
  url = {https://api.semanticscholar.org/CorpusID:272987528}
113
113
  }
114
114
  ```
115
+
116
+ ```bibtex
117
+ @misc{Rubin2024,
118
+ author = {Ohad Rubin},
119
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
120
+ }
121
+ ```
@@ -12,6 +12,14 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ """
16
+ ein notation:
17
+ b - batch
18
+ d - feature dimension
19
+ s - residual streams
20
+ t - residual streams + num branch inputs
21
+ """
22
+
15
23
  # helper functions
16
24
 
17
25
  def exists(v):
@@ -23,6 +31,17 @@ def default(v, d):
23
31
  def identity(t):
24
32
  return t
25
33
 
34
+ # norms
35
+
36
+ class RMSNorm(Module):
37
+ def __init__(self, dim):
38
+ super().__init__()
39
+ self.scale = dim ** 0.5
40
+ self.gamma = nn.Parameter(torch.zeros(dim))
41
+
42
+ def forward(self, x):
43
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
44
+
26
45
  # main classes
27
46
 
28
47
  # residual base class
@@ -100,7 +119,9 @@ class HyperConnections(Module):
100
119
 
101
120
  self.act = nn.Tanh() if tanh else nn.Identity()
102
121
 
103
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
122
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
123
+
124
+ assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
104
125
 
105
126
  self.num_residual_streams = num_residual_streams
106
127
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
@@ -12,7 +12,17 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
- from hyper_connections.hyper_connections import Residual, StreamEmbed
15
+ """
16
+ ein notation:
17
+ b - batch
18
+ d - feature dimension
19
+ s - residual streams
20
+ i - branch inputs
21
+ br - branch functions
22
+ t - residual streams + num branch inputs
23
+ """
24
+
25
+ from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
16
26
 
17
27
  # helper functions
18
28
 
@@ -64,7 +74,7 @@ class HyperConnections(Module):
64
74
 
65
75
  self.act = nn.Tanh() if tanh else nn.Identity()
66
76
 
67
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
77
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
68
78
 
69
79
  self.num_residual_streams = num_residual_streams
70
80
  self.num_branch_inputs = num_branch_inputs
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.19"
3
+ version = "0.0.21"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }