hyper-connections 0.0.8__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -130,6 +130,8 @@ branch_output = branch(branch_input)
130
130
 
131
131
  residual = add_residual(branch_output)
132
132
 
133
+ # or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
134
+
133
135
  # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
134
136
 
135
137
  residual = reduce_stream(residual)
@@ -87,6 +87,8 @@ branch_output = branch(branch_input)
87
87
 
88
88
  residual = add_residual(branch_output)
89
89
 
90
+ # or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
91
+
90
92
  # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
91
93
 
92
94
  residual = reduce_stream(residual)
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
3
+
2
4
  from functools import partial
3
5
  from random import randrange
4
6
 
@@ -79,8 +81,11 @@ class HyperConnections(Module):
79
81
 
80
82
  self.branch = branch
81
83
 
84
+ # activation, seemingly results were wishy washy depending on using tanh or not
85
+
82
86
  self.act = nn.Tanh() if tanh else nn.Identity()
83
- self.norm = nn.RMSNorm(dim)
87
+
88
+ self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
84
89
 
85
90
  self.num_residual_streams = num_residual_streams
86
91
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
@@ -163,6 +168,20 @@ class HyperConnections(Module):
163
168
 
164
169
  return output
165
170
 
171
+ def decorate_branch(self, branch: Callable):
172
+ assert not exists(self.branch), 'branch was already wrapped on init'
173
+
174
+ def forward_and_add_residual(residual, *args, **kwargs):
175
+ branch_input, add_residual = self.forward(residual)
176
+
177
+ branch_output = branch(branch_input)
178
+
179
+ residual = add_residual(branch_output)
180
+
181
+ return residual
182
+
183
+ return forward_and_add_residual
184
+
166
185
  def forward(self, residuals, *branch_args, **branch_kwargs):
167
186
 
168
187
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.8"
3
+ version = "0.0.10"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }