hyper-connections 0.0.11__tar.gz → 0.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/PKG-INFO +1 -1
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/hyper_connections/hyper_connections.py +20 -6
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/pyproject.toml +1 -1
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/.gitignore +0 -0
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/LICENSE +0 -0
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/README.md +0 -0
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/hyper-connections.png +0 -0
- {hyper_connections-0.0.11 → hyper_connections-0.0.12}/hyper_connections/__init__.py +0 -0
{hyper_connections-0.0.11 → hyper_connections-0.0.12}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -43,23 +43,37 @@ class Residual(Module):
|
|
|
43
43
|
def depth_connection(self, branch_output, residuals):
|
|
44
44
|
return branch_output + residuals
|
|
45
45
|
|
|
46
|
+
def decorate_branch(self, branch: Callable):
|
|
47
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
48
|
+
|
|
49
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
50
|
+
branch_input, add_residual = self.forward(residual)
|
|
51
|
+
|
|
52
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
53
|
+
|
|
54
|
+
residual = add_residual(branch_output)
|
|
55
|
+
|
|
56
|
+
return residual
|
|
57
|
+
|
|
58
|
+
return forward_and_add_residual
|
|
59
|
+
|
|
46
60
|
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
47
61
|
|
|
48
62
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
49
63
|
|
|
50
64
|
def add_residual_fn(branch_out):
|
|
51
|
-
|
|
65
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
66
|
+
|
|
67
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
68
|
+
|
|
69
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
52
70
|
|
|
53
71
|
if not exists(self.branch):
|
|
54
72
|
return branch_input, add_residual_fn
|
|
55
73
|
|
|
56
74
|
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
57
75
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
branch_output = add_residual_fn(branch_output)
|
|
61
|
-
|
|
62
|
-
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
76
|
+
return add_residual_fn(branch_output)
|
|
63
77
|
|
|
64
78
|
# hyper connection residual streams
|
|
65
79
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|