hyper-connections 0.0.9__tar.gz → 0.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.9
3
+ Version: 0.0.11
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
3
+
2
4
  from functools import partial
3
5
  from random import randrange
4
6
 
@@ -172,7 +174,7 @@ class HyperConnections(Module):
172
174
  def forward_and_add_residual(residual, *args, **kwargs):
173
175
  branch_input, add_residual = self.forward(residual)
174
176
 
175
- branch_output = branch(branch_input)
177
+ branch_output = branch(branch_input, *args, **kwargs)
176
178
 
177
179
  residual = add_residual(branch_output)
178
180
 
@@ -185,18 +187,18 @@ class HyperConnections(Module):
185
187
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
186
188
 
187
189
  def add_residual_fn(branch_out):
188
- return self.depth_connection(branch_out, residuals, **residual_kwargs)
190
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
191
+
192
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
193
+
194
+ return tree_unflatten((branch_out, *rest), tree_spec)
189
195
 
190
196
  if not exists(self.branch):
191
197
  return branch_input, add_residual_fn
192
198
 
193
199
  branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
194
200
 
195
- (branch_output, *rest), tree_spec = tree_flatten(branch_output)
196
-
197
- branch_output = add_residual_fn(branch_output)
198
-
199
- return tree_unflatten((branch_output, *rest), tree_spec)
201
+ return add_residual_fn(branch_output)
200
202
 
201
203
  # stream embed
202
204
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.9"
3
+ version = "0.0.11"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }