hyper-connections 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,8 +79,11 @@ class HyperConnections(Module):
79
79
 
80
80
  self.branch = branch
81
81
 
82
+ # activation, seemingly results were wishy washy depending on using tanh or not
83
+
82
84
  self.act = nn.Tanh() if tanh else nn.Identity()
83
- self.norm = nn.RMSNorm(dim)
85
+
86
+ self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
84
87
 
85
88
  self.num_residual_streams = num_residual_streams
86
89
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
@@ -163,6 +166,20 @@ class HyperConnections(Module):
163
166
 
164
167
  return output
165
168
 
169
+ def decorate_branch(self, branch: Callable):
170
+ assert not exists(self.branch), 'branch was already wrapped on init'
171
+
172
+ def forward_and_add_residual(residual, *args, **kwargs):
173
+ branch_input, add_residual = self.forward(residual)
174
+
175
+ branch_output = branch(branch_input)
176
+
177
+ residual = add_residual(branch_output)
178
+
179
+ return residual
180
+
181
+ return forward_and_add_residual
182
+
166
183
  def forward(self, residuals, *branch_args, **branch_kwargs):
167
184
 
168
185
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -130,6 +130,8 @@ branch_output = branch(branch_input)
130
130
 
131
131
  residual = add_residual(branch_output)
132
132
 
133
+ # or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
134
+
133
135
  # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
134
136
 
135
137
  residual = reduce_stream(residual)
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
+ hyper_connections/hyper_connections.py,sha256=YfEDcPcT93-S599gFUpaATtuUZ908vJ_pmjPeF4Po28,7558
3
+ hyper_connections-0.0.9.dist-info/METADATA,sha256=xur7rWt-ZdJU1XxXlpaO0D9aWDR1BgaVNhjfnQaedZQ,5075
4
+ hyper_connections-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.9.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
- hyper_connections/hyper_connections.py,sha256=fdcr0DODcIQ1eggy7pa6faX6MqNIZST_q2aDMevViig,6964
3
- hyper_connections-0.0.8.dist-info/METADATA,sha256=2q_q0AHjDyFHzliHiRXXTBaI9iA1uqECGAH_HMZmGis,4978
4
- hyper_connections-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.8.dist-info/RECORD,,