hyper-connections 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import torch
6
6
  from torch import nn
7
7
  from torch.nn import Module
8
8
  import torch.nn.functional as F
9
+ from torch.utils._pytree import tree_flatten, tree_unflatten
9
10
 
10
11
  from einops import rearrange, repeat, reduce, einsum
11
12
 
@@ -125,4 +126,8 @@ class HyperConnections(Module):
125
126
 
126
127
  branch_output = self.branch(branch_input, **branch_kwargs)
127
128
 
128
- return add_residual_fn(branch_output)
129
+ (branch_output, *rest), tree_spec = tree_flatten(branch_output)
130
+
131
+ branch_output = add_residual_fn(branch_output)
132
+
133
+ return tree_unflatten((branch_output, *rest), tree_spec)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -45,7 +45,7 @@ Description-Content-Type: text/markdown
45
45
 
46
46
  ## Hyper Connections
47
47
 
48
- Attempt to make the multiple residual stream approach proposed by Hyper-Connections paper by Bytedance AI more accessible as a reusable library, and for following any new research in this direction.
48
+ Attempt to make multiple residual streams, proposed in [Hyper-Connections paper](https://arxiv.org/abs/2409.19606) out of Bytedance AI lab, accessible as an easy to use library, as well as for following any new research in this direction.
49
49
 
50
50
  ## Install
51
51
 
@@ -114,7 +114,7 @@ from hyper_connections import HyperConnections
114
114
 
115
115
  expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
116
116
 
117
- # 1. wrap your branch function
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case)
118
118
 
119
119
  hyper_conn = HyperConnections(4, dim = 512)
120
120
 
@@ -124,11 +124,11 @@ residual = expand_stream(residual)
124
124
 
125
125
  # 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
126
126
 
127
- branch_input, depth_connect = hyper_conn(residual)
127
+ branch_input, add_residual = hyper_conn(residual)
128
128
 
129
129
  branch_output = branch(branch_input)
130
130
 
131
- residual = depth_connect(branch_output)
131
+ residual = add_residual(branch_output)
132
132
 
133
133
  # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
134
134
 
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
+ hyper_connections/hyper_connections.py,sha256=91QtTtnpffmErIZvrnTtosSf4JgBqcyGvxftmka-EOw,4303
3
+ hyper_connections-0.0.3.dist-info/METADATA,sha256=8XKTmC6Ys10uOyotPtvL17v4uZyepkWoeMRVM4B_TSQ,4676
4
+ hyper_connections-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.3.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
- hyper_connections/hyper_connections.py,sha256=0AcrJ6O2Crnc3aiz-G5zExypcdujJHaTgmLJswdRA_c,4094
3
- hyper_connections-0.0.2.dist-info/METADATA,sha256=KptXVWgSC5aIh63S7cCSp4h7X7xCqxMOFRToGr6XEbo,4587
4
- hyper_connections-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.2.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.2.dist-info/RECORD,,