hyper-connections 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ b - batch
19
19
  d - feature dimension
20
20
  s - residual streams
21
21
  t - residual streams + num branch inputs
22
+ v - number of views for branch input
22
23
  """
23
24
 
24
25
  # helper functions
@@ -141,7 +142,8 @@ class HyperConnections(Module):
141
142
  channel_first = False,
142
143
  dropout = 0.,
143
144
  residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
144
- add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
145
+ add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
146
+ num_input_views = 1 # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
145
147
  ):
146
148
  """
147
149
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -161,14 +163,19 @@ class HyperConnections(Module):
161
163
  self.num_residual_streams = num_residual_streams
162
164
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
163
165
 
166
+ # width num residual streams
167
+
168
+ assert num_input_views >= 1
169
+ self.num_input_views = num_input_views
170
+
164
171
  # width connection
165
172
 
166
- init_alpha0 = torch.zeros((num_residual_streams, 1))
167
- init_alpha0[init_residual_index, 0] = 1.
173
+ init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
174
+ init_alpha0[init_residual_index, :] = 1.
168
175
 
169
176
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
170
177
 
171
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
178
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
172
179
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
173
180
 
174
181
  # depth connection related (beta)
@@ -222,7 +229,11 @@ class HyperConnections(Module):
222
229
 
223
230
  mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
224
231
 
225
- branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
232
+ if self.num_input_views == 1:
233
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
234
+ else:
235
+ branch_input, residuals = mix_h[..., :self.num_input_views, :], mix_h[..., self.num_input_views:, :]
236
+ branch_input = rearrange(branch_input, 'b ... v d -> v b ... d')
226
237
 
227
238
  if self.channel_first:
228
239
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.1.10
3
+ Version: 0.1.11
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,9 +1,9 @@
1
1
  hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=Jk7Ux8fJPz63EkgZgKa7fQqpqCasr6cLZt7Fd06dPoE,11563
2
+ hyper_connections/hyper_connections.py,sha256=mUImPtaTE8Paygs-6vq7l_mlph1CkU__jRcE4TFim_Y,12137
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
6
- hyper_connections-0.1.10.dist-info/METADATA,sha256=K2EgcNxhmXRGTOEbVvsVsQl_dKLZ6iw88dzZqD6zaf4,5231
7
- hyper_connections-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- hyper_connections-0.1.10.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
- hyper_connections-0.1.10.dist-info/RECORD,,
6
+ hyper_connections-0.1.11.dist-info/METADATA,sha256=Ck3udilJMrT1ABRkqkNhfEdkjQXAKFBtsUAAalEu3No,5231
7
+ hyper_connections-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ hyper_connections-0.1.11.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
9
+ hyper_connections-0.1.11.dist-info/RECORD,,