hyper-connections 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +20 -1
- {hyper_connections-0.0.8.dist-info → hyper_connections-0.0.10.dist-info}/METADATA +3 -1
- hyper_connections-0.0.10.dist-info/RECORD +6 -0
- hyper_connections-0.0.8.dist-info/RECORD +0 -6
- {hyper_connections-0.0.8.dist-info → hyper_connections-0.0.10.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.8.dist-info → hyper_connections-0.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
2
4
|
from functools import partial
|
|
3
5
|
from random import randrange
|
|
4
6
|
|
|
@@ -79,8 +81,11 @@ class HyperConnections(Module):
|
|
|
79
81
|
|
|
80
82
|
self.branch = branch
|
|
81
83
|
|
|
84
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
85
|
+
|
|
82
86
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
83
|
-
|
|
87
|
+
|
|
88
|
+
self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
84
89
|
|
|
85
90
|
self.num_residual_streams = num_residual_streams
|
|
86
91
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
@@ -163,6 +168,20 @@ class HyperConnections(Module):
|
|
|
163
168
|
|
|
164
169
|
return output
|
|
165
170
|
|
|
171
|
+
def decorate_branch(self, branch: Callable):
|
|
172
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
173
|
+
|
|
174
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
175
|
+
branch_input, add_residual = self.forward(residual)
|
|
176
|
+
|
|
177
|
+
branch_output = branch(branch_input)
|
|
178
|
+
|
|
179
|
+
residual = add_residual(branch_output)
|
|
180
|
+
|
|
181
|
+
return residual
|
|
182
|
+
|
|
183
|
+
return forward_and_add_residual
|
|
184
|
+
|
|
166
185
|
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
167
186
|
|
|
168
187
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.10
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -130,6 +130,8 @@ branch_output = branch(branch_input)
|
|
|
130
130
|
|
|
131
131
|
residual = add_residual(branch_output)
|
|
132
132
|
|
|
133
|
+
# or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
|
|
134
|
+
|
|
133
135
|
# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
|
|
134
136
|
|
|
135
137
|
residual = reduce_stream(residual)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=lPaO9tBKKI_a3SUKIOAVnZFYMx9-DsW1nZMEvrcJaVU,7587
|
|
3
|
+
hyper_connections-0.0.10.dist-info/METADATA,sha256=vfPTRI0tqXljeLVbtvQoSICekb7w4OWwpx76utQDq8Y,5076
|
|
4
|
+
hyper_connections-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hyper_connections-0.0.10.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
+
hyper_connections-0.0.10.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=fdcr0DODcIQ1eggy7pa6faX6MqNIZST_q2aDMevViig,6964
|
|
3
|
-
hyper_connections-0.0.8.dist-info/METADATA,sha256=2q_q0AHjDyFHzliHiRXXTBaI9iA1uqECGAH_HMZmGis,4978
|
|
4
|
-
hyper_connections-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|