hyper-connections 0.0.10__tar.gz → 0.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.10
3
+ Version: 0.0.12
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -43,23 +43,37 @@ class Residual(Module):
43
43
  def depth_connection(self, branch_output, residuals):
44
44
  return branch_output + residuals
45
45
 
46
+ def decorate_branch(self, branch: Callable):
47
+ assert not exists(self.branch), 'branch was already wrapped on init'
48
+
49
+ def forward_and_add_residual(residual, *args, **kwargs):
50
+ branch_input, add_residual = self.forward(residual)
51
+
52
+ branch_output = branch(branch_input, *args, **kwargs)
53
+
54
+ residual = add_residual(branch_output)
55
+
56
+ return residual
57
+
58
+ return forward_and_add_residual
59
+
46
60
  def forward(self, residuals, *branch_args, **branch_kwargs):
47
61
 
48
62
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
49
63
 
50
64
  def add_residual_fn(branch_out):
51
- return self.depth_connection(branch_out, residuals, **residual_kwargs)
65
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
66
+
67
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
68
+
69
+ return tree_unflatten((branch_out, *rest), tree_spec)
52
70
 
53
71
  if not exists(self.branch):
54
72
  return branch_input, add_residual_fn
55
73
 
56
74
  branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
57
75
 
58
- (branch_output, *rest), tree_spec = tree_flatten(branch_output)
59
-
60
- branch_output = add_residual_fn(branch_output)
61
-
62
- return tree_unflatten((branch_output, *rest), tree_spec)
76
+ return add_residual_fn(branch_output)
63
77
 
64
78
  # hyper connection residual streams
65
79
 
@@ -174,7 +188,7 @@ class HyperConnections(Module):
174
188
  def forward_and_add_residual(residual, *args, **kwargs):
175
189
  branch_input, add_residual = self.forward(residual)
176
190
 
177
- branch_output = branch(branch_input)
191
+ branch_output = branch(branch_input, *args, **kwargs)
178
192
 
179
193
  residual = add_residual(branch_output)
180
194
 
@@ -187,18 +201,18 @@ class HyperConnections(Module):
187
201
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
188
202
 
189
203
  def add_residual_fn(branch_out):
190
- return self.depth_connection(branch_out, residuals, **residual_kwargs)
204
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
205
+
206
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
207
+
208
+ return tree_unflatten((branch_out, *rest), tree_spec)
191
209
 
192
210
  if not exists(self.branch):
193
211
  return branch_input, add_residual_fn
194
212
 
195
213
  branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
196
214
 
197
- (branch_output, *rest), tree_spec = tree_flatten(branch_output)
198
-
199
- branch_output = add_residual_fn(branch_output)
200
-
201
- return tree_unflatten((branch_output, *rest), tree_spec)
215
+ return add_residual_fn(branch_output)
202
216
 
203
217
  # stream embed
204
218
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.10"
3
+ version = "0.0.12"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }