hyper-connections 0.0.9__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/PKG-INFO +1 -1
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/hyper_connections/hyper_connections.py +9 -7
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/pyproject.toml +1 -1
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/.gitignore +0 -0
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/LICENSE +0 -0
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/README.md +0 -0
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/hyper-connections.png +0 -0
- {hyper_connections-0.0.9 → hyper_connections-0.0.11}/hyper_connections/__init__.py +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
2
4
|
from functools import partial
|
|
3
5
|
from random import randrange
|
|
4
6
|
|
|
@@ -172,7 +174,7 @@ class HyperConnections(Module):
|
|
|
172
174
|
def forward_and_add_residual(residual, *args, **kwargs):
|
|
173
175
|
branch_input, add_residual = self.forward(residual)
|
|
174
176
|
|
|
175
|
-
branch_output = branch(branch_input)
|
|
177
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
176
178
|
|
|
177
179
|
residual = add_residual(branch_output)
|
|
178
180
|
|
|
@@ -185,18 +187,18 @@ class HyperConnections(Module):
|
|
|
185
187
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
186
188
|
|
|
187
189
|
def add_residual_fn(branch_out):
|
|
188
|
-
|
|
190
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
191
|
+
|
|
192
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
193
|
+
|
|
194
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
189
195
|
|
|
190
196
|
if not exists(self.branch):
|
|
191
197
|
return branch_input, add_residual_fn
|
|
192
198
|
|
|
193
199
|
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
194
200
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
branch_output = add_residual_fn(branch_output)
|
|
198
|
-
|
|
199
|
-
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
201
|
+
return add_residual_fn(branch_output)
|
|
200
202
|
|
|
201
203
|
# stream embed
|
|
202
204
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|