hyper-connections 0.0.10__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/PKG-INFO +1 -1
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/hyper_connections/hyper_connections.py +7 -7
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/pyproject.toml +1 -1
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/.gitignore +0 -0
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/LICENSE +0 -0
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/README.md +0 -0
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/hyper-connections.png +0 -0
- {hyper_connections-0.0.10 → hyper_connections-0.0.11}/hyper_connections/__init__.py +0 -0
{hyper_connections-0.0.10 → hyper_connections-0.0.11}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -174,7 +174,7 @@ class HyperConnections(Module):
|
|
|
174
174
|
def forward_and_add_residual(residual, *args, **kwargs):
|
|
175
175
|
branch_input, add_residual = self.forward(residual)
|
|
176
176
|
|
|
177
|
-
branch_output = branch(branch_input)
|
|
177
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
178
178
|
|
|
179
179
|
residual = add_residual(branch_output)
|
|
180
180
|
|
|
@@ -187,18 +187,18 @@ class HyperConnections(Module):
|
|
|
187
187
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
188
188
|
|
|
189
189
|
def add_residual_fn(branch_out):
|
|
190
|
-
|
|
190
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
191
|
+
|
|
192
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
193
|
+
|
|
194
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
191
195
|
|
|
192
196
|
if not exists(self.branch):
|
|
193
197
|
return branch_input, add_residual_fn
|
|
194
198
|
|
|
195
199
|
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
196
200
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
branch_output = add_residual_fn(branch_output)
|
|
200
|
-
|
|
201
|
-
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
201
|
+
return add_residual_fn(branch_output)
|
|
202
202
|
|
|
203
203
|
# stream embed
|
|
204
204
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|