hyper-connections 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/PKG-INFO +12 -6
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/README.md +11 -5
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/hyper_connections/__init__.py +2 -1
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/hyper_connections/hyper_connections.py +54 -5
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/pyproject.toml +1 -1
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/.gitignore +0 -0
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/LICENSE +0 -0
- {hyper_connections-0.0.4 → hyper_connections-0.0.6}/hyper-connections.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
|
|
|
73
73
|
|
|
74
74
|
from hyper_connections import HyperConnections
|
|
75
75
|
|
|
76
|
-
expand_stream, reduce_stream = HyperConnections.
|
|
76
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
|
|
77
77
|
|
|
78
78
|
# 1. wrap your branch function
|
|
79
79
|
|
|
80
|
-
hyper_conn_branch =
|
|
80
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
81
81
|
|
|
82
82
|
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
83
83
|
|
|
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
|
|
|
112
112
|
|
|
113
113
|
from hyper_connections import HyperConnections
|
|
114
114
|
|
|
115
|
-
expand_stream, reduce_stream = HyperConnections.
|
|
115
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
|
|
116
116
|
|
|
117
|
-
# 1. instantiate hyper connection with correct number of streams (4 in this case)
|
|
117
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
118
118
|
|
|
119
|
-
hyper_conn =
|
|
119
|
+
hyper_conn = init_hyper_conn(dim = 512)
|
|
120
120
|
|
|
121
121
|
# 2. expand to 4 streams
|
|
122
122
|
|
|
@@ -135,6 +135,12 @@ residual = add_residual(branch_output)
|
|
|
135
135
|
residual = reduce_stream(residual)
|
|
136
136
|
```
|
|
137
137
|
|
|
138
|
+
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
139
|
+
|
|
140
|
+
```python
|
|
141
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions(4, disabled = True)
|
|
142
|
+
```
|
|
143
|
+
|
|
138
144
|
## Citation
|
|
139
145
|
|
|
140
146
|
```bibtex
|
|
@@ -30,11 +30,11 @@ residual = branch(residual) + residual
|
|
|
30
30
|
|
|
31
31
|
from hyper_connections import HyperConnections
|
|
32
32
|
|
|
33
|
-
expand_stream, reduce_stream = HyperConnections.
|
|
33
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
|
|
34
34
|
|
|
35
35
|
# 1. wrap your branch function
|
|
36
36
|
|
|
37
|
-
hyper_conn_branch =
|
|
37
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
38
38
|
|
|
39
39
|
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
40
40
|
|
|
@@ -69,11 +69,11 @@ residual = branch(residual) + residual
|
|
|
69
69
|
|
|
70
70
|
from hyper_connections import HyperConnections
|
|
71
71
|
|
|
72
|
-
expand_stream, reduce_stream = HyperConnections.
|
|
72
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
|
|
73
73
|
|
|
74
|
-
# 1. instantiate hyper connection with correct number of streams (4 in this case)
|
|
74
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
75
75
|
|
|
76
|
-
hyper_conn =
|
|
76
|
+
hyper_conn = init_hyper_conn(dim = 512)
|
|
77
77
|
|
|
78
78
|
# 2. expand to 4 streams
|
|
79
79
|
|
|
@@ -92,6 +92,12 @@ residual = add_residual(branch_output)
|
|
|
92
92
|
residual = reduce_stream(residual)
|
|
93
93
|
```
|
|
94
94
|
|
|
95
|
+
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions(4, disabled = True)
|
|
99
|
+
```
|
|
100
|
+
|
|
95
101
|
## Citation
|
|
96
102
|
|
|
97
103
|
```bibtex
|
|
@@ -18,7 +18,46 @@ def exists(v):
|
|
|
18
18
|
def default(v, d):
|
|
19
19
|
return v if exists(v) else d
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
def identity(t):
|
|
22
|
+
return t
|
|
23
|
+
|
|
24
|
+
# main classes
|
|
25
|
+
|
|
26
|
+
# residual base class
|
|
27
|
+
|
|
28
|
+
class Residual(Module):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*args,
|
|
32
|
+
branch = None,
|
|
33
|
+
**kwargs
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.branch = branch
|
|
37
|
+
|
|
38
|
+
def width_connection(self, residuals):
|
|
39
|
+
return residuals, residuals, dict()
|
|
40
|
+
|
|
41
|
+
def depth_connection(self, branch_output, residuals):
|
|
42
|
+
return branch_output + residuals
|
|
43
|
+
|
|
44
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
45
|
+
|
|
46
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
47
|
+
|
|
48
|
+
def add_residual_fn(branch_out):
|
|
49
|
+
return self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
50
|
+
|
|
51
|
+
if not exists(self.branch):
|
|
52
|
+
return branch_input, add_residual_fn
|
|
53
|
+
|
|
54
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
55
|
+
|
|
56
|
+
(branch_output, *rest), tree_spec = tree_flatten(branch_output)
|
|
57
|
+
|
|
58
|
+
branch_output = add_residual_fn(branch_output)
|
|
59
|
+
|
|
60
|
+
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
22
61
|
|
|
23
62
|
# hyper connection residual streams
|
|
24
63
|
|
|
@@ -69,6 +108,16 @@ class HyperConnections(Module):
|
|
|
69
108
|
|
|
70
109
|
return expand_fn, reduce_fn
|
|
71
110
|
|
|
111
|
+
@classmethod
|
|
112
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
113
|
+
|
|
114
|
+
hyper_conn_klass = cls if not disable else Residual
|
|
115
|
+
|
|
116
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
117
|
+
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
|
|
118
|
+
|
|
119
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
120
|
+
|
|
72
121
|
def width_connection(self, residuals):
|
|
73
122
|
# width connection
|
|
74
123
|
|
|
@@ -98,9 +147,9 @@ class HyperConnections(Module):
|
|
|
98
147
|
if self.channel_first:
|
|
99
148
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
100
149
|
|
|
101
|
-
return branch_input, residuals, beta
|
|
150
|
+
return branch_input, residuals, dict(beta = beta)
|
|
102
151
|
|
|
103
|
-
def depth_connection(self, branch_output, residuals, beta):
|
|
152
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
104
153
|
# 'depth' connection
|
|
105
154
|
|
|
106
155
|
if self.channel_first:
|
|
@@ -116,10 +165,10 @@ class HyperConnections(Module):
|
|
|
116
165
|
|
|
117
166
|
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
118
167
|
|
|
119
|
-
branch_input, residuals,
|
|
168
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
120
169
|
|
|
121
170
|
def add_residual_fn(branch_out):
|
|
122
|
-
return self.depth_connection(branch_out, residuals,
|
|
171
|
+
return self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
123
172
|
|
|
124
173
|
if not exists(self.branch):
|
|
125
174
|
return branch_input, add_residual_fn
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|