hyper-connections 0.3.6__tar.gz → 0.3.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/PKG-INFO +1 -1
  2. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/manifold_constrained_hyper_connections.py +26 -11
  3. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/pyproject.toml +1 -1
  4. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/tests/test_hyper_connections.py +21 -0
  5. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/.github/workflows/python-publish.yml +0 -0
  6. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/.github/workflows/test.yml +0 -0
  7. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/.gitignore +0 -0
  8. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/LICENSE +0 -0
  9. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/README.md +0 -0
  10. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper-connections.png +0 -0
  11. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/__init__.py +0 -0
  12. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections.py +0 -0
  13. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_channel_first.py +0 -0
  14. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  15. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  16. {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/residuals.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.3.6
3
+ Version: 0.3.7
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -46,6 +46,9 @@ def l1norm(t, dim):
46
46
  return F.normalize(t, p = 1, dim = dim)
47
47
 
48
48
  def sinkhorn_knopps(log_alpha, iters = 20):
49
+ dtype = log_alpha.dtype
50
+ log_alpha = log_alpha.float()
51
+
49
52
  log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
50
53
 
51
54
  alpha = log_alpha.exp()
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
54
57
  alpha = l1norm(alpha, dim = -2)
55
58
  alpha = l1norm(alpha, dim = -1)
56
59
 
57
- return alpha
60
+ return alpha.to(dtype)
58
61
 
59
62
  # main functions
60
63
 
@@ -316,17 +319,21 @@ class ManifoldConstrainedHyperConnections(Module):
316
319
 
317
320
  # alpha for weighted sum of residuals going into branch
318
321
 
319
- wc_weight = normed @ self.dynamic_alpha_fn
322
+ dtype = residuals.dtype
323
+
324
+ normed = normed.float()
325
+
326
+ wc_weight = normed @ self.dynamic_alpha_fn.float()
320
327
 
321
- pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
322
- residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
328
+ pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
329
+ residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
323
330
  alpha_scale = cat((pre_branch_scale, residual_scale))
324
331
 
325
332
  alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
326
333
 
327
334
  dynamic_alpha = wc_weight * alpha_scale
328
335
 
329
- static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
336
+ static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
330
337
 
331
338
  alpha = dynamic_alpha + static_alpha
332
339
 
@@ -351,20 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
351
358
  beta = None
352
359
 
353
360
  if self.add_branch_out_to_residual:
354
- dc_weight = normed @ self.dynamic_beta_fn
361
+ dc_weight = normed @ self.dynamic_beta_fn.float()
355
362
 
356
363
  dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
357
364
 
358
365
  if not self.has_fracs:
359
366
  dc_weight = rearrange(dc_weight, '... -> ... 1')
360
367
 
361
- dynamic_beta = dc_weight * self.h_post_scale
368
+ dynamic_beta = dc_weight * self.h_post_scale.float()
362
369
 
363
- static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
370
+ static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
364
371
 
365
372
  beta = dynamic_beta + static_beta
366
373
 
367
- mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
374
+ mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
368
375
 
369
376
  if self.num_input_views == 1:
370
377
  branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
@@ -379,6 +386,12 @@ class ManifoldConstrainedHyperConnections(Module):
379
386
 
380
387
  branch_input = self.merge_fracs(branch_input)
381
388
 
389
+ branch_input = branch_input.to(dtype)
390
+ residuals = residuals.to(dtype)
391
+
392
+ if exists(beta):
393
+ beta = beta.to(dtype)
394
+
382
395
  return branch_input, maybe_transformed_residuals, dict(beta = beta)
383
396
 
384
397
  def depth_connection(
@@ -399,7 +412,9 @@ class ManifoldConstrainedHyperConnections(Module):
399
412
  if self.channel_first:
400
413
  branch_output = rearrange(branch_output, 'b d ... -> b ... d')
401
414
 
402
- output = einsum(branch_output, beta, 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
415
+ dtype = residuals.dtype
416
+
417
+ output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
403
418
 
404
419
  output = rearrange(output, 'b ... s d -> (b s) ... d')
405
420
 
@@ -412,7 +427,7 @@ class ManifoldConstrainedHyperConnections(Module):
412
427
  if self.channel_first:
413
428
  output = rearrange(output, 'b ... d -> b d ...')
414
429
 
415
- residuals = self.depth_residual_fn(output, residuals)
430
+ residuals = self.depth_residual_fn(output.to(dtype), residuals)
416
431
 
417
432
  return self.dropout(residuals)
418
433
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.3.6"
3
+ version = "0.3.7"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -231,3 +231,24 @@ def test_channel_first_hyper_connection(disable):
231
231
  after_residual = reduce_stream(residual)
232
232
 
233
233
  assert before_residual.shape == after_residual.shape
234
+
235
+ def test_mhc_dtype_restoration():
236
+ from hyper_connections.manifold_constrained_hyper_connections import ManifoldConstrainedHyperConnections
237
+
238
+ mhc = ManifoldConstrainedHyperConnections(
239
+ num_residual_streams = 4,
240
+ dim = 64,
241
+ add_branch_out_to_residual = True
242
+ )
243
+
244
+ residual = torch.randn(4, 1, 64).half()
245
+
246
+ branch_input, _, residual_kwargs = mhc.width_connection(residual)
247
+
248
+ assert branch_input.dtype == torch.half
249
+ assert residual_kwargs['beta'].dtype == torch.half
250
+
251
+ branch_output = torch.randn_like(branch_input).half()
252
+ residual = mhc.depth_connection(branch_output, residual, **residual_kwargs)
253
+
254
+ assert residual.dtype == torch.half