hyper-connections 0.3.5__tar.gz → 0.3.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/PKG-INFO +1 -1
  2. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/manifold_constrained_hyper_connections.py +27 -16
  3. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/pyproject.toml +1 -1
  4. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/tests/test_hyper_connections.py +21 -0
  5. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/.github/workflows/python-publish.yml +0 -0
  6. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/.github/workflows/test.yml +0 -0
  7. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/.gitignore +0 -0
  8. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/LICENSE +0 -0
  9. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/README.md +0 -0
  10. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper-connections.png +0 -0
  11. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/__init__.py +0 -0
  12. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections.py +0 -0
  13. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_channel_first.py +0 -0
  14. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  15. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  16. {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/residuals.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.5
3
+ Version: 0.3.7
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -46,6 +46,9 @@ def l1norm(t, dim):
46
46
  return F.normalize(t, p = 1, dim = dim)
47
47
 
48
48
  def sinkhorn_knopps(log_alpha, iters = 20):
49
+ dtype = log_alpha.dtype
50
+ log_alpha = log_alpha.float()
51
+
49
52
  log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
50
53
 
51
54
  alpha = log_alpha.exp()
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
54
57
  alpha = l1norm(alpha, dim = -2)
55
58
  alpha = l1norm(alpha, dim = -1)
56
59
 
57
- return alpha
60
+ return alpha.to(dtype)
58
61
 
59
62
  # main functions
60
63
 
@@ -222,7 +225,7 @@ class ManifoldConstrainedHyperConnections(Module):
222
225
 
223
226
  # they used layernorm in paper, but rmsnorm is fine given what we know now
224
227
 
225
- self.norm = RMSNorm(dim * num_residual_streams * num_fracs)
228
+ self.norm = RMSNorm(dim)
226
229
 
227
230
  assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
228
231
 
@@ -312,25 +315,25 @@ class ManifoldConstrainedHyperConnections(Module):
312
315
 
313
316
  # norm
314
317
 
315
- normed = rearrange(residuals, 'b ... f s d -> b ... (f s d)')
318
+ normed = self.norm(residuals)
316
319
 
317
- normed = self.norm(normed)
320
+ # alpha for weighted sum of residuals going into branch
318
321
 
319
- normed = rearrange(normed, 'b ... (f s d) -> b ... f s d', f = self.num_fracs, s = streams)
322
+ dtype = residuals.dtype
320
323
 
321
- # alpha for weighted sum of residuals going into branch
324
+ normed = normed.float()
322
325
 
323
- wc_weight = normed @ self.dynamic_alpha_fn
326
+ wc_weight = normed @ self.dynamic_alpha_fn.float()
324
327
 
325
- pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
326
- residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
328
+ pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
329
+ residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
327
330
  alpha_scale = cat((pre_branch_scale, residual_scale))
328
331
 
329
332
  alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
330
333
 
331
334
  dynamic_alpha = wc_weight * alpha_scale
332
335
 
333
- static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
336
+ static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
334
337
 
335
338
  alpha = dynamic_alpha + static_alpha
336
339
 
@@ -355,20 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
355
358
  beta = None
356
359
 
357
360
  if self.add_branch_out_to_residual:
358
- dc_weight = normed @ self.dynamic_beta_fn
361
+ dc_weight = normed @ self.dynamic_beta_fn.float()
359
362
 
360
363
  dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
361
364
 
362
365
  if not self.has_fracs:
363
366
  dc_weight = rearrange(dc_weight, '... -> ... 1')
364
367
 
365
- dynamic_beta = dc_weight * self.h_post_scale
368
+ dynamic_beta = dc_weight * self.h_post_scale.float()
366
369
 
367
- static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
370
+ static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
368
371
 
369
372
  beta = dynamic_beta + static_beta
370
373
 
371
- mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
374
+ mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
372
375
 
373
376
  if self.num_input_views == 1:
374
377
  branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
@@ -383,6 +386,12 @@ class ManifoldConstrainedHyperConnections(Module):
383
386
 
384
387
  branch_input = self.merge_fracs(branch_input)
385
388
 
389
+ branch_input = branch_input.to(dtype)
390
+ residuals = residuals.to(dtype)
391
+
392
+ if exists(beta):
393
+ beta = beta.to(dtype)
394
+
386
395
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
387
396
 
388
397
  def depth_connection(
@@ -403,7 +412,9 @@ class ManifoldConstrainedHyperConnections(Module):
403
412
  if self.channel_first:
404
413
  branch_output = rearrange(branch_output, 'b d ... -> b ... d')
405
414
 
406
- output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
415
+ dtype = residuals.dtype
416
+
417
+ output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
407
418
 
408
419
  output = rearrange(output, 'b ... s d -> (b s) ... d')
409
420
 
@@ -416,7 +427,7 @@ class ManifoldConstrainedHyperConnections(Module):
416
427
  if self.channel_first:
417
428
  output = rearrange(output, 'b ... d -> b d ...')
418
429
 
419
- residuals = self.depth_residual_fn(output, residuals)
430
+ residuals = self.depth_residual_fn(output.to(dtype), residuals)
420
431
 
421
432
  return self.dropout(residuals)
422
433
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.3.5"
3
+ version = "0.3.7"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -231,3 +231,24 @@ def test_channel_first_hyper_connection(disable):
231
231
  after_residual = reduce_stream(residual)
232
232
 
233
233
  assert before_residual.shape == after_residual.shape
234
+
235
+ def test_mhc_dtype_restoration():
236
+ from hyper_connections.manifold_constrained_hyper_connections import ManifoldConstrainedHyperConnections
237
+
238
+ mhc = ManifoldConstrainedHyperConnections(
239
+ num_residual_streams = 4,
240
+ dim = 64,
241
+ add_branch_out_to_residual = True
242
+ )
243
+
244
+ residual = torch.randn(4, 1, 64).half()
245
+
246
+ branch_input, _, residual_kwargs = mhc.width_connection(residual)
247
+
248
+ assert branch_input.dtype == torch.half
249
+ assert residual_kwargs['beta'].dtype == torch.half
250
+
251
+ branch_output = torch.randn_like(branch_input).half()
252
+ residual = mhc.depth_connection(branch_output, residual, **residual_kwargs)
253
+
254
+ assert residual.dtype == torch.half