hyper-connections 0.0.21__tar.gz → 0.0.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/PKG-INFO +7 -6
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/README.md +5 -5
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/hyper_connections/__init__.py +2 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/hyper_connections/hyper_connections.py +29 -22
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +26 -20
- hyper_connections-0.0.23/hyper_connections/hyper_connections_with_multi_input_streams.py +338 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/pyproject.toml +2 -1
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/.gitignore +0 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/LICENSE +0 -0
- {hyper_connections-0.0.21 → hyper_connections-0.0.23}/hyper-connections.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.23
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
+
Requires-Dist: beartype
|
|
37
38
|
Requires-Dist: einops>=0.8.0
|
|
38
39
|
Requires-Dist: torch>=2.3
|
|
39
40
|
Provides-Extra: examples
|
|
@@ -71,9 +72,9 @@ residual = branch(residual) + residual
|
|
|
71
72
|
|
|
72
73
|
# after, say 4 streams in paper
|
|
73
74
|
|
|
74
|
-
from hyper_connections import
|
|
75
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
75
76
|
|
|
76
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
77
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
77
78
|
|
|
78
79
|
# 1. wrap your branch function
|
|
79
80
|
|
|
@@ -110,9 +111,9 @@ residual = branch(residual) + residual
|
|
|
110
111
|
|
|
111
112
|
# after, say 4 streams in paper
|
|
112
113
|
|
|
113
|
-
from hyper_connections import
|
|
114
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
114
115
|
|
|
115
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
116
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
116
117
|
|
|
117
118
|
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
118
119
|
|
|
@@ -140,7 +141,7 @@ residual = reduce_stream(residual)
|
|
|
140
141
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
141
142
|
|
|
142
143
|
```python
|
|
143
|
-
|
|
144
|
+
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
144
145
|
```
|
|
145
146
|
|
|
146
147
|
## Citation
|
|
@@ -28,9 +28,9 @@ residual = branch(residual) + residual
|
|
|
28
28
|
|
|
29
29
|
# after, say 4 streams in paper
|
|
30
30
|
|
|
31
|
-
from hyper_connections import
|
|
31
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
32
32
|
|
|
33
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
33
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
34
34
|
|
|
35
35
|
# 1. wrap your branch function
|
|
36
36
|
|
|
@@ -67,9 +67,9 @@ residual = branch(residual) + residual
|
|
|
67
67
|
|
|
68
68
|
# after, say 4 streams in paper
|
|
69
69
|
|
|
70
|
-
from hyper_connections import
|
|
70
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
71
71
|
|
|
72
|
-
init_hyper_conn, expand_stream, reduce_stream =
|
|
72
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
|
|
73
73
|
|
|
74
74
|
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
75
75
|
|
|
@@ -97,7 +97,7 @@ residual = reduce_stream(residual)
|
|
|
97
97
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
98
98
|
|
|
99
99
|
```python
|
|
100
|
-
|
|
100
|
+
get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
101
101
|
```
|
|
102
102
|
|
|
103
103
|
## Citation
|
{hyper_connections-0.0.21 → hyper_connections-0.0.23}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
from beartype import beartype
|
|
16
|
+
|
|
15
17
|
"""
|
|
16
18
|
ein notation:
|
|
17
19
|
b - batch
|
|
@@ -31,6 +33,27 @@ def default(v, d):
|
|
|
31
33
|
def identity(t):
|
|
32
34
|
return t
|
|
33
35
|
|
|
36
|
+
# main functions
|
|
37
|
+
|
|
38
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
39
|
+
|
|
40
|
+
if disable:
|
|
41
|
+
return (identity, identity)
|
|
42
|
+
|
|
43
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
44
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
45
|
+
|
|
46
|
+
return expand_fn, reduce_fn
|
|
47
|
+
|
|
48
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
49
|
+
|
|
50
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
51
|
+
|
|
52
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
53
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
54
|
+
|
|
55
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
56
|
+
|
|
34
57
|
# norms
|
|
35
58
|
|
|
36
59
|
class RMSNorm(Module):
|
|
@@ -47,10 +70,11 @@ class RMSNorm(Module):
|
|
|
47
70
|
# residual base class
|
|
48
71
|
|
|
49
72
|
class Residual(Module):
|
|
73
|
+
@beartype
|
|
50
74
|
def __init__(
|
|
51
75
|
self,
|
|
52
76
|
*args,
|
|
53
|
-
branch = None,
|
|
77
|
+
branch: Module | None = None,
|
|
54
78
|
**kwargs
|
|
55
79
|
):
|
|
56
80
|
super().__init__()
|
|
@@ -97,6 +121,7 @@ class Residual(Module):
|
|
|
97
121
|
# hyper connection residual streams
|
|
98
122
|
|
|
99
123
|
class HyperConnections(Module):
|
|
124
|
+
@beartype
|
|
100
125
|
def __init__(
|
|
101
126
|
self,
|
|
102
127
|
num_residual_streams,
|
|
@@ -146,27 +171,6 @@ class HyperConnections(Module):
|
|
|
146
171
|
|
|
147
172
|
self.channel_first = channel_first
|
|
148
173
|
|
|
149
|
-
@classmethod
|
|
150
|
-
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
151
|
-
|
|
152
|
-
if disable:
|
|
153
|
-
return (identity, identity)
|
|
154
|
-
|
|
155
|
-
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
156
|
-
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
157
|
-
|
|
158
|
-
return expand_fn, reduce_fn
|
|
159
|
-
|
|
160
|
-
@classmethod
|
|
161
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
162
|
-
|
|
163
|
-
hyper_conn_klass = cls if not disable else Residual
|
|
164
|
-
|
|
165
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
166
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
167
|
-
|
|
168
|
-
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
169
|
-
|
|
170
174
|
def width_connection(self, residuals):
|
|
171
175
|
# width connection
|
|
172
176
|
|
|
@@ -244,6 +248,9 @@ class HyperConnections(Module):
|
|
|
244
248
|
|
|
245
249
|
return add_residual_fn(branch_output)
|
|
246
250
|
|
|
251
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
252
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
253
|
+
|
|
247
254
|
# stream embed
|
|
248
255
|
|
|
249
256
|
class StreamEmbed(Module):
|
|
@@ -12,6 +12,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
from beartype import beartype
|
|
16
|
+
|
|
15
17
|
"""
|
|
16
18
|
ein notation:
|
|
17
19
|
b - batch
|
|
@@ -38,11 +40,32 @@ def divisible_by(num, den):
|
|
|
38
40
|
def identity(t):
|
|
39
41
|
return t
|
|
40
42
|
|
|
43
|
+
# main functions
|
|
44
|
+
|
|
45
|
+
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
46
|
+
if disable:
|
|
47
|
+
return (identity, identity)
|
|
48
|
+
|
|
49
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
50
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
51
|
+
|
|
52
|
+
return expand_fn, reduce_fn
|
|
53
|
+
|
|
54
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
55
|
+
|
|
56
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
57
|
+
|
|
58
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
59
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
60
|
+
|
|
61
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
62
|
+
|
|
41
63
|
# main classes
|
|
42
64
|
|
|
43
65
|
# hyper connection residual streams
|
|
44
66
|
|
|
45
67
|
class HyperConnections(Module):
|
|
68
|
+
@beartype
|
|
46
69
|
def __init__(
|
|
47
70
|
self,
|
|
48
71
|
num_residual_streams,
|
|
@@ -108,26 +131,6 @@ class HyperConnections(Module):
|
|
|
108
131
|
|
|
109
132
|
self.channel_first = channel_first
|
|
110
133
|
|
|
111
|
-
@classmethod
|
|
112
|
-
def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
113
|
-
if disable:
|
|
114
|
-
return (identity, identity)
|
|
115
|
-
|
|
116
|
-
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
117
|
-
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
118
|
-
|
|
119
|
-
return expand_fn, reduce_fn
|
|
120
|
-
|
|
121
|
-
@classmethod
|
|
122
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
123
|
-
|
|
124
|
-
hyper_conn_klass = cls if not disable else Residual
|
|
125
|
-
|
|
126
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
127
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
128
|
-
|
|
129
|
-
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
130
|
-
|
|
131
134
|
def width_connection(self, residuals):
|
|
132
135
|
num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
|
|
133
136
|
|
|
@@ -225,3 +228,6 @@ class HyperConnections(Module):
|
|
|
225
228
|
branch_output = torch.cat(branch_outputs)
|
|
226
229
|
|
|
227
230
|
return add_residual_fn(branch_output)
|
|
231
|
+
|
|
232
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
233
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.nn import Module, ModuleList
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Rearrange
|
|
15
|
+
|
|
16
|
+
from beartype import beartype
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
ein notation:
|
|
20
|
+
b - batch
|
|
21
|
+
d - feature dimension
|
|
22
|
+
s - residual streams
|
|
23
|
+
t - residual streams + num branch inputs
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# helper functions
|
|
27
|
+
|
|
28
|
+
def exists(v):
|
|
29
|
+
return v is not None
|
|
30
|
+
|
|
31
|
+
def default(v, d):
|
|
32
|
+
return v if exists(v) else d
|
|
33
|
+
|
|
34
|
+
def identity(t):
|
|
35
|
+
return t
|
|
36
|
+
|
|
37
|
+
# main functions
|
|
38
|
+
|
|
39
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
40
|
+
|
|
41
|
+
if disable:
|
|
42
|
+
return (identity, identity)
|
|
43
|
+
|
|
44
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
45
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
46
|
+
|
|
47
|
+
return expand_fn, reduce_fn
|
|
48
|
+
|
|
49
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
50
|
+
|
|
51
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
52
|
+
|
|
53
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
54
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
55
|
+
|
|
56
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
57
|
+
|
|
58
|
+
# norms
|
|
59
|
+
|
|
60
|
+
class RMSNorm(Module):
|
|
61
|
+
def __init__(self, dim):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.scale = dim ** 0.5
|
|
64
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
65
|
+
|
|
66
|
+
def forward(self, x):
|
|
67
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
68
|
+
|
|
69
|
+
class ProjActScale(Module):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
dim,
|
|
73
|
+
dim_out,
|
|
74
|
+
activation: Module = nn.Identity(),
|
|
75
|
+
scale_init: float = 1e-2,
|
|
76
|
+
squeeze_output = False
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
dim_out = default(dim_out, dim)
|
|
80
|
+
|
|
81
|
+
self.proj = nn.Linear(dim, dim_out, bias = False)
|
|
82
|
+
nn.init.zeros_(self.proj.weight)
|
|
83
|
+
|
|
84
|
+
self.act = activation
|
|
85
|
+
self.scale = nn.Parameter(torch.ones(()) * scale_init)
|
|
86
|
+
self.maybe_squeeze = Rearrange('... 1 -> ...') if squeeze_output else nn.Identity()
|
|
87
|
+
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
out = self.proj(x)
|
|
90
|
+
out = self.act(out)
|
|
91
|
+
return self.maybe_squeeze(out * self.scale)
|
|
92
|
+
|
|
93
|
+
# main classes
|
|
94
|
+
|
|
95
|
+
# residual base class
|
|
96
|
+
|
|
97
|
+
class Residual(Module):
|
|
98
|
+
@beartype
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
*args,
|
|
102
|
+
branch: Module | None = None,
|
|
103
|
+
**kwargs
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.branch = branch
|
|
107
|
+
|
|
108
|
+
def width_connection(self, residuals, *args, **kwargs):
|
|
109
|
+
return residuals, residuals, dict()
|
|
110
|
+
|
|
111
|
+
def depth_connection(self, branch_output, residuals):
|
|
112
|
+
return branch_output + residuals
|
|
113
|
+
|
|
114
|
+
def decorate_branch(self, branch: Callable):
|
|
115
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
116
|
+
|
|
117
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
118
|
+
branch_input, add_residual = self.forward(residual, *args, **kwargs)
|
|
119
|
+
|
|
120
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
121
|
+
|
|
122
|
+
residual = add_residual(branch_output)
|
|
123
|
+
|
|
124
|
+
return residual
|
|
125
|
+
|
|
126
|
+
return forward_and_add_residual
|
|
127
|
+
|
|
128
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
129
|
+
|
|
130
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
|
|
131
|
+
|
|
132
|
+
def add_residual_fn(branch_out):
|
|
133
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
134
|
+
|
|
135
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
136
|
+
|
|
137
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
138
|
+
|
|
139
|
+
if not exists(self.branch):
|
|
140
|
+
return branch_input, add_residual_fn
|
|
141
|
+
|
|
142
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
143
|
+
|
|
144
|
+
return add_residual_fn(branch_output)
|
|
145
|
+
|
|
146
|
+
# hyper connection with multiple input streams
|
|
147
|
+
|
|
148
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
|
|
149
|
+
|
|
150
|
+
class HyperConnections(Module):
|
|
151
|
+
@beartype
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
num_residual_streams,
|
|
155
|
+
*,
|
|
156
|
+
dim,
|
|
157
|
+
additional_input_paths: (
|
|
158
|
+
list[InputPathType |
|
|
159
|
+
tuple[InputPathType, int]] # if the second residual has different dimensions, second tuple element is the dimension
|
|
160
|
+
| None
|
|
161
|
+
) = None,
|
|
162
|
+
branch: Module | None = None,
|
|
163
|
+
layer_index = None,
|
|
164
|
+
tanh = True,
|
|
165
|
+
channel_first = False,
|
|
166
|
+
dropout = 0.
|
|
167
|
+
):
|
|
168
|
+
"""
|
|
169
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
170
|
+
"""
|
|
171
|
+
super().__init__()
|
|
172
|
+
|
|
173
|
+
self.branch = branch
|
|
174
|
+
act = nn.Tanh() if tanh else nn.Identity()
|
|
175
|
+
|
|
176
|
+
self.num_residual_streams = num_residual_streams
|
|
177
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
178
|
+
|
|
179
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
180
|
+
|
|
181
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
182
|
+
|
|
183
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
184
|
+
|
|
185
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
|
+
init_alpha0[init_residual_index, 0] = 1.
|
|
187
|
+
|
|
188
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
189
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
|
+
|
|
191
|
+
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
192
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
193
|
+
|
|
194
|
+
# additional input residual streams
|
|
195
|
+
|
|
196
|
+
additional_input_paths = default(additional_input_paths, [])
|
|
197
|
+
additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
|
|
198
|
+
|
|
199
|
+
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
200
|
+
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
201
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
202
|
+
|
|
203
|
+
self.additional_input_paths = additional_input_paths
|
|
204
|
+
|
|
205
|
+
# dropouts
|
|
206
|
+
|
|
207
|
+
self.dropout = nn.Dropout(dropout)
|
|
208
|
+
|
|
209
|
+
# channel first option
|
|
210
|
+
|
|
211
|
+
self.channel_first = channel_first
|
|
212
|
+
|
|
213
|
+
def width_connection(
|
|
214
|
+
self,
|
|
215
|
+
residuals,
|
|
216
|
+
*branch_args,
|
|
217
|
+
**branch_kwargs
|
|
218
|
+
):
|
|
219
|
+
|
|
220
|
+
transpose = self.channel_first
|
|
221
|
+
|
|
222
|
+
# width connection
|
|
223
|
+
|
|
224
|
+
if transpose:
|
|
225
|
+
residuals = rearrange(residuals, 'b d ... -> b ... d')
|
|
226
|
+
|
|
227
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
|
|
228
|
+
|
|
229
|
+
normed = self.norm(residuals)
|
|
230
|
+
|
|
231
|
+
# alpha for weighted sum of residuals going into branch
|
|
232
|
+
|
|
233
|
+
dynamic_alpha = self.dynamic_alpha_and_branch_input(normed)
|
|
234
|
+
alpha = dynamic_alpha + self.static_alpha
|
|
235
|
+
|
|
236
|
+
# beta for weights from branch output back to residual streams
|
|
237
|
+
|
|
238
|
+
dynamic_beta = self.dynamic_beta(normed)
|
|
239
|
+
beta = dynamic_beta + self.static_beta
|
|
240
|
+
|
|
241
|
+
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
242
|
+
|
|
243
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
244
|
+
|
|
245
|
+
if transpose:
|
|
246
|
+
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
247
|
+
|
|
248
|
+
# take care of additional inputs
|
|
249
|
+
|
|
250
|
+
for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
|
|
251
|
+
|
|
252
|
+
# get the residual streams from additional arguments
|
|
253
|
+
|
|
254
|
+
if isinstance(path, int):
|
|
255
|
+
additional_residuals = branch_args[path]
|
|
256
|
+
elif isinstance(path, str):
|
|
257
|
+
additional_residuals = branch_kwargs[path]
|
|
258
|
+
|
|
259
|
+
assert torch.is_tensor(additional_residuals)
|
|
260
|
+
|
|
261
|
+
# handle channel first
|
|
262
|
+
|
|
263
|
+
if transpose:
|
|
264
|
+
additional_residuals = rearrange('b d ... -> b ... d')
|
|
265
|
+
|
|
266
|
+
additional_residuals = rearrange(additional_residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
|
|
267
|
+
|
|
268
|
+
# norm
|
|
269
|
+
|
|
270
|
+
additional_mix = proj(norm(additional_residuals))
|
|
271
|
+
additional_mix = additional_mix + learned_static
|
|
272
|
+
|
|
273
|
+
additional_residuals = einsum(additional_mix, additional_residuals, '... s, ... s d -> ... d')
|
|
274
|
+
|
|
275
|
+
# transpose out
|
|
276
|
+
|
|
277
|
+
if transpose:
|
|
278
|
+
additional_residuals = rearrange('b ... d -> b d ...')
|
|
279
|
+
|
|
280
|
+
# set back transformed residual
|
|
281
|
+
|
|
282
|
+
if isinstance(path, int):
|
|
283
|
+
branch_args[path] = additional_residuals
|
|
284
|
+
elif isinstance(path, str):
|
|
285
|
+
branch_kwargs[path] = additional_residuals
|
|
286
|
+
|
|
287
|
+
return ([branch_input, *branch_args], branch_kwargs), residuals, dict(beta = beta)
|
|
288
|
+
|
|
289
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
290
|
+
# 'depth' connection
|
|
291
|
+
|
|
292
|
+
if self.channel_first:
|
|
293
|
+
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
294
|
+
|
|
295
|
+
residuals = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d') + residuals
|
|
296
|
+
output = rearrange(residuals, 'b ... s d -> (b s) ... d')
|
|
297
|
+
|
|
298
|
+
if self.channel_first:
|
|
299
|
+
output = rearrange(output, 'b ... d -> b d ...')
|
|
300
|
+
|
|
301
|
+
return self.dropout(output)
|
|
302
|
+
|
|
303
|
+
def decorate_branch(self, branch: Callable):
|
|
304
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
305
|
+
|
|
306
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
307
|
+
([branch_input, *args], kwargs), add_residual = self.forward(residual, *args, **kwargs)
|
|
308
|
+
|
|
309
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
310
|
+
|
|
311
|
+
residual = add_residual(branch_output)
|
|
312
|
+
|
|
313
|
+
return residual
|
|
314
|
+
|
|
315
|
+
return forward_and_add_residual
|
|
316
|
+
|
|
317
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
318
|
+
|
|
319
|
+
(branch_args, branch_kwargs), residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
|
|
320
|
+
|
|
321
|
+
def add_residual_fn(branch_out):
|
|
322
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
323
|
+
|
|
324
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
325
|
+
|
|
326
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
327
|
+
|
|
328
|
+
if not exists(self.branch):
|
|
329
|
+
return (branch_args, branch_kwargs), add_residual_fn
|
|
330
|
+
|
|
331
|
+
branch_output = self.branch(*branch_args, **branch_kwargs)
|
|
332
|
+
|
|
333
|
+
return add_residual_fn(branch_output)
|
|
334
|
+
|
|
335
|
+
# add static methods
|
|
336
|
+
|
|
337
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
338
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hyper-connections"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.23"
|
|
4
4
|
description = "Hyper-Connections"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -23,6 +23,7 @@ classifiers=[
|
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
dependencies = [
|
|
26
|
+
"beartype",
|
|
26
27
|
"einops>=0.8.0",
|
|
27
28
|
"torch>=2.3",
|
|
28
29
|
]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|