heavyball 2.1.1__tar.gz → 2.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-2.1.1 → heavyball-2.1.2}/PKG-INFO +3 -3
  2. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/utils.py +79 -16
  3. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/PKG-INFO +3 -3
  4. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/requires.txt +2 -2
  5. {heavyball-2.1.1 → heavyball-2.1.2}/pyproject.toml +3 -3
  6. {heavyball-2.1.1 → heavyball-2.1.2}/LICENSE +0 -0
  7. {heavyball-2.1.1 → heavyball-2.1.2}/README.md +0 -0
  8. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/__init__.py +0 -0
  9. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/chainable.py +0 -0
  10. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/helpers.py +0 -0
  11. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
  12. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
  13. {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/top_level.txt +0 -0
  14. {heavyball-2.1.1 → heavyball-2.1.2}/setup.cfg +0 -0
  15. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_bf16_params.py +0 -0
  16. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_bf16_q.py +0 -0
  17. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_bf16_storage.py +0 -0
  18. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_caution.py +0 -0
  19. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_channels_last.py +0 -0
  20. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_clip.py +0 -0
  21. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_closure.py +0 -0
  22. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_ema.py +0 -0
  23. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_foreach.py +0 -0
  24. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_hook.py +0 -0
  25. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_mars.py +0 -0
  26. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_memory.py +0 -0
  27. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_memory_leak.py +0 -0
  28. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_merge.py +0 -0
  29. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_migrate_cli.py +0 -0
  30. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_nd_param.py +0 -0
  31. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_no_grad.py +0 -0
  32. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_psgd_precond_init_stability.py +0 -0
  33. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_save_restore.py +0 -0
  34. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_singular_values.py +0 -0
  35. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_soap.py +0 -0
  36. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_stochastic_updates.py +0 -0
  37. {heavyball-2.1.1 → heavyball-2.1.2}/test/test_toy_training.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.1
3
+ Version: 2.1.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: opt-einsum>=3.4.0
19
- Requires-Dist: torch>=2.7.0
20
- Requires-Dist: numpy
19
+ Requires-Dist: torch<3.0,>=2.2
20
+ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
@@ -27,6 +27,7 @@ class ZerothPowerMode(enum.Enum):
27
27
  qr = "qr"
28
28
  svd = "svd"
29
29
  legacy_svd = "legacy_svd"
30
+ thinky_polar_express = "thinky_polar_express"
30
31
 
31
32
 
32
33
  class OrthoScaleMode(enum.Enum):
@@ -390,12 +391,12 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
390
391
  ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
391
392
  assert steps == 5
392
393
  G = G.clone()
393
- X = G if G.dtype == torch.float64 else stochastic_round_(G)
394
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
394
395
  if G.size(-2) > G.size(-1):
395
- X = X.mT
396
+ x = x.mT
396
397
 
397
398
  # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
398
- stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
399
+ stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
399
400
  # Perform the NS iterations
400
401
  for a, b, c in [
401
402
  (4.0848, -6.8946, 2.9270),
@@ -404,13 +405,70 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
404
405
  (2.8769, -3.1427, 1.2046),
405
406
  (2.8366, -3.0525, 1.2012),
406
407
  ]:
407
- A = X @ X.mT
408
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
409
- X = a * X + B @ X
408
+ s = x @ x.mT
409
+ y = c * s
410
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
411
+ y = y @ s
412
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
413
+ x = y @ x
410
414
 
411
415
  if G.size(-2) > G.size(-1):
412
- X = X.mT
413
- return X.to(G.dtype)
416
+ x = x.mT
417
+ return x.to(G.dtype)
418
+
419
+
420
+ ###### START
421
+ # Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
422
+ # under the MIT License
423
+
424
+ ABC_LIST: list[tuple[float, float, float]] = [
425
+ (8.28721201814563, -23.595886519098837, 17.300387312530933),
426
+ (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
427
+ (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
428
+ (3.3184196573706015, -2.488488024314874, 0.51004894012372),
429
+ (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
430
+ (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
431
+ (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
432
+ (1.875, -1.25, 0.375),
433
+ ]
434
+
435
+ # safety factor for numerical stability (but exclude last polynomial)
436
+ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
437
+ (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
438
+ ] + [ABC_LIST[-1]]
439
+
440
+
441
+ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
442
+ """
443
+ Polar Express algorithm for the matrix sign function:
444
+ https://arxiv.org/abs/2505.16932
445
+ """
446
+ assert G.ndim >= 2
447
+ should_transpose: bool = G.size(-2) > G.size(-1)
448
+
449
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
450
+ if should_transpose:
451
+ x = x.mT
452
+
453
+ x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
454
+ for step in range(steps):
455
+ a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
456
+ s = x @ x.mT
457
+ # goal is to compute x = a x + b S x + c S^2 x
458
+ # we can break this up into: x = (a I + (b I + c S) S) x
459
+ y = c * s
460
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
461
+ y = y @ s
462
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
463
+ x = y @ x
464
+
465
+ if should_transpose:
466
+ x = x.mT
467
+ x = torch.nan_to_num(x)
468
+ return x.float()
469
+
470
+
471
+ ###### END
414
472
 
415
473
 
416
474
  @decorator_knowngood
@@ -418,19 +476,22 @@ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
418
476
  assert len(G.shape) == 2
419
477
  a, b, c = (3.4445, -4.7750, 2.0315)
420
478
  G = G.clone()
421
- X = G if G.dtype == torch.float64 else stochastic_round_(G)
479
+ x = G if G.dtype == torch.float64 else stochastic_round_(G)
422
480
 
423
481
  # X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
424
- stochastic_divide_with_eps_(X, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
482
+ stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
425
483
  if G.size(0) > G.size(1):
426
- X = X.T
484
+ x = x.T
427
485
  for _ in range(steps):
428
- A = X @ X.T
429
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
430
- X = a * X + B @ X
486
+ s = x @ x.mT
487
+ y = c * s
488
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
489
+ y = y @ s
490
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
491
+ x = y @ x
431
492
  if G.size(0) > G.size(1):
432
- X = X.T
433
- return X.to(G.dtype)
493
+ x = x.T
494
+ return x.to(G.dtype)
434
495
 
435
496
 
436
497
  @decorator_knowngood
@@ -492,6 +553,8 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
492
553
  scale_mode = OrthoScaleMode(scale_mode)
493
554
  if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
494
555
  y = zeropower_via_newtonschulz5(x, 5)
556
+ elif mode == ZerothPowerMode.thinky_polar_express:
557
+ y = msign(x, 10)
495
558
  elif mode == ZerothPowerMode.legacy_newtonschulz:
496
559
  y = legacy_zeropower_via_newtonschulz5(x, 5)
497
560
  elif mode == ZerothPowerMode.qr:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.1.1
3
+ Version: 2.1.2
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: opt-einsum>=3.4.0
19
- Requires-Dist: torch>=2.7.0
20
- Requires-Dist: numpy
19
+ Requires-Dist: torch<3.0,>=2.2
20
+ Requires-Dist: numpy<2.0.0
21
21
  Provides-Extra: dev
22
22
  Requires-Dist: pre-commit; extra == "dev"
23
23
  Requires-Dist: pytest; extra == "dev"
@@ -1,6 +1,6 @@
1
1
  opt-einsum>=3.4.0
2
- torch>=2.7.0
3
- numpy
2
+ torch<3.0,>=2.2
3
+ numpy<2.0.0
4
4
 
5
5
  [dev]
6
6
  pre-commit
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "heavyball"
7
7
  description = "Efficient Optimizers"
8
- version = "2.1.1"
8
+ version = "2.1.2"
9
9
  authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
10
  classifiers = ["Intended Audience :: Developers",
11
11
  "Intended Audience :: Science/Research",
@@ -15,8 +15,8 @@ classifiers = ["Intended Audience :: Developers",
15
15
  "Programming Language :: Python :: 3",
16
16
  ]
17
17
  dependencies = ["opt-einsum>=3.4.0",
18
- "torch>=2.7.0",
19
- "numpy",
18
+ "torch>=2.2,<3.0",
19
+ "numpy<2.0.0",
20
20
  ]
21
21
  keywords = ["torch",
22
22
  "optimizer",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes