hyper-connections 0.0.21__tar.gz → 0.0.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/PKG-INFO +7 -6
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/README.md +5 -5
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/hyper_connections/__init__.py +2 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/hyper_connections/hyper_connections.py +29 -22
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +26 -20
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/pyproject.toml +2 -1
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/.gitignore +0 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/LICENSE +0 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.22}/hyper-connections.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.22
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
+
Requires-Dist: beartype
|
|
37
38
|
Requires-Dist: einops>=0.8.0
|
|
38
39
|
Requires-Dist: torch>=2.3
|
|
39
40
|
Provides-Extra: examples
|
|
@@ -71,9 +72,9 @@ residual = branch(residual) + residual
|
|
|
71
72
|
|
|
72
73
|
# after, say 4 streams in paper
|
|
73
74
|
|
|
74
|
-
from hyper_connections import
|
|
75
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
75
76
|
|
|
76
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
77
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
77
78
|
|
|
78
79
|
# 1. wrap your branch function
|
|
79
80
|
|
|
@@ -110,9 +111,9 @@ residual = branch(residual) + residual
|
|
|
110
111
|
|
|
111
112
|
# after, say 4 streams in paper
|
|
112
113
|
|
|
113
|
-
from hyper_connections import
|
|
114
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
114
115
|
|
|
115
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
116
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
116
117
|
|
|
117
118
|
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
118
119
|
|
|
@@ -140,7 +141,7 @@ residual = reduce_stream(residual)
|
|
|
140
141
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
141
142
|
|
|
142
143
|
```python
|
|
143
|
-
|
|
144
|
+
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
144
145
|
```
|
|
145
146
|
|
|
146
147
|
## Citation
|
|
@@ -28,9 +28,9 @@ residual = branch(residual) + residual
|
|
|
28
28
|
|
|
29
29
|
# after, say 4 streams in paper
|
|
30
30
|
|
|
31
|
-
from hyper_connections import
|
|
31
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
32
32
|
|
|
33
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
33
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
34
34
|
|
|
35
35
|
# 1. wrap your branch function
|
|
36
36
|
|
|
@@ -67,9 +67,9 @@ residual = branch(residual) + residual
|
|
|
67
67
|
|
|
68
68
|
# after, say 4 streams in paper
|
|
69
69
|
|
|
70
|
-
from hyper_connections import
|
|
70
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
71
71
|
|
|
72
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
72
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
73
73
|
|
|
74
74
|
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
75
75
|
|
|
@@ -97,7 +97,7 @@ residual = reduce_stream(residual)
|
|
|
97
97
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
98
98
|
|
|
99
99
|
```python
|
|
100
|
-
|
|
100
|
+
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
101
101
|
```
|
|
102
102
|
|
|
103
103
|
## Citation
|
{hyper_connections-0.0.21 → hyper_connections-0.0.22}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
from beartype import beartype
|
|
16
|
+
|
|
15
17
|
"""
|
|
16
18
|
ein notation:
|
|
17
19
|
b - batch
|
|
@@ -31,6 +33,27 @@ def default(v, d):
|
|
|
31
33
|
def identity(t):
|
|
32
34
|
return t
|
|
33
35
|
|
|
36
|
+
# main functions
|
|
37
|
+
|
|
38
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
39
|
+
|
|
40
|
+
if disable:
|
|
41
|
+
return (identity, identity)
|
|
42
|
+
|
|
43
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
44
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
45
|
+
|
|
46
|
+
return expand_fn, reduce_fn
|
|
47
|
+
|
|
48
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
49
|
+
|
|
50
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
51
|
+
|
|
52
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
53
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
54
|
+
|
|
55
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
56
|
+
|
|
34
57
|
# norms
|
|
35
58
|
|
|
36
59
|
class RMSNorm(Module):
|
|
@@ -47,10 +70,11 @@ class RMSNorm(Module):
|
|
|
47
70
|
# residual base class
|
|
48
71
|
|
|
49
72
|
class Residual(Module):
|
|
73
|
+
@beartype
|
|
50
74
|
def __init__(
|
|
51
75
|
self,
|
|
52
76
|
*args,
|
|
53
|
-
branch = None,
|
|
77
|
+
branch: Module | None = None,
|
|
54
78
|
**kwargs
|
|
55
79
|
):
|
|
56
80
|
super().__init__()
|
|
@@ -97,6 +121,7 @@ class Residual(Module):
|
|
|
97
121
|
# hyper connection residual streams
|
|
98
122
|
|
|
99
123
|
class HyperConnections(Module):
|
|
124
|
+
@beartype
|
|
100
125
|
def __init__(
|
|
101
126
|
self,
|
|
102
127
|
num_residual_streams,
|
|
@@ -146,27 +171,6 @@ class HyperConnections(Module):
|
|
|
146
171
|
|
|
147
172
|
self.channel_first = channel_first
|
|
148
173
|
|
|
149
|
-
@classmethod
|
|
150
|
-
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
151
|
-
|
|
152
|
-
if disable:
|
|
153
|
-
return (identity, identity)
|
|
154
|
-
|
|
155
|
-
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
156
|
-
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
157
|
-
|
|
158
|
-
return expand_fn, reduce_fn
|
|
159
|
-
|
|
160
|
-
@classmethod
|
|
161
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
162
|
-
|
|
163
|
-
hyper_conn_klass = cls if not disable else Residual
|
|
164
|
-
|
|
165
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
166
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
167
|
-
|
|
168
|
-
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
169
|
-
|
|
170
174
|
def width_connection(self, residuals):
|
|
171
175
|
# width connection
|
|
172
176
|
|
|
@@ -244,6 +248,9 @@ class HyperConnections(Module):
|
|
|
244
248
|
|
|
245
249
|
return add_residual_fn(branch_output)
|
|
246
250
|
|
|
251
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
252
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
253
|
+
|
|
247
254
|
# stream embed
|
|
248
255
|
|
|
249
256
|
class StreamEmbed(Module):
|
|
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
from beartype import beartype
|
|
16
|
+
|
|
15
17
|
"""
|
|
16
18
|
ein notation:
|
|
17
19
|
b - batch
|
|
@@ -38,11 +40,32 @@ def divisible_by(num, den):
|
|
|
38
40
|
def identity(t):
|
|
39
41
|
return t
|
|
40
42
|
|
|
43
|
+
# main functions
|
|
44
|
+
|
|
45
|
+
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
46
|
+
if disable:
|
|
47
|
+
return (identity, identity)
|
|
48
|
+
|
|
49
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
50
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
51
|
+
|
|
52
|
+
return expand_fn, reduce_fn
|
|
53
|
+
|
|
54
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
55
|
+
|
|
56
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
57
|
+
|
|
58
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
59
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
60
|
+
|
|
61
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
62
|
+
|
|
41
63
|
# main classes
|
|
42
64
|
|
|
43
65
|
# hyper connection residual streams
|
|
44
66
|
|
|
45
67
|
class HyperConnections(Module):
|
|
68
|
+
@beartype
|
|
46
69
|
def __init__(
|
|
47
70
|
self,
|
|
48
71
|
num_residual_streams,
|
|
@@ -108,26 +131,6 @@ class HyperConnections(Module):
|
|
|
108
131
|
|
|
109
132
|
self.channel_first = channel_first
|
|
110
133
|
|
|
111
|
-
@classmethod
|
|
112
|
-
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
113
|
-
if disable:
|
|
114
|
-
return (identity, identity)
|
|
115
|
-
|
|
116
|
-
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
117
|
-
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
118
|
-
|
|
119
|
-
return expand_fn, reduce_fn
|
|
120
|
-
|
|
121
|
-
@classmethod
|
|
122
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
123
|
-
|
|
124
|
-
hyper_conn_klass = cls if not disable else Residual
|
|
125
|
-
|
|
126
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
127
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
128
|
-
|
|
129
|
-
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
130
|
-
|
|
131
134
|
def width_connection(self, residuals):
|
|
132
135
|
num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
|
|
133
136
|
|
|
@@ -225,3 +228,6 @@ class HyperConnections(Module):
|
|
|
225
228
|
branch_output = torch.cat(branch_outputs)
|
|
226
229
|
|
|
227
230
|
return add_residual_fn(branch_output)
|
|
231
|
+
|
|
232
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
233
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hyper-connections"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.22"
|
|
4
4
|
description = "Hyper-Connections"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -23,6 +23,7 @@ classifiers=[
|
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
dependencies = [
|
|
26
|
+
"beartype",
|
|
26
27
|
"einops>=0.8.0",
|
|
27
28
|
"torch>=2.3",
|
|
28
29
|
]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|