hyper-connections 0.0.11__tar.gz → 0.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.11
3
+ Version: 0.0.12
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -43,23 +43,37 @@ class Residual(Module):
43
43
  def depth_connection(self, branch_output, residuals):
44
44
  return branch_output + residuals
45
45
 
46
+ def decorate_branch(self, branch: Callable):
47
+ assert not exists(self.branch), 'branch was already wrapped on init'
48
+
49
+ def forward_and_add_residual(residual, *args, **kwargs):
50
+ branch_input, add_residual = self.forward(residual)
51
+
52
+ branch_output = branch(branch_input, *args, **kwargs)
53
+
54
+ residual = add_residual(branch_output)
55
+
56
+ return residual
57
+
58
+ return forward_and_add_residual
59
+
46
60
  def forward(self, residuals, *branch_args, **branch_kwargs):
47
61
 
48
62
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
49
63
 
50
64
  def add_residual_fn(branch_out):
51
- return self.depth_connection(branch_out, residuals, **residual_kwargs)
65
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
66
+
67
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
68
+
69
+ return tree_unflatten((branch_out, *rest), tree_spec)
52
70
 
53
71
  if not exists(self.branch):
54
72
  return branch_input, add_residual_fn
55
73
 
56
74
  branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
57
75
 
58
- (branch_output, *rest), tree_spec = tree_flatten(branch_output)
59
-
60
- branch_output = add_residual_fn(branch_output)
61
-
62
- return tree_unflatten((branch_output, *rest), tree_spec)
76
+ return add_residual_fn(branch_output)
63
77
 
64
78
  # hyper connection residual streams
65
79
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.11"
3
+ version = "0.0.12"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }