hyper-connections 0.0.20__tar.gz → 0.0.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.20
3
+ Version: 0.0.22
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
+ Requires-Dist: beartype
37
38
  Requires-Dist: einops>=0.8.0
38
39
  Requires-Dist: torch>=2.3
39
40
  Provides-Extra: examples
@@ -71,9 +72,9 @@ residual = branch(residual) + residual
71
72
 
72
73
  # after, say 4 streams in paper
73
74
 
74
- from hyper_connections import HyperConnections
75
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
75
76
 
76
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
77
78
 
78
79
  # 1. wrap your branch function
79
80
 
@@ -110,9 +111,9 @@ residual = branch(residual) + residual
110
111
 
111
112
  # after, say 4 streams in paper
112
113
 
113
- from hyper_connections import HyperConnections
114
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
114
115
 
115
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
116
117
 
117
118
  # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
119
 
@@ -140,7 +141,7 @@ residual = reduce_stream(residual)
140
141
  To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
141
142
 
142
143
  ```python
143
- HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
144
+ get_init_and_expand_reduce_stream_functions(4, disable = True)
144
145
  ```
145
146
 
146
147
  ## Citation
@@ -155,3 +156,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
155
156
  url = {https://api.semanticscholar.org/CorpusID:272987528}
156
157
  }
157
158
  ```
159
+
160
+ ```bibtex
161
+ @misc{Rubin2024,
162
+ author = {Ohad Rubin},
163
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
164
+ }
165
+ ```
@@ -28,9 +28,9 @@ residual = branch(residual) + residual
28
28
 
29
29
  # after, say 4 streams in paper
30
30
 
31
- from hyper_connections import HyperConnections
31
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
32
32
 
33
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
33
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
34
34
 
35
35
  # 1. wrap your branch function
36
36
 
@@ -67,9 +67,9 @@ residual = branch(residual) + residual
67
67
 
68
68
  # after, say 4 streams in paper
69
69
 
70
- from hyper_connections import HyperConnections
70
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
71
71
 
72
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
72
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
73
73
 
74
74
  # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
75
75
 
@@ -97,7 +97,7 @@ residual = reduce_stream(residual)
97
97
  To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
98
98
 
99
99
  ```python
100
- HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
100
+ get_init_and_expand_reduce_stream_functions(4, disable = True)
101
101
  ```
102
102
 
103
103
  ## Citation
@@ -112,3 +112,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
112
112
  url = {https://api.semanticscholar.org/CorpusID:272987528}
113
113
  }
114
114
  ```
115
+
116
+ ```bibtex
117
+ @misc{Rubin2024,
118
+ author = {Ohad Rubin},
119
+ url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
120
+ }
121
+ ```
@@ -1,5 +1,7 @@
1
1
  from hyper_connections.hyper_connections import (
2
2
  HyperConnections,
3
+ get_expand_reduce_stream_functions,
4
+ get_init_and_expand_reduce_stream_functions,
3
5
  Residual,
4
6
  StreamEmbed,
5
7
  AttentionPoolReduceStream
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ from beartype import beartype
16
+
15
17
  """
16
18
  ein notation:
17
19
  b - batch
@@ -31,15 +33,48 @@ def default(v, d):
31
33
  def identity(t):
32
34
  return t
33
35
 
36
+ # main functions
37
+
38
+ def get_expand_reduce_stream_functions(num_streams, disable = False):
39
+
40
+ if disable:
41
+ return (identity, identity)
42
+
43
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
44
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
45
+
46
+ return expand_fn, reduce_fn
47
+
48
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
49
+
50
+ hyper_conn_klass = HyperConnections if not disable else Residual
51
+
52
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
53
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
54
+
55
+ return (init_hyper_conn_fn, *expand_reduce_fns)
56
+
57
+ # norms
58
+
59
+ class RMSNorm(Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.scale = dim ** 0.5
63
+ self.gamma = nn.Parameter(torch.zeros(dim))
64
+
65
+ def forward(self, x):
66
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
67
+
34
68
  # main classes
35
69
 
36
70
  # residual base class
37
71
 
38
72
  class Residual(Module):
73
+ @beartype
39
74
  def __init__(
40
75
  self,
41
76
  *args,
42
- branch = None,
77
+ branch: Module | None = None,
43
78
  **kwargs
44
79
  ):
45
80
  super().__init__()
@@ -86,6 +121,7 @@ class Residual(Module):
86
121
  # hyper connection residual streams
87
122
 
88
123
  class HyperConnections(Module):
124
+ @beartype
89
125
  def __init__(
90
126
  self,
91
127
  num_residual_streams,
@@ -108,7 +144,7 @@ class HyperConnections(Module):
108
144
 
109
145
  self.act = nn.Tanh() if tanh else nn.Identity()
110
146
 
111
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
147
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
112
148
 
113
149
  assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
114
150
 
@@ -135,27 +171,6 @@ class HyperConnections(Module):
135
171
 
136
172
  self.channel_first = channel_first
137
173
 
138
- @classmethod
139
- def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
140
-
141
- if disable:
142
- return (identity, identity)
143
-
144
- expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
145
- reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
146
-
147
- return expand_fn, reduce_fn
148
-
149
- @classmethod
150
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
151
-
152
- hyper_conn_klass = cls if not disable else Residual
153
-
154
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
155
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
156
-
157
- return (init_hyper_conn_fn, *expand_reduce_fns)
158
-
159
174
  def width_connection(self, residuals):
160
175
  # width connection
161
176
 
@@ -233,6 +248,9 @@ class HyperConnections(Module):
233
248
 
234
249
  return add_residual_fn(branch_output)
235
250
 
251
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
252
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
253
+
236
254
  # stream embed
237
255
 
238
256
  class StreamEmbed(Module):
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ from beartype import beartype
16
+
15
17
  """
16
18
  ein notation:
17
19
  b - batch
@@ -22,7 +24,7 @@ br - branch functions
22
24
  t - residual streams + num branch inputs
23
25
  """
24
26
 
25
- from hyper_connections.hyper_connections import Residual, StreamEmbed
27
+ from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
26
28
 
27
29
  # helper functions
28
30
 
@@ -38,11 +40,32 @@ def divisible_by(num, den):
38
40
  def identity(t):
39
41
  return t
40
42
 
43
+ # main functions
44
+
45
+ def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
46
+ if disable:
47
+ return (identity, identity)
48
+
49
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
50
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
51
+
52
+ return expand_fn, reduce_fn
53
+
54
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
55
+
56
+ hyper_conn_klass = HyperConnections if not disable else Residual
57
+
58
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
59
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
60
+
61
+ return (init_hyper_conn_fn, *expand_reduce_fns)
62
+
41
63
  # main classes
42
64
 
43
65
  # hyper connection residual streams
44
66
 
45
67
  class HyperConnections(Module):
68
+ @beartype
46
69
  def __init__(
47
70
  self,
48
71
  num_residual_streams,
@@ -74,7 +97,7 @@ class HyperConnections(Module):
74
97
 
75
98
  self.act = nn.Tanh() if tanh else nn.Identity()
76
99
 
77
- self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
100
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
78
101
 
79
102
  self.num_residual_streams = num_residual_streams
80
103
  self.num_branch_inputs = num_branch_inputs
@@ -108,26 +131,6 @@ class HyperConnections(Module):
108
131
 
109
132
  self.channel_first = channel_first
110
133
 
111
- @classmethod
112
- def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
113
- if disable:
114
- return (identity, identity)
115
-
116
- expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
117
- reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
118
-
119
- return expand_fn, reduce_fn
120
-
121
- @classmethod
122
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
123
-
124
- hyper_conn_klass = cls if not disable else Residual
125
-
126
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
127
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
128
-
129
- return (init_hyper_conn_fn, *expand_reduce_fns)
130
-
131
134
  def width_connection(self, residuals):
132
135
  num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
133
136
 
@@ -225,3 +228,6 @@ class HyperConnections(Module):
225
228
  branch_output = torch.cat(branch_outputs)
226
229
 
227
230
  return add_residual_fn(branch_output)
231
+
232
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
233
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.20"
3
+ version = "0.0.22"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,6 +23,7 @@ classifiers=[
23
23
  ]
24
24
 
25
25
  dependencies = [
26
+ "beartype",
26
27
  "einops>=0.8.0",
27
28
  "torch>=2.3",
28
29
  ]