hyper-connections 0.0.10__tar.gz → 0.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -174,7 +174,7 @@ class HyperConnections(Module):
174
174
  def forward_and_add_residual(residual, *args, **kwargs):
175
175
  branch_input, add_residual = self.forward(residual)
176
176
 
177
- branch_output = branch(branch_input)
177
+ branch_output = branch(branch_input, *args, **kwargs)
178
178
 
179
179
  residual = add_residual(branch_output)
180
180
 
@@ -187,18 +187,18 @@ class HyperConnections(Module):
187
187
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
188
188
 
189
189
  def add_residual_fn(branch_out):
190
- return self.depth_connection(branch_out, residuals, **residual_kwargs)
190
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
191
+
192
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
193
+
194
+ return tree_unflatten((branch_out, *rest), tree_spec)
191
195
 
192
196
  if not exists(self.branch):
193
197
  return branch_input, add_residual_fn
194
198
 
195
199
  branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
196
200
 
197
- (branch_output, *rest), tree_spec = tree_flatten(branch_output)
198
-
199
- branch_output = add_residual_fn(branch_output)
200
-
201
- return tree_unflatten((branch_output, *rest), tree_spec)
201
+ return add_residual_fn(branch_output)
202
202
 
203
203
  # stream embed
204
204
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.10"
3
+ version = "0.0.11"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }