hyper-connections 0.3.6__tar.gz → 0.3.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/PKG-INFO +1 -1
  2. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/manifold_constrained_hyper_connections.py +38 -12
  3. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/pyproject.toml +1 -1
  4. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/tests/test_hyper_connections.py +21 -0
  5. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/.github/workflows/python-publish.yml +0 -0
  6. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/.github/workflows/test.yml +0 -0
  7. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/.gitignore +0 -0
  8. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/LICENSE +0 -0
  9. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/README.md +0 -0
  10. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper-connections.png +0 -0
  11. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/__init__.py +0 -0
  12. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections.py +0 -0
  13. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections_channel_first.py +0 -0
  14. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  15. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  16. {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/residuals.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.6
3
+ Version: 0.3.8
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -46,6 +46,9 @@ def l1norm(t, dim):
46
46
  return F.normalize(t, p = 1, dim = dim)
47
47
 
48
48
  def sinkhorn_knopps(log_alpha, iters = 20):
49
+ dtype = log_alpha.dtype
50
+ log_alpha = log_alpha.float()
51
+
49
52
  log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
50
53
 
51
54
  alpha = log_alpha.exp()
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
54
57
  alpha = l1norm(alpha, dim = -2)
55
58
  alpha = l1norm(alpha, dim = -1)
56
59
 
57
- return alpha
60
+ return alpha.to(dtype)
58
61
 
59
62
  # main functions
60
63
 
@@ -197,7 +200,8 @@ class ManifoldConstrainedHyperConnections(Module):
197
200
  num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
198
201
  depth_residual_fn = add,
199
202
  num_fracs = 1, # https://arxiv.org/abs/2503.14125
200
- sinkhorn_iters = 20
203
+ sinkhorn_iters = 20,
204
+ forward_method_names: tuple[str, ...] = (),
201
205
  ):
202
206
  """
203
207
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -287,6 +291,16 @@ class ManifoldConstrainedHyperConnections(Module):
287
291
 
288
292
  self.depth_residual_fn = depth_residual_fn
289
293
 
294
+ # forwarding method names
295
+
296
+ self.forward_method_names = forward_method_names
297
+
298
+ for forward_method_name in self.forward_method_names:
299
+ assert not hasattr(self, forward_method_name)
300
+
301
+ fn = getattr(self.branch, forward_method_name)
302
+ setattr(self, forward_method_name, fn)
303
+
290
304
  def width_connection(
291
305
  self,
292
306
  residuals
@@ -316,17 +330,21 @@ class ManifoldConstrainedHyperConnections(Module):
316
330
 
317
331
  # alpha for weighted sum of residuals going into branch
318
332
 
319
- wc_weight = normed @ self.dynamic_alpha_fn
333
+ dtype = residuals.dtype
320
334
 
321
- pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
322
- residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
335
+ normed = normed.float()
336
+
337
+ wc_weight = normed @ self.dynamic_alpha_fn.float()
338
+
339
+ pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
340
+ residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
323
341
  alpha_scale = cat((pre_branch_scale, residual_scale))
324
342
 
325
343
  alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
326
344
 
327
345
  dynamic_alpha = wc_weight * alpha_scale
328
346
 
329
- static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
347
+ static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
330
348
 
331
349
  alpha = dynamic_alpha + static_alpha
332
350
 
@@ -351,20 +369,20 @@ class ManifoldConstrainedHyperConnections(Module):
351
369
  beta = None
352
370
 
353
371
  if self.add_branch_out_to_residual:
354
- dc_weight = normed @ self.dynamic_beta_fn
372
+ dc_weight = normed @ self.dynamic_beta_fn.float()
355
373
 
356
374
  dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
357
375
 
358
376
  if not self.has_fracs:
359
377
  dc_weight = rearrange(dc_weight, '... -> ... 1')
360
378
 
361
- dynamic_beta = dc_weight * self.h_post_scale
379
+ dynamic_beta = dc_weight * self.h_post_scale.float()
362
380
 
363
- static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
381
+ static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
364
382
 
365
383
  beta = dynamic_beta + static_beta
366
384
 
367
- mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
385
+ mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
368
386
 
369
387
  if self.num_input_views == 1:
370
388
  branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
@@ -379,6 +397,12 @@ class ManifoldConstrainedHyperConnections(Module):
379
397
 
380
398
  branch_input = self.merge_fracs(branch_input)
381
399
 
400
+ branch_input = branch_input.to(dtype)
401
+ residuals = residuals.to(dtype)
402
+
403
+ if exists(beta):
404
+ beta = beta.to(dtype)
405
+
382
406
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
383
407
 
384
408
  def depth_connection(
@@ -399,7 +423,9 @@ class ManifoldConstrainedHyperConnections(Module):
399
423
  if self.channel_first:
400
424
  branch_output = rearrange(branch_output, 'b d ... -> b ... d')
401
425
 
402
- output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
426
+ dtype = residuals.dtype
427
+
428
+ output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
403
429
 
404
430
  output = rearrange(output, 'b ... s d -> (b s) ... d')
405
431
 
@@ -412,7 +438,7 @@ class ManifoldConstrainedHyperConnections(Module):
412
438
  if self.channel_first:
413
439
  output = rearrange(output, 'b ... d -> b d ...')
414
440
 
415
- residuals = self.depth_residual_fn(output, residuals)
441
+ residuals = self.depth_residual_fn(output.to(dtype), residuals)
416
442
 
417
443
  return self.dropout(residuals)
418
444
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.3.6"
3
+ version = "0.3.8"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -231,3 +231,24 @@ def test_channel_first_hyper_connection(disable):
231
231
  after_residual = reduce_stream(residual)
232
232
 
233
233
  assert before_residual.shape == after_residual.shape
234
+
235
+ def test_mhc_dtype_restoration():
236
+ from hyper_connections.manifold_constrained_hyper_connections import ManifoldConstrainedHyperConnections
237
+
238
+ mhc = ManifoldConstrainedHyperConnections(
239
+ num_residual_streams = 4,
240
+ dim = 64,
241
+ add_branch_out_to_residual = True
242
+ )
243
+
244
+ residual = torch.randn(4, 1, 64).half()
245
+
246
+ branch_input, _, residual_kwargs = mhc.width_connection(residual)
247
+
248
+ assert branch_input.dtype == torch.half
249
+ assert residual_kwargs['beta'].dtype == torch.half
250
+
251
+ branch_output = torch.randn_like(branch_input).half()
252
+ residual = mhc.depth_connection(branch_output, residual, **residual_kwargs)
253
+
254
+ assert residual.dtype == torch.half