hyper-connections 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -175,7 +175,6 @@ class HyperConnections(Module):
175
175
  tanh = True,
176
176
  channel_first = False,
177
177
  dropout = 0.,
178
- residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
179
178
  add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
180
179
  num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
181
180
  depth_residual_fn = add,
@@ -255,10 +254,6 @@ class HyperConnections(Module):
255
254
 
256
255
  self.channel_first = channel_first
257
256
 
258
- # maybe residual transform
259
-
260
- self.residual_transform = default(residual_transform, nn.Identity())
261
-
262
257
  # maybe custom depth connection residual function
263
258
  # this is to prepare for gating the addition of the branch outputs to the residual streams
264
259
  # needed for memory lanes a la RMT / LMM
@@ -271,8 +266,6 @@ class HyperConnections(Module):
271
266
  ):
272
267
  streams = self.num_residual_streams
273
268
 
274
- maybe_transformed_residuals = self.residual_transform(residuals)
275
-
276
269
  # width connection
277
270
 
278
271
  # handle channel first
@@ -334,7 +327,14 @@ class HyperConnections(Module):
334
327
 
335
328
  branch_input = self.merge_fracs(branch_input)
336
329
 
337
- return branch_input, maybe_transformed_residuals, dict(beta = beta)
330
+ # reshape residuals back
331
+
332
+ if self.channel_first:
333
+ residuals = rearrange(residuals, 'b ... f s d -> (b s) (f d) ...')
334
+ else:
335
+ residuals = rearrange(residuals, 'b ... f s d -> (b s) ... (f d)')
336
+
337
+ return branch_input, residuals, dict(beta = beta)
338
338
 
339
339
  def depth_connection(
340
340
  self,
@@ -84,7 +84,6 @@ class HyperConnections(Module):
84
84
  tanh = True,
85
85
  channel_first = True,
86
86
  dropout = 0.,
87
- residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
88
87
  ):
89
88
  """
90
89
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -124,19 +123,12 @@ class HyperConnections(Module):
124
123
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
125
124
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
126
125
 
127
-
128
126
  # dropouts
129
127
 
130
128
  self.dropout = nn.Dropout(dropout)
131
129
 
132
- # maybe residual transform
133
-
134
- self.residual_transform = default(residual_transform, nn.Identity())
135
-
136
130
  def width_connection(self, residuals):
137
131
 
138
- maybe_transformed_residuals = self.residual_transform(residuals)
139
-
140
132
  # width connection
141
133
 
142
134
  normed = self.norm(residuals)
@@ -161,7 +153,9 @@ class HyperConnections(Module):
161
153
 
162
154
  branch_input, residuals = mix_h[:, 0, ...], mix_h[:, 1:, ...]
163
155
 
164
- return branch_input, maybe_transformed_residuals, dict(beta = beta)
156
+ residuals = rearrange(residuals, 'b s d ... -> (b s) d ...')
157
+
158
+ return branch_input, residuals, dict(beta = beta)
165
159
 
166
160
  def depth_connection(self, branch_output, residuals, *, beta):
167
161
  # 'depth' connection
@@ -307,7 +307,7 @@ class ManifoldConstrainedHyperConnections(Module):
307
307
  ):
308
308
  streams = self.num_residual_streams
309
309
 
310
- maybe_transformed_residuals = self.residual_transform(residuals)
310
+ residuals = self.residual_transform(residuals)
311
311
 
312
312
  # width connection
313
313
 
@@ -397,13 +397,14 @@ class ManifoldConstrainedHyperConnections(Module):
397
397
 
398
398
  branch_input = self.merge_fracs(branch_input)
399
399
 
400
- branch_input = branch_input.to(dtype)
401
- residuals = residuals.to(dtype)
400
+ residuals = rearrange(residuals, 'b ... f s d -> (b s) ... (f d)')
401
+
402
+ branch_input, residuals = tuple(t.to(dtype) for t in (branch_input, residuals))
402
403
 
403
404
  if exists(beta):
404
405
  beta = beta.to(dtype)
405
406
 
406
- return branch_input, maybe_transformed_residuals, dict(beta = beta)
407
+ return branch_input, residuals, dict(beta = beta)
407
408
 
408
409
  def depth_connection(
409
410
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,12 @@
1
+ hyper_connections/__init__.py,sha256=BAGwi53ozXcnfPJAGur0RHA4vcolF1ORBhbZ9a8SkrE,602
2
+ hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD6S-xdZdLo,14904
3
+ hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
4
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=SkGAWpBHnrOlIcixb0iIGej9StO82O7KXrFjYuSKx7I,17424
7
+ hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
+ hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
+ hyper_connections-0.3.10.dist-info/METADATA,sha256=tEYVvFTVY_13gYQbflz-mjWMQEwH4DvPvQk76X9Iq2E,6705
10
+ hyper_connections-0.3.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ hyper_connections-0.3.10.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
+ hyper_connections-0.3.10.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- hyper_connections/__init__.py,sha256=BAGwi53ozXcnfPJAGur0RHA4vcolF1ORBhbZ9a8SkrE,602
2
- hyper_connections/hyper_connections.py,sha256=UHxZhyRwx89GRgmQVt53Gv6JeNhX8UCjjETlydMZjTk,15021
3
- hyper_connections/hyper_connections_channel_first.py,sha256=_1PM4LRcPpDqfCiHlBMc2nLV08sXM2nuyZGSKTiuqbE,6818
4
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=uF9WALGLeEBdfUm_p8O8ZTmmsk3L44gg-G1GW1SCMO0,17382
7
- hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
- hyper_connections/vit.py,sha256=fTC8hAYkD4qm-KURAj8SJ66C6ZWtsBdHf_kS-4rJZGQ,5049
9
- hyper_connections-0.3.9.dist-info/METADATA,sha256=mAciMU5pRr1oxP1OvjUFnwJAxqU9RTOKjPs-I7xn1ns,6704
10
- hyper_connections-0.3.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- hyper_connections-0.3.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
- hyper_connections-0.3.9.dist-info/RECORD,,