hyper-connections 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -132,6 +132,7 @@ class HyperConnections(Module):
132
132
  channel_first = False,
133
133
  dropout = 0.,
134
134
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
135
+ add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
135
136
  ):
136
137
  """
137
138
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -151,7 +152,7 @@ class HyperConnections(Module):
151
152
  self.num_residual_streams = num_residual_streams
152
153
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
153
154
 
154
- self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
155
+ # width connection
155
156
 
156
157
  init_alpha0 = torch.zeros((num_residual_streams, 1))
157
158
  init_alpha0[init_residual_index, 0] = 1.
@@ -160,8 +161,15 @@ class HyperConnections(Module):
160
161
 
161
162
  self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
162
163
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
163
- self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
164
- self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
164
+
165
+ # depth connection related (beta)
166
+
167
+ self.add_branch_out_to_residual = add_branch_out_to_residual
168
+
169
+ if add_branch_out_to_residual:
170
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
171
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
172
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
165
173
 
166
174
  # dropouts
167
175
 
@@ -196,9 +204,12 @@ class HyperConnections(Module):
196
204
 
197
205
  # beta for weights from branch output back to residual streams
198
206
 
199
- dc_weight = self.act(normed @ self.dynamic_beta_fn)
200
- dynamic_beta = dc_weight * self.dynamic_beta_scale
201
- beta = dynamic_beta + self.static_beta
207
+ beta = None
208
+
209
+ if self.add_branch_out_to_residual:
210
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
211
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
212
+ beta = dynamic_beta + self.static_beta
202
213
 
203
214
  mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
204
215
 
@@ -210,6 +221,8 @@ class HyperConnections(Module):
210
221
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
211
222
 
212
223
  def depth_connection(self, branch_output, residuals, *, beta):
224
+ assert self.add_branch_out_to_residual
225
+
213
226
  # 'depth' connection
214
227
 
215
228
  if self.channel_first:
@@ -244,6 +257,10 @@ class HyperConnections(Module):
244
257
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
245
258
 
246
259
  def add_residual_fn(branch_out):
260
+
261
+ if not self.add_branch_out_to_residual:
262
+ return branch_out
263
+
247
264
  (branch_out, *rest), tree_spec = tree_flatten(branch_out)
248
265
 
249
266
  branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,9 +1,9 @@
1
1
  hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=L2e4DduzPGdH30NhfHuiSiVZTwXRgeZW2MDAZ0Z-TKk,10541
2
+ hyper_connections/hyper_connections.py,sha256=F81iJkGMpxgCZPaBTLf0c3CYIE-ROAVgZJWY3NlrsJw,11068
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections-0.1.8.dist-info/METADATA,sha256=hjJ1feS21_VizDdYwE6lSPhh4kJXcQ5PXYPYKGtm2LI,5230
7
- hyper_connections-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- hyper_connections-0.1.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
- hyper_connections-0.1.8.dist-info/RECORD,,
6
+ hyper_connections-0.1.9.dist-info/METADATA,sha256=XxicphOwzNfTmBLF4Py89MhTWSpJqVk1EG-DV1gpFvo,5230
7
+ hyper_connections-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ hyper_connections-0.1.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
+ hyper_connections-0.1.9.dist-info/RECORD,,