hyper-connections 0.0.20__tar.gz → 0.0.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/PKG-INFO +14 -6
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/README.md +12 -5
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/hyper_connections/__init__.py +2 -0
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/hyper_connections/hyper_connections.py +41 -23
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +28 -22
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/pyproject.toml +2 -1
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/.gitignore +0 -0
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/LICENSE +0 -0
- {hyper_connections-0.0.20 → hyper_connections-0.0.22}/hyper-connections.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.22
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
+
Requires-Dist: beartype
|
|
37
38
|
Requires-Dist: einops>=0.8.0
|
|
38
39
|
Requires-Dist: torch>=2.3
|
|
39
40
|
Provides-Extra: examples
|
|
@@ -71,9 +72,9 @@ residual = branch(residual) + residual
|
|
|
71
72
|
|
|
72
73
|
# after, say 4 streams in paper
|
|
73
74
|
|
|
74
|
-
from hyper_connections import
|
|
75
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
75
76
|
|
|
76
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
77
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
77
78
|
|
|
78
79
|
# 1. wrap your branch function
|
|
79
80
|
|
|
@@ -110,9 +111,9 @@ residual = branch(residual) + residual
|
|
|
110
111
|
|
|
111
112
|
# after, say 4 streams in paper
|
|
112
113
|
|
|
113
|
-
from hyper_connections import
|
|
114
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
114
115
|
|
|
115
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
116
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
116
117
|
|
|
117
118
|
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
118
119
|
|
|
@@ -140,7 +141,7 @@ residual = reduce_stream(residual)
|
|
|
140
141
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
141
142
|
|
|
142
143
|
```python
|
|
143
|
-
|
|
144
|
+
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
144
145
|
```
|
|
145
146
|
|
|
146
147
|
## Citation
|
|
@@ -155,3 +156,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
155
156
|
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
156
157
|
}
|
|
157
158
|
```
|
|
159
|
+
|
|
160
|
+
```bibtex
|
|
161
|
+
@misc{Rubin2024,
|
|
162
|
+
author = {Ohad Rubin},
|
|
163
|
+
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
164
|
+
}
|
|
165
|
+
```
|
|
@@ -28,9 +28,9 @@ residual = branch(residual) + residual
|
|
|
28
28
|
|
|
29
29
|
# after, say 4 streams in paper
|
|
30
30
|
|
|
31
|
-
from hyper_connections import
|
|
31
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
32
32
|
|
|
33
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
33
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
34
34
|
|
|
35
35
|
# 1. wrap your branch function
|
|
36
36
|
|
|
@@ -67,9 +67,9 @@ residual = branch(residual) + residual
|
|
|
67
67
|
|
|
68
68
|
# after, say 4 streams in paper
|
|
69
69
|
|
|
70
|
-
from hyper_connections import
|
|
70
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
71
71
|
|
|
72
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
72
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
73
73
|
|
|
74
74
|
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
75
75
|
|
|
@@ -97,7 +97,7 @@ residual = reduce_stream(residual)
|
|
|
97
97
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
98
98
|
|
|
99
99
|
```python
|
|
100
|
-
|
|
100
|
+
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
101
101
|
```
|
|
102
102
|
|
|
103
103
|
## Citation
|
|
@@ -112,3 +112,10 @@ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
|
112
112
|
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
113
113
|
}
|
|
114
114
|
```
|
|
115
|
+
|
|
116
|
+
```bibtex
|
|
117
|
+
@misc{Rubin2024,
|
|
118
|
+
author = {Ohad Rubin},
|
|
119
|
+
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
|
120
|
+
}
|
|
121
|
+
```
|
{hyper_connections-0.0.20 → hyper_connections-0.0.22}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
from beartype import beartype
|
|
16
|
+
|
|
15
17
|
"""
|
|
16
18
|
ein notation:
|
|
17
19
|
b - batch
|
|
@@ -31,15 +33,48 @@ def default(v, d):
|
|
|
31
33
|
def identity(t):
|
|
32
34
|
return t
|
|
33
35
|
|
|
36
|
+
# main functions
|
|
37
|
+
|
|
38
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
39
|
+
|
|
40
|
+
if disable:
|
|
41
|
+
return (identity, identity)
|
|
42
|
+
|
|
43
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
44
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
45
|
+
|
|
46
|
+
return expand_fn, reduce_fn
|
|
47
|
+
|
|
48
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
49
|
+
|
|
50
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
51
|
+
|
|
52
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
53
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
54
|
+
|
|
55
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
56
|
+
|
|
57
|
+
# norms
|
|
58
|
+
|
|
59
|
+
class RMSNorm(Module):
|
|
60
|
+
def __init__(self, dim):
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.scale = dim ** 0.5
|
|
63
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
67
|
+
|
|
34
68
|
# main classes
|
|
35
69
|
|
|
36
70
|
# residual base class
|
|
37
71
|
|
|
38
72
|
class Residual(Module):
|
|
73
|
+
@beartype
|
|
39
74
|
def __init__(
|
|
40
75
|
self,
|
|
41
76
|
*args,
|
|
42
|
-
branch = None,
|
|
77
|
+
branch: Module | None = None,
|
|
43
78
|
**kwargs
|
|
44
79
|
):
|
|
45
80
|
super().__init__()
|
|
@@ -86,6 +121,7 @@ class Residual(Module):
|
|
|
86
121
|
# hyper connection residual streams
|
|
87
122
|
|
|
88
123
|
class HyperConnections(Module):
|
|
124
|
+
@beartype
|
|
89
125
|
def __init__(
|
|
90
126
|
self,
|
|
91
127
|
num_residual_streams,
|
|
@@ -108,7 +144,7 @@ class HyperConnections(Module):
|
|
|
108
144
|
|
|
109
145
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
110
146
|
|
|
111
|
-
self.norm =
|
|
147
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
112
148
|
|
|
113
149
|
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
114
150
|
|
|
@@ -135,27 +171,6 @@ class HyperConnections(Module):
|
|
|
135
171
|
|
|
136
172
|
self.channel_first = channel_first
|
|
137
173
|
|
|
138
|
-
@classmethod
|
|
139
|
-
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
140
|
-
|
|
141
|
-
if disable:
|
|
142
|
-
return (identity, identity)
|
|
143
|
-
|
|
144
|
-
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
145
|
-
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
146
|
-
|
|
147
|
-
return expand_fn, reduce_fn
|
|
148
|
-
|
|
149
|
-
@classmethod
|
|
150
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
151
|
-
|
|
152
|
-
hyper_conn_klass = cls if not disable else Residual
|
|
153
|
-
|
|
154
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
155
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
156
|
-
|
|
157
|
-
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
158
|
-
|
|
159
174
|
def width_connection(self, residuals):
|
|
160
175
|
# width connection
|
|
161
176
|
|
|
@@ -233,6 +248,9 @@ class HyperConnections(Module):
|
|
|
233
248
|
|
|
234
249
|
return add_residual_fn(branch_output)
|
|
235
250
|
|
|
251
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
252
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
253
|
+
|
|
236
254
|
# stream embed
|
|
237
255
|
|
|
238
256
|
class StreamEmbed(Module):
|
|
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
from beartype import beartype
|
|
16
|
+
|
|
15
17
|
"""
|
|
16
18
|
ein notation:
|
|
17
19
|
b - batch
|
|
@@ -22,7 +24,7 @@ br - branch functions
|
|
|
22
24
|
t - residual streams + num branch inputs
|
|
23
25
|
"""
|
|
24
26
|
|
|
25
|
-
from hyper_connections.hyper_connections import Residual, StreamEmbed
|
|
27
|
+
from hyper_connections.hyper_connections import Residual, StreamEmbed, RMSNorm
|
|
26
28
|
|
|
27
29
|
# helper functions
|
|
28
30
|
|
|
@@ -38,11 +40,32 @@ def divisible_by(num, den):
|
|
|
38
40
|
def identity(t):
|
|
39
41
|
return t
|
|
40
42
|
|
|
43
|
+
# main functions
|
|
44
|
+
|
|
45
|
+
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
46
|
+
if disable:
|
|
47
|
+
return (identity, identity)
|
|
48
|
+
|
|
49
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
50
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
51
|
+
|
|
52
|
+
return expand_fn, reduce_fn
|
|
53
|
+
|
|
54
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
55
|
+
|
|
56
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
57
|
+
|
|
58
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
59
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
60
|
+
|
|
61
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
62
|
+
|
|
41
63
|
# main classes
|
|
42
64
|
|
|
43
65
|
# hyper connection residual streams
|
|
44
66
|
|
|
45
67
|
class HyperConnections(Module):
|
|
68
|
+
@beartype
|
|
46
69
|
def __init__(
|
|
47
70
|
self,
|
|
48
71
|
num_residual_streams,
|
|
@@ -74,7 +97,7 @@ class HyperConnections(Module):
|
|
|
74
97
|
|
|
75
98
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
76
99
|
|
|
77
|
-
self.norm =
|
|
100
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
78
101
|
|
|
79
102
|
self.num_residual_streams = num_residual_streams
|
|
80
103
|
self.num_branch_inputs = num_branch_inputs
|
|
@@ -108,26 +131,6 @@ class HyperConnections(Module):
|
|
|
108
131
|
|
|
109
132
|
self.channel_first = channel_first
|
|
110
133
|
|
|
111
|
-
@classmethod
|
|
112
|
-
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
113
|
-
if disable:
|
|
114
|
-
return (identity, identity)
|
|
115
|
-
|
|
116
|
-
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
117
|
-
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
118
|
-
|
|
119
|
-
return expand_fn, reduce_fn
|
|
120
|
-
|
|
121
|
-
@classmethod
|
|
122
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
123
|
-
|
|
124
|
-
hyper_conn_klass = cls if not disable else Residual
|
|
125
|
-
|
|
126
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
127
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
128
|
-
|
|
129
|
-
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
130
|
-
|
|
131
134
|
def width_connection(self, residuals):
|
|
132
135
|
num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
|
|
133
136
|
|
|
@@ -225,3 +228,6 @@ class HyperConnections(Module):
|
|
|
225
228
|
branch_output = torch.cat(branch_outputs)
|
|
226
229
|
|
|
227
230
|
return add_residual_fn(branch_output)
|
|
231
|
+
|
|
232
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
233
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hyper-connections"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.22"
|
|
4
4
|
description = "Hyper-Connections"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -23,6 +23,7 @@ classifiers=[
|
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
dependencies = [
|
|
26
|
+
"beartype",
|
|
26
27
|
"einops>=0.8.0",
|
|
27
28
|
"torch>=2.3",
|
|
28
29
|
]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|