hyper-connections 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +10 -2
- {hyper_connections-0.0.3.dist-info → hyper_connections-0.0.5.dist-info}/METADATA +6 -6
- hyper_connections-0.0.5.dist-info/RECORD +6 -0
- hyper_connections-0.0.3.dist-info/RECORD +0 -6
- {hyper_connections-0.0.3.dist-info → hyper_connections-0.0.5.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.3.dist-info → hyper_connections-0.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -69,6 +69,14 @@ class HyperConnections(Module):
|
|
|
69
69
|
|
|
70
70
|
return expand_fn, reduce_fn
|
|
71
71
|
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams):
|
|
74
|
+
|
|
75
|
+
init_hyper_conn_fn = partial(cls, num_streams)
|
|
76
|
+
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams)
|
|
77
|
+
|
|
78
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
79
|
+
|
|
72
80
|
def width_connection(self, residuals):
|
|
73
81
|
# width connection
|
|
74
82
|
|
|
@@ -114,7 +122,7 @@ class HyperConnections(Module):
|
|
|
114
122
|
|
|
115
123
|
return output
|
|
116
124
|
|
|
117
|
-
def forward(self, residuals, **branch_kwargs):
|
|
125
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
118
126
|
|
|
119
127
|
branch_input, residuals, beta = self.width_connection(residuals)
|
|
120
128
|
|
|
@@ -124,7 +132,7 @@ class HyperConnections(Module):
|
|
|
124
132
|
if not exists(self.branch):
|
|
125
133
|
return branch_input, add_residual_fn
|
|
126
134
|
|
|
127
|
-
branch_output = self.branch(branch_input, **branch_kwargs)
|
|
135
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
128
136
|
|
|
129
137
|
(branch_output, *rest), tree_spec = tree_flatten(branch_output)
|
|
130
138
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
|
|
|
73
73
|
|
|
74
74
|
from hyper_connections import HyperConnections
|
|
75
75
|
|
|
76
|
-
expand_stream, reduce_stream = HyperConnections.
|
|
76
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
|
|
77
77
|
|
|
78
78
|
# 1. wrap your branch function
|
|
79
79
|
|
|
80
|
-
hyper_conn_branch =
|
|
80
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
|
|
81
81
|
|
|
82
82
|
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
83
83
|
|
|
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
|
|
|
112
112
|
|
|
113
113
|
from hyper_connections import HyperConnections
|
|
114
114
|
|
|
115
|
-
expand_stream, reduce_stream = HyperConnections.
|
|
115
|
+
init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
|
|
116
116
|
|
|
117
|
-
# 1. instantiate hyper connection with correct number of streams (4 in this case)
|
|
117
|
+
# 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
|
|
118
118
|
|
|
119
|
-
hyper_conn =
|
|
119
|
+
hyper_conn = init_hyper_conn(dim = 512)
|
|
120
120
|
|
|
121
121
|
# 2. expand to 4 streams
|
|
122
122
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=2lZcPuW4hEKet3r8caN-sN-PzRaBNL1q-V3_uA1lVaM,4613
|
|
3
|
+
hyper_connections-0.0.5.dist-info/METADATA,sha256=bko4lEEBiulROYd0aOC_8rIk0hPinBg5TBQWf9DZe9M,4753
|
|
4
|
+
hyper_connections-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hyper_connections-0.0.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
+
hyper_connections-0.0.5.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=91QtTtnpffmErIZvrnTtosSf4JgBqcyGvxftmka-EOw,4303
|
|
3
|
-
hyper_connections-0.0.3.dist-info/METADATA,sha256=8XKTmC6Ys10uOyotPtvL17v4uZyepkWoeMRVM4B_TSQ,4676
|
|
4
|
-
hyper_connections-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|