hyper-connections 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ b - batch
19
19
  d - feature dimension
20
20
  s - residual streams
21
21
  t - residual streams + num branch inputs
22
+ v - number of views for branch input
22
23
  """
23
24
 
24
25
  # helper functions
@@ -34,22 +35,31 @@ def identity(t):
34
35
 
35
36
  # main functions
36
37
 
37
- def get_expand_reduce_stream_functions(num_streams, disable = False):
38
+ def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
38
39
 
39
40
  if num_streams == 1 or disable:
40
41
  return (nn.Identity(), nn.Identity())
41
42
 
42
- expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
43
+ if add_stream_embed:
44
+ assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
45
+
46
+ expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
47
+ else:
48
+ expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
49
+
43
50
  reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
44
51
 
45
52
  return expand_fn, reduce_fn
46
53
 
47
- def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
54
+ def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
48
55
 
49
56
  hyper_conn_klass = HyperConnections if not disable else Residual
50
57
 
51
58
  init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
52
- expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
59
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
60
+
61
+ if exists(dim):
62
+ init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
53
63
 
54
64
  return (init_hyper_conn_fn, *expand_reduce_fns)
55
65
 
@@ -132,7 +142,8 @@ class HyperConnections(Module):
132
142
  channel_first = False,
133
143
  dropout = 0.,
134
144
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
135
- add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
145
+ add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
146
+ num_input_views = 1 # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
136
147
  ):
137
148
  """
138
149
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -152,14 +163,19 @@ class HyperConnections(Module):
152
163
  self.num_residual_streams = num_residual_streams
153
164
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
154
165
 
166
+ # width num residual streams
167
+
168
+ assert num_input_views >= 1
169
+ self.num_input_views = num_input_views
170
+
155
171
  # width connection
156
172
 
157
- init_alpha0 = torch.zeros((num_residual_streams, 1))
158
- init_alpha0[init_residual_index, 0] = 1.
173
+ init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
174
+ init_alpha0[init_residual_index, :] = 1.
159
175
 
160
176
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
161
177
 
162
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
178
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
163
179
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
164
180
 
165
181
  # depth connection related (beta)
@@ -213,7 +229,11 @@ class HyperConnections(Module):
213
229
 
214
230
  mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
215
231
 
216
- branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
232
+ if self.num_input_views == 1:
233
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
234
+ else:
235
+ branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
236
+ branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
217
237
 
218
238
  if self.channel_first:
219
239
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,9 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=mUImPtaTE8Paygs-6vq7l_mlph1CkU__jRcE4TFim_Y,12137
3
+ hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
+ hyper_connections-0.1.11.dist-info/METADATA,sha256=Ck3udilJMrT1ABRkqkNhfEdkjQXAKFBtsUAAalEu3No,5231
7
+ hyper_connections-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ hyper_connections-0.1.11.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
+ hyper_connections-0.1.11.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=F81iJkGMpxgCZPaBTLf0c3CYIE-ROAVgZJWY3NlrsJw,11068
3
- hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections-0.1.9.dist-info/METADATA,sha256=XxicphOwzNfTmBLF4Py89MhTWSpJqVk1EG-DV1gpFvo,5230
7
- hyper_connections-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- hyper_connections-0.1.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
- hyper_connections-0.1.9.dist-info/RECORD,,