hyper-connections 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/__init__.py +2 -1
- hyper_connections/hyper_connections.py +49 -8
- {hyper_connections-0.0.5.dist-info → hyper_connections-0.0.6.dist-info}/METADATA +7 -1
- hyper_connections-0.0.6.dist-info/RECORD +6 -0
- hyper_connections-0.0.5.dist-info/RECORD +0 -6
- {hyper_connections-0.0.5.dist-info → hyper_connections-0.0.6.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.5.dist-info → hyper_connections-0.0.6.dist-info}/licenses/LICENSE +0 -0
hyper_connections/__init__.py
CHANGED
|
@@ -18,7 +18,46 @@ def exists(v):
|
|
|
18
18
|
def default(v, d):
|
|
19
19
|
return v if exists(v) else d
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
def identity(t):
|
|
22
|
+
return t
|
|
23
|
+
|
|
24
|
+
# main classes
|
|
25
|
+
|
|
26
|
+
# residual base class
|
|
27
|
+
|
|
28
|
+
class Residual(Module):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*args,
|
|
32
|
+
branch = None,
|
|
33
|
+
**kwargs
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.branch = branch
|
|
37
|
+
|
|
38
|
+
def width_connection(self, residuals):
|
|
39
|
+
return residuals, residuals, dict()
|
|
40
|
+
|
|
41
|
+
def depth_connection(self, branch_output, residuals):
|
|
42
|
+
return branch_output + residuals
|
|
43
|
+
|
|
44
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
45
|
+
|
|
46
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
47
|
+
|
|
48
|
+
def add_residual_fn(branch_out):
|
|
49
|
+
return self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
50
|
+
|
|
51
|
+
if not exists(self.branch):
|
|
52
|
+
return branch_input, add_residual_fn
|
|
53
|
+
|
|
54
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
55
|
+
|
|
56
|
+
(branch_output, *rest), tree_spec = tree_flatten(branch_output)
|
|
57
|
+
|
|
58
|
+
branch_output = add_residual_fn(branch_output)
|
|
59
|
+
|
|
60
|
+
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
22
61
|
|
|
23
62
|
# hyper connection residual streams
|
|
24
63
|
|
|
@@ -70,10 +109,12 @@ class HyperConnections(Module):
|
|
|
70
109
|
return expand_fn, reduce_fn
|
|
71
110
|
|
|
72
111
|
@classmethod
|
|
73
|
-
def get_init_and_expand_reduce_stream_functions(cls, num_streams):
|
|
112
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
113
|
+
|
|
114
|
+
hyper_conn_klass = cls if not disable else Residual
|
|
74
115
|
|
|
75
|
-
init_hyper_conn_fn = partial(
|
|
76
|
-
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams)
|
|
116
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
117
|
+
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
|
|
77
118
|
|
|
78
119
|
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
79
120
|
|
|
@@ -106,9 +147,9 @@ class HyperConnections(Module):
|
|
|
106
147
|
if self.channel_first:
|
|
107
148
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
108
149
|
|
|
109
|
-
return branch_input, residuals, beta
|
|
150
|
+
return branch_input, residuals, dict(beta = beta)
|
|
110
151
|
|
|
111
|
-
def depth_connection(self, branch_output, residuals, beta):
|
|
152
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
112
153
|
# 'depth' connection
|
|
113
154
|
|
|
114
155
|
if self.channel_first:
|
|
@@ -124,10 +165,10 @@ class HyperConnections(Module):
|
|
|
124
165
|
|
|
125
166
|
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
126
167
|
|
|
127
|
-
branch_input, residuals,
|
|
168
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
128
169
|
|
|
129
170
|
def add_residual_fn(branch_out):
|
|
130
|
-
return self.depth_connection(branch_out, residuals,
|
|
171
|
+
return self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
131
172
|
|
|
132
173
|
if not exists(self.branch):
|
|
133
174
|
return branch_input, add_residual_fn
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -135,6 +135,12 @@ residual = add_residual(branch_output)
|
|
|
135
135
|
residual = reduce_stream(residual)
|
|
136
136
|
```
|
|
137
137
|
|
|
138
|
+
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
139
|
+
|
|
140
|
+
```python
|
|
141
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions(4, disabled = True)
|
|
142
|
+
```
|
|
143
|
+
|
|
138
144
|
## Citation
|
|
139
145
|
|
|
140
146
|
```bibtex
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=yEc-yNlGq084y0pR0_VVGLr-sH4ye-eVX0RNz7sTPCo,87
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=YibHh3ocMhkCWhEu8EF554HtGm7i4SPH5ChSMEyFPlI,5843
|
|
3
|
+
hyper_connections-0.0.6.dist-info/METADATA,sha256=c_2oOz7OtvUeLt71AHeO5AXj8d_oIOj0Qo1fcBvCN1A,4979
|
|
4
|
+
hyper_connections-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hyper_connections-0.0.6.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
+
hyper_connections-0.0.6.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=2lZcPuW4hEKet3r8caN-sN-PzRaBNL1q-V3_uA1lVaM,4613
|
|
3
|
-
hyper_connections-0.0.5.dist-info/METADATA,sha256=bko4lEEBiulROYd0aOC_8rIk0hPinBg5TBQWf9DZe9M,4753
|
|
4
|
-
hyper_connections-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|