hyper-connections 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,22 +34,31 @@ def identity(t):
34
34
 
35
35
  # main functions
36
36
 
37
- def get_expand_reduce_stream_functions(num_streams, disable = False):
37
+ def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
38
38
 
39
39
  if num_streams == 1 or disable:
40
40
  return (nn.Identity(), nn.Identity())
41
41
 
42
- expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
42
+ if add_stream_embed:
43
+ assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
44
+
45
+ expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
46
+ else:
47
+ expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
48
+
43
49
  reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
44
50
 
45
51
  return expand_fn, reduce_fn
46
52
 
47
- def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
53
+ def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
48
54
 
49
55
  hyper_conn_klass = HyperConnections if not disable else Residual
50
56
 
51
57
  init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
52
- expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
58
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
59
+
60
+ if exists(dim):
61
+ init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
53
62
 
54
63
  return (init_hyper_conn_fn, *expand_reduce_fns)
55
64
 
@@ -132,6 +141,7 @@ class HyperConnections(Module):
132
141
  channel_first = False,
133
142
  dropout = 0.,
134
143
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
144
+ add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
135
145
  ):
136
146
  """
137
147
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -151,7 +161,7 @@ class HyperConnections(Module):
151
161
  self.num_residual_streams = num_residual_streams
152
162
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
153
163
 
154
- self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
164
+ # width connection
155
165
 
156
166
  init_alpha0 = torch.zeros((num_residual_streams, 1))
157
167
  init_alpha0[init_residual_index, 0] = 1.
@@ -160,8 +170,15 @@ class HyperConnections(Module):
160
170
 
161
171
  self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
162
172
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
163
- self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
164
- self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
173
+
174
+ # depth connection related (beta)
175
+
176
+ self.add_branch_out_to_residual = add_branch_out_to_residual
177
+
178
+ if add_branch_out_to_residual:
179
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
180
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
181
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
165
182
 
166
183
  # dropouts
167
184
 
@@ -196,9 +213,12 @@ class HyperConnections(Module):
196
213
 
197
214
  # beta for weights from branch output back to residual streams
198
215
 
199
- dc_weight = self.act(normed @ self.dynamic_beta_fn)
200
- dynamic_beta = dc_weight * self.dynamic_beta_scale
201
- beta = dynamic_beta + self.static_beta
216
+ beta = None
217
+
218
+ if self.add_branch_out_to_residual:
219
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
220
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
221
+ beta = dynamic_beta + self.static_beta
202
222
 
203
223
  mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
204
224
 
@@ -210,6 +230,8 @@ class HyperConnections(Module):
210
230
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
211
231
 
212
232
  def depth_connection(self, branch_output, residuals, *, beta):
233
+ assert self.add_branch_out_to_residual
234
+
213
235
  # 'depth' connection
214
236
 
215
237
  if self.channel_first:
@@ -244,6 +266,10 @@ class HyperConnections(Module):
244
266
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
245
267
 
246
268
  def add_residual_fn(branch_out):
269
+
270
+ if not self.add_branch_out_to_residual:
271
+ return branch_out
272
+
247
273
  (branch_out, *rest), tree_spec = tree_flatten(branch_out)
248
274
 
249
275
  branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,9 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=Jk7Ux8fJPz63EkgZgKa7fQqpqCasr6cLZt7Fd06dPoE,11563
3
+ hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
+ hyper_connections-0.1.10.dist-info/METADATA,sha256=K2EgcNxhmXRGTOEbVvsVsQl_dKLZ6iw88dzZqD6zaf4,5231
7
+ hyper_connections-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ hyper_connections-0.1.10.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
+ hyper_connections-0.1.10.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=L2e4DduzPGdH30NhfHuiSiVZTwXRgeZW2MDAZ0Z-TKk,10541
3
- hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections-0.1.8.dist-info/METADATA,sha256=hjJ1feS21_VizDdYwE6lSPhh4kJXcQ5PXYPYKGtm2LI,5230
7
- hyper_connections-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- hyper_connections-0.1.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
- hyper_connections-0.1.8.dist-info/RECORD,,