hyper-connections 0.0.21__tar.gz → 0.0.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.21
3
+ Version: 0.0.22
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
+ Requires-Dist: beartype
37
38
  Requires-Dist: einops>=0.8.0
38
39
  Requires-Dist: torch>=2.3
39
40
  Provides-Extra: examples
@@ -71,9 +72,9 @@ residual = branch(residual) + residual
71
72
 
72
73
  # after, say 4 streams in paper
73
74
 
74
- from hyper_connections import HyperConnections
75
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
75
76
 
76
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
77
78
 
78
79
  # 1. wrap your branch function
79
80
 
@@ -110,9 +111,9 @@ residual = branch(residual) + residual
110
111
 
111
112
  # after, say 4 streams in paper
112
113
 
113
- from hyper_connections import HyperConnections
114
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
114
115
 
115
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
116
117
 
117
118
  # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
119
 
@@ -140,7 +141,7 @@ residual = reduce_stream(residual)
140
141
  To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
141
142
 
142
143
  ```python
143
- HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
144
+ get_init_and_expand_reduce_stream_functions(4, disable = True)
144
145
  ```
145
146
 
146
147
  ## Citation
@@ -28,9 +28,9 @@ residual = branch(residual) + residual
28
28
 
29
29
  # after, say 4 streams in paper
30
30
 
31
- from hyper_connections import HyperConnections
31
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
32
32
 
33
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
33
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
34
34
 
35
35
  # 1. wrap your branch function
36
36
 
@@ -67,9 +67,9 @@ residual = branch(residual) + residual
67
67
 
68
68
  # after, say 4 streams in paper
69
69
 
70
- from hyper_connections import HyperConnections
70
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
71
71
 
72
- init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
72
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
73
73
 
74
74
  # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
75
75
 
@@ -97,7 +97,7 @@ residual = reduce_stream(residual)
97
97
  To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
98
98
 
99
99
  ```python
100
- HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
100
+ get_init_and_expand_reduce_stream_functions(4, disable = True)
101
101
  ```
102
102
 
103
103
  ## Citation
@@ -1,5 +1,7 @@
1
1
  from hyper_connections.hyper_connections import (
2
2
  HyperConnections,
3
+ get_expand_reduce_stream_functions,
4
+ get_init_and_expand_reduce_stream_functions,
3
5
  Residual,
4
6
  StreamEmbed,
5
7
  AttentionPoolReduceStream
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ from beartype import beartype
16
+
15
17
  """
16
18
  ein notation:
17
19
  b - batch
@@ -31,6 +33,27 @@ def default(v, d):
31
33
  def identity(t):
32
34
  return t
33
35
 
36
+ # main functions
37
+
38
+ def get_expand_reduce_stream_functions(num_streams, disable = False):
39
+
40
+ if disable:
41
+ return (identity, identity)
42
+
43
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
44
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
45
+
46
+ return expand_fn, reduce_fn
47
+
48
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
49
+
50
+ hyper_conn_klass = HyperConnections if not disable else Residual
51
+
52
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
53
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
54
+
55
+ return (init_hyper_conn_fn, *expand_reduce_fns)
56
+
34
57
  # norms
35
58
 
36
59
  class RMSNorm(Module):
@@ -47,10 +70,11 @@ class RMSNorm(Module):
47
70
  # residual base class
48
71
 
49
72
  class Residual(Module):
73
+ @beartype
50
74
  def __init__(
51
75
  self,
52
76
  *args,
53
- branch = None,
77
+ branch: Module | None = None,
54
78
  **kwargs
55
79
  ):
56
80
  super().__init__()
@@ -97,6 +121,7 @@ class Residual(Module):
97
121
  # hyper connection residual streams
98
122
 
99
123
  class HyperConnections(Module):
124
+ @beartype
100
125
  def __init__(
101
126
  self,
102
127
  num_residual_streams,
@@ -146,27 +171,6 @@ class HyperConnections(Module):
146
171
 
147
172
  self.channel_first = channel_first
148
173
 
149
- @classmethod
150
- def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
151
-
152
- if disable:
153
- return (identity, identity)
154
-
155
- expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
156
- reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
157
-
158
- return expand_fn, reduce_fn
159
-
160
- @classmethod
161
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
162
-
163
- hyper_conn_klass = cls if not disable else Residual
164
-
165
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
166
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
167
-
168
- return (init_hyper_conn_fn, *expand_reduce_fns)
169
-
170
174
  def width_connection(self, residuals):
171
175
  # width connection
172
176
 
@@ -244,6 +248,9 @@ class HyperConnections(Module):
244
248
 
245
249
  return add_residual_fn(branch_output)
246
250
 
251
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
252
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
253
+
247
254
  # stream embed
248
255
 
249
256
  class StreamEmbed(Module):
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
+ from beartype import beartype
16
+
15
17
  """
16
18
  ein notation:
17
19
  b - batch
@@ -38,11 +40,32 @@ def divisible_by(num, den):
38
40
  def identity(t):
39
41
  return t
40
42
 
43
+ # main functions
44
+
45
+ def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
46
+ if disable:
47
+ return (identity, identity)
48
+
49
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
50
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
51
+
52
+ return expand_fn, reduce_fn
53
+
54
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
55
+
56
+ hyper_conn_klass = HyperConnections if not disable else Residual
57
+
58
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
59
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
60
+
61
+ return (init_hyper_conn_fn, *expand_reduce_fns)
62
+
41
63
  # main classes
42
64
 
43
65
  # hyper connection residual streams
44
66
 
45
67
  class HyperConnections(Module):
68
+ @beartype
46
69
  def __init__(
47
70
  self,
48
71
  num_residual_streams,
@@ -108,26 +131,6 @@ class HyperConnections(Module):
108
131
 
109
132
  self.channel_first = channel_first
110
133
 
111
- @classmethod
112
- def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
113
- if disable:
114
- return (identity, identity)
115
-
116
- expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
117
- reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
118
-
119
- return expand_fn, reduce_fn
120
-
121
- @classmethod
122
- def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
123
-
124
- hyper_conn_klass = cls if not disable else Residual
125
-
126
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
127
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
128
-
129
- return (init_hyper_conn_fn, *expand_reduce_fns)
130
-
131
134
  def width_connection(self, residuals):
132
135
  num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
133
136
 
@@ -225,3 +228,6 @@ class HyperConnections(Module):
225
228
  branch_output = torch.cat(branch_outputs)
226
229
 
227
230
  return add_residual_fn(branch_output)
231
+
232
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
233
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.21"
3
+ version = "0.0.22"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,6 +23,7 @@ classifiers=[
23
23
  ]
24
24
 
25
25
  dependencies = [
26
+ "beartype",
26
27
  "einops>=0.8.0",
27
28
  "torch>=2.3",
28
29
  ]