heavyball 2.1.0__tar.gz → 2.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-2.1.0 → heavyball-2.1.2}/PKG-INFO +3 -3
  2. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/utils.py +96 -16
  3. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/PKG-INFO +3 -3
  4. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/requires.txt +2 -2
  5. {heavyball-2.1.0 → heavyball-2.1.2}/pyproject.toml +3 -3
  6. {heavyball-2.1.0 → heavyball-2.1.2}/LICENSE +0 -0
  7. {heavyball-2.1.0 → heavyball-2.1.2}/README.md +0 -0
  8. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/__init__.py +0 -0
  9. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/chainable.py +0 -0
  10. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/helpers.py +0 -0
  11. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
  12. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
  13. {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/top_level.txt +0 -0
  14. {heavyball-2.1.0 → heavyball-2.1.2}/setup.cfg +0 -0
  15. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_bf16_params.py +0 -0
  16. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_bf16_q.py +0 -0
  17. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_bf16_storage.py +0 -0
  18. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_caution.py +0 -0
  19. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_channels_last.py +0 -0
  20. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_clip.py +0 -0
  21. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_closure.py +0 -0
  22. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_ema.py +0 -0
  23. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_foreach.py +0 -0
  24. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_hook.py +0 -0
  25. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_mars.py +0 -0
  26. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_memory.py +0 -0
  27. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_memory_leak.py +0 -0
  28. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_merge.py +0 -0
  29. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_migrate_cli.py +0 -0
  30. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_nd_param.py +0 -0
  31. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_no_grad.py +0 -0
  32. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_psgd_precond_init_stability.py +0 -0
  33. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_save_restore.py +0 -0
  34. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_singular_values.py +0 -0
  35. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_soap.py +0 -0
  36. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_stochastic_updates.py +0 -0
  37. {heavyball-2.1.0 → heavyball-2.1.2}/test/test_toy_training.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.0
3
+ Version: 2.1.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: opt-einsum>=3.4.0
19
- Requires-Dist: torch>=2.7.0
20
- Requires-Dist: numpy
19
+ Requires-Dist: torch<3.0,>=2.2
20
+ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
@@ -27,6 +27,7 @@ class ZerothPowerMode(enum.Enum):
27
27
  qr = "qr"
28
28
  svd = "svd"
29
29
  legacy_svd = "legacy_svd"
30
+ thinky_polar_express = "thinky_polar_express"
30
31
 
31
32
 
32
33
  class OrthoScaleMode(enum.Enum):
@@ -390,11 +391,12 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
390
391
  ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
391
392
  assert steps == 5
392
393
  G = G.clone()
393
- X = G if G.dtype == torch.float64 else stochastic_round_(G)
394
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
394
395
  if G.size(-2) > G.size(-1):
395
- X = X.mT
396
+ x = x.mT
396
397
 
397
- stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
398
+ # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
399
+ stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
398
400
  # Perform the NS iterations
399
401
  for a, b, c in [
400
402
  (4.0848, -6.8946, 2.9270),
@@ -403,13 +405,70 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
403
405
  (2.8769, -3.1427, 1.2046),
404
406
  (2.8366, -3.0525, 1.2012),
405
407
  ]:
406
- A = X @ X.mT
407
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
408
- X = a * X + B @ X
408
+ s = x @ x.mT
409
+ y = c * s
410
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
411
+ y = y @ s
412
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
413
+ x = y @ x
409
414
 
410
415
  if G.size(-2) > G.size(-1):
411
- X = X.mT
412
- return X.to(G.dtype)
416
+ x = x.mT
417
+ return x.to(G.dtype)
418
+
419
+
420
+ ###### START
421
+ # Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
422
+ # under the MIT License
423
+
424
+ ABC_LIST: list[tuple[float, float, float]] = [
425
+ (8.28721201814563, -23.595886519098837, 17.300387312530933),
426
+ (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
427
+ (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
428
+ (3.3184196573706015, -2.488488024314874, 0.51004894012372),
429
+ (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
430
+ (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
431
+ (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
432
+ (1.875, -1.25, 0.375),
433
+ ]
434
+
435
+ # safety factor for numerical stability (but exclude last polynomial)
436
+ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
437
+ (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
438
+ ] + [ABC_LIST[-1]]
439
+
440
+
441
+ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
442
+ """
443
+ Polar Express algorithm for the matrix sign function:
444
+ https://arxiv.org/abs/2505.16932
445
+ """
446
+ assert G.ndim >= 2
447
+ should_transpose: bool = G.size(-2) > G.size(-1)
448
+
449
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
450
+ if should_transpose:
451
+ x = x.mT
452
+
453
+ x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
454
+ for step in range(steps):
455
+ a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
456
+ s = x @ x.mT
457
+ # goal is to compute x = a x + b S x + c S^2 x
458
+ # we can break this up into: x = (a I + (b I + c S) S) x
459
+ y = c * s
460
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
461
+ y = y @ s
462
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
463
+ x = y @ x
464
+
465
+ if should_transpose:
466
+ x = x.mT
467
+ x = torch.nan_to_num(x)
468
+ return x.float()
469
+
470
+
471
+ ###### END
413
472
 
414
473
 
415
474
  @decorator_knowngood
@@ -417,17 +476,22 @@ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
417
476
  assert len(G.shape) == 2
418
477
  a, b, c = (3.4445, -4.7750, 2.0315)
419
478
  G = G.clone()
420
- X = G if G.dtype == torch.float64 else stochastic_round_(G)
421
- stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
479
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
480
+
481
+ # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
482
+ stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
422
483
  if G.size(0) > G.size(1):
423
- X = X.T
484
+ x = x.T
424
485
  for _ in range(steps):
425
- A = X @ X.T
426
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
427
- X = a * X + B @ X
486
+ s = x @ x.mT
487
+ y = c * s
488
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
489
+ y = y @ s
490
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
491
+ x = y @ x
428
492
  if G.size(0) > G.size(1):
429
- X = X.T
430
- return X.to(G.dtype)
493
+ x = x.T
494
+ return x.to(G.dtype)
431
495
 
432
496
 
433
497
  @decorator_knowngood
@@ -489,6 +553,8 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
489
553
  scale_mode = OrthoScaleMode(scale_mode)
490
554
  if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
491
555
  y = zeropower_via_newtonschulz5(x, 5)
556
+ elif mode == ZerothPowerMode.thinky_polar_express:
557
+ y = msign(x, 10)
492
558
  elif mode == ZerothPowerMode.legacy_newtonschulz:
493
559
  y = legacy_zeropower_via_newtonschulz5(x, 5)
494
560
  elif mode == ZerothPowerMode.qr:
@@ -755,6 +821,20 @@ def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
755
821
  _compilable_stochastic_multiply_(x, y)
756
822
 
757
823
 
824
+ @decorator_knowngood
825
+ def _compilable_stochastic_divide_with_eps_(x: List[Tensor], y: List[Tensor], eps: Tensor):
826
+ for x_, y_ in zip(x, y):
827
+ x32 = promote(x_)
828
+ y32 = promote(y_)
829
+ copy_stochastic_(x_, x32 / (y32 + eps))
830
+
831
+
832
+ def stochastic_divide_with_eps_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, eps: float):
833
+ x, y = broadcastable_list_guard(x, y)
834
+ eps = scalar_guard(eps, y[0])
835
+ _compilable_stochastic_divide_with_eps_(x, y, eps)
836
+
837
+
758
838
  @decorator
759
839
  def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
760
840
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.0
3
+ Version: 2.1.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: opt-einsum>=3.4.0
19
- Requires-Dist: torch>=2.7.0
20
- Requires-Dist: numpy
19
+ Requires-Dist: torch<3.0,>=2.2
20
+ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
@@ -1,6 +1,6 @@
1
1
  opt-einsum>=3.4.0
2
- torch>=2.7.0
3
- numpy
2
+ torch<3.0,>=2.2
3
+ numpy<2.0.0
4
4
 
5
5
  [dev]
6
6
  pre-commit
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "heavyball"
7
7
  description = "Efficient Optimizers"
8
- version = "2.1.0"
8
+ version = "2.1.2"
9
9
  authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
10
  classifiers = ["Intended Audience :: Developers",
11
11
  "Intended Audience :: Science/Research",
@@ -15,8 +15,8 @@ classifiers = ["Intended Audience :: Developers",
15
15
  "Programming Language :: Python :: 3",
16
16
  ]
17
17
  dependencies = ["opt-einsum>=3.4.0",
18
- "torch>=2.7.0",
19
- "numpy",
18
+ "torch>=2.2,<3.0",
19
+ "numpy<2.0.0",
20
20
  ]
21
21
  keywords = ["torch",
22
22
  "optimizer",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes