hyper-connections 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  from hyper_connections.hyper_connections import (
2
2
  HyperConnections,
3
+ get_expand_reduce_stream_functions,
4
+ get_init_and_expand_reduce_stream_functions,
3
5
  Residual,
4
6
  StreamEmbed,
5
7
  AttentionPoolReduceStream
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ from beartype import beartype
16
+
15
17
  """
16
18
  ein notation:
17
19
  b - batch
@@ -31,15 +33,48 @@ def default(v, d):
31
33
  def identity(t):
32
34
  return t
33
35
 
36
+ # main functions
37
+
38
+ def get_expand_reduce_stream_functions(num_streams, disable = False):
39
+
40
+ if disable:
41
+ return (identity, identity)
42
+
43
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
44
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
45
+
46
+ return expand_fn, reduce_fn
47
+
48
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
49
+
50
+ hyper_conn_klass = HyperConnections if not disable else Residual
51
+
52
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
53
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
54
+
55
+ return (init_hyper_conn_fn, *expand_reduce_fns)
56
+
57
+ # norms
58
+
59
+ class RMSNorm(Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.scale = dim ** 0.5
63
+ self.gamma = nn.Parameter(torch.zeros(dim))
64
+
65
+ def forward(self, x):
66
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
67
+
34
68
  # main classes
35
69
 
36
70
  # residual base class
37
71
 
38
72
  class Residual(Module):
73
+ @beartype
39
74
  def __init__(
40
75
  self,
41
76
  *args,
42
- branch = None,
77
+ branch: Module | None = None,
43
78
  **kwargs
44
79
  ):
45
80
  super().__init__()
@@ -86,6 +121,7 @@ class Residual(Module):
86
121
  # hyper connection residual streams
87
122
 
88
123
  class HyperConnections(Module):
124
+ @beartype
89
125
  def __init__(
90
126
  self,
91
127
  num_residual_streams,
@@ -108,7 +144,7 @@ class HyperConnections(Module):
108
144
 
109
145
  self.act = nn.Tanh() if tanh else nn.Identity()
110
146
 
111
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
147
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
112
148
 
113
149
  assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
114
150
 
@@ -135,27 +171,6 @@ class HyperConnections(Module):
135
171
 
136
172
  self.channel_first = channel_first
137
173
 
138
- @classmethod
139
- def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
140
-
141
- if disable:
142
- return (identity, identity)
143
-
144
- expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
145
- reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
146
-
147
- return expand_fn, reduce_fn
148
-
149
- @classmethod
150
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
151
-
152
- hyper_conn_klass = cls if not disable else Residual
153
-
154
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
155
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
156
-
157
- return (init_hyper_conn_fn, *expand_reduce_fns)
158
-
159
174
  def width_connection(self, residuals):
160
175
  # width connection
161
176
 
@@ -233,6 +248,9 @@ class HyperConnections(Module):
233
248
 
234
249
  return add_residual_fn(branch_output)
235
250
 
251
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
252
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
253
+
236
254
  # stream embed
237
255
 
238
256
  class StreamEmbed(Module):
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ from beartype import beartype
16
+
15
17
  """
16
18
  ein notation:
17
19
  b - batch
@@ -22,7 +24,7 @@ br - branch functions
22
24
  t - residual streams + num branch inputs
23
25
  """
24
26
 
25
- from hyper_connections.hyper_connections import Residual, StreamEmbed
27
+ from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
26
28
 
27
29
  # helper functions
28
30
 
@@ -38,11 +40,32 @@ def divisible_by(num, den):
38
40
  def identity(t):
39
41
  return t
40
42
 
43
+ # main functions
44
+
45
+ def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
46
+ if disable:
47
+ return (identity, identity)
48
+
49
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
50
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
51
+
52
+ return expand_fn, reduce_fn
53
+
54
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
55
+
56
+ hyper_conn_klass = HyperConnections if not disable else Residual
57
+
58
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
59
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
60
+
61
+ return (init_hyper_conn_fn, *expand_reduce_fns)
62
+
41
63
  # main classes
42
64
 
43
65
  # hyper connection residual streams
44
66
 
45
67
  class HyperConnections(Module):
68
+ @beartype
46
69
  def __init__(
47
70
  self,
48
71
  num_residual_streams,
@@ -74,7 +97,7 @@ class HyperConnections(Module):
74
97
 
75
98
  self.act = nn.Tanh() if tanh else nn.Identity()
76
99
 
77
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
100
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
78
101
 
79
102
  self.num_residual_streams = num_residual_streams
80
103
  self.num_branch_inputs = num_branch_inputs
@@ -108,26 +131,6 @@ class HyperConnections(Module):
108
131
 
109
132
  self.channel_first = channel_first
110
133
 
111
- @classmethod
112
- def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
113
- if disable:
114
- return (identity, identity)
115
-
116
- expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
117
- reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
118
-
119
- return expand_fn, reduce_fn
120
-
121
- @classmethod
122
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
123
-
124
- hyper_conn_klass = cls if not disable else Residual
125
-
126
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
127
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
128
-
129
- return (init_hyper_conn_fn, *expand_reduce_fns)
130
-
131
134
  def width_connection(self, residuals):
132
135
  num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
133
136
 
@@ -225,3 +228,6 @@ class HyperConnections(Module):
225
228
  branch_output = torch.cat(branch_outputs)
226
229
 
227
230
  return add_residual_fn(branch_output)
231
+
232
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
233
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.20
3
+ Version: 0.0.22
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
+ Requires-Dist: beartype
37
38
  Requires-Dist: einops>=0.8.0
38
39
  Requires-Dist: torch>=2.3
39
40
  Provides-Extra: examples
@@ -71,9 +72,9 @@ residual = branch(residual) + residual
71
72
 
72
73
  # after, say 4 streams in paper
73
74
 
74
- from hyper_connections import HyperConnections
75
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
75
76
 
76
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
77
78
 
78
79
  # 1. wrap your branch function
79
80
 
@@ -110,9 +111,9 @@ residual = branch(residual) + residual
110
111
 
111
112
  # after, say 4 streams in paper
112
113
 
113
- from hyper_connections import HyperConnections
114
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
114
115
 
115
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
116
117
 
117
118
  # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
119
 
@@ -140,7 +141,7 @@ residual = reduce_stream(residual)
140
141
  To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
141
142
 
142
143
  ```python
143
- HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
144
+ get_init_and_expand_reduce_stream_functions(4, disable = True)
144
145
  ```
145
146
 
146
147
  ## Citation
@@ -155,3 +156,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
155
156
  url = {https://api.semanticscholar.org/CorpusID:272987528}
156
157
  }
157
158
  ```
159
+
160
+ ```bibtex
161
+ @misc{Rubin2024,
162
+ author = {Ohad Rubin},
163
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
164
+ }
165
+ ```
@@ -0,0 +1,7 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
+ hyper_connections-0.0.22.dist-info/METADATA,sha256=uMrTDUeNCoLpQs89yjMvadzz8r4JLQpky0zQ_Di2H7I,5315
5
+ hyper_connections-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ hyper_connections-0.0.22.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
+ hyper_connections-0.0.22.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- hyper_connections/__init__.py,sha256=wJxbrEXRGmOIjPw8fWP-cUq6CE8bvx95mIlhWifNvYc,135
2
- hyper_connections/hyper_connections.py,sha256=ElPtieRLvVKaVg2Attx1k6esKq1SY2X4AVZbZmsQAOM,9486
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=HbLpt79xcMv_os6brMvDd90t2GOPceliE1YFusR2eJI,7553
4
- hyper_connections-0.0.20.dist-info/METADATA,sha256=erA-d7KNNdzPY76x8IWKd2trv2WuBO9-C2DtH-SoQ_Y,5076
5
- hyper_connections-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- hyper_connections-0.0.20.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
- hyper_connections-0.0.20.dist-info/RECORD,,