heavyball 2.1.0__tar.gz → 2.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.1.0 → heavyball-2.1.2}/PKG-INFO +3 -3
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/utils.py +96 -16
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/PKG-INFO +3 -3
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/requires.txt +2 -2
- {heavyball-2.1.0 → heavyball-2.1.2}/pyproject.toml +3 -3
- {heavyball-2.1.0 → heavyball-2.1.2}/LICENSE +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/README.md +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/__init__.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/chainable.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball/helpers.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/setup.cfg +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_bf16_params.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_bf16_q.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_bf16_storage.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_caution.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_channels_last.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_clip.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_closure.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_ema.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_foreach.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_hook.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_mars.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_memory.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_memory_leak.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_merge.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_migrate_cli.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_nd_param.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_no_grad.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_psgd_precond_init_stability.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_save_restore.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_singular_values.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_soap.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_stochastic_updates.py +0 -0
- {heavyball-2.1.0 → heavyball-2.1.2}/test/test_toy_training.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.2
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
|
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
License-File: LICENSE
|
18
18
|
Requires-Dist: opt-einsum>=3.4.0
|
19
|
-
Requires-Dist: torch
|
20
|
-
Requires-Dist: numpy
|
19
|
+
Requires-Dist: torch<3.0,>=2.2
|
20
|
+
Requires-Dist: numpy<2.0.0
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
@@ -27,6 +27,7 @@ class ZerothPowerMode(enum.Enum):
|
|
27
27
|
qr = "qr"
|
28
28
|
svd = "svd"
|
29
29
|
legacy_svd = "legacy_svd"
|
30
|
+
thinky_polar_express = "thinky_polar_express"
|
30
31
|
|
31
32
|
|
32
33
|
class OrthoScaleMode(enum.Enum):
|
@@ -390,11 +391,12 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
390
391
|
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
391
392
|
assert steps == 5
|
392
393
|
G = G.clone()
|
393
|
-
|
394
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
394
395
|
if G.size(-2) > G.size(-1):
|
395
|
-
|
396
|
+
x = x.mT
|
396
397
|
|
397
|
-
|
398
|
+
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
399
|
+
stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
398
400
|
# Perform the NS iterations
|
399
401
|
for a, b, c in [
|
400
402
|
(4.0848, -6.8946, 2.9270),
|
@@ -403,13 +405,70 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
403
405
|
(2.8769, -3.1427, 1.2046),
|
404
406
|
(2.8366, -3.0525, 1.2012),
|
405
407
|
]:
|
406
|
-
|
407
|
-
|
408
|
-
|
408
|
+
s = x @ x.mT
|
409
|
+
y = c * s
|
410
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
411
|
+
y = y @ s
|
412
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
413
|
+
x = y @ x
|
409
414
|
|
410
415
|
if G.size(-2) > G.size(-1):
|
411
|
-
|
412
|
-
return
|
416
|
+
x = x.mT
|
417
|
+
return x.to(G.dtype)
|
418
|
+
|
419
|
+
|
420
|
+
###### START
|
421
|
+
# Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
422
|
+
# under the MIT License
|
423
|
+
|
424
|
+
ABC_LIST: list[tuple[float, float, float]] = [
|
425
|
+
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
426
|
+
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
427
|
+
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
|
428
|
+
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
|
429
|
+
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
|
430
|
+
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
|
431
|
+
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
|
432
|
+
(1.875, -1.25, 0.375),
|
433
|
+
]
|
434
|
+
|
435
|
+
# safety factor for numerical stability (but exclude last polynomial)
|
436
|
+
ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
437
|
+
(a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
|
438
|
+
] + [ABC_LIST[-1]]
|
439
|
+
|
440
|
+
|
441
|
+
def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
442
|
+
"""
|
443
|
+
Polar Express algorithm for the matrix sign function:
|
444
|
+
https://arxiv.org/abs/2505.16932
|
445
|
+
"""
|
446
|
+
assert G.ndim >= 2
|
447
|
+
should_transpose: bool = G.size(-2) > G.size(-1)
|
448
|
+
|
449
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
450
|
+
if should_transpose:
|
451
|
+
x = x.mT
|
452
|
+
|
453
|
+
x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
|
454
|
+
for step in range(steps):
|
455
|
+
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
456
|
+
s = x @ x.mT
|
457
|
+
# goal is to compute x = a x + b S x + c S^2 x
|
458
|
+
# we can break this up into: x = (a I + (b I + c S) S) x
|
459
|
+
y = c * s
|
460
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
461
|
+
y = y @ s
|
462
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
463
|
+
x = y @ x
|
464
|
+
|
465
|
+
if should_transpose:
|
466
|
+
x = x.mT
|
467
|
+
x = torch.nan_to_num(x)
|
468
|
+
return x.float()
|
469
|
+
|
470
|
+
|
471
|
+
###### END
|
413
472
|
|
414
473
|
|
415
474
|
@decorator_knowngood
|
@@ -417,17 +476,22 @@ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
417
476
|
assert len(G.shape) == 2
|
418
477
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
419
478
|
G = G.clone()
|
420
|
-
|
421
|
-
|
479
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
480
|
+
|
481
|
+
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
482
|
+
stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
422
483
|
if G.size(0) > G.size(1):
|
423
|
-
|
484
|
+
x = x.T
|
424
485
|
for _ in range(steps):
|
425
|
-
|
426
|
-
|
427
|
-
|
486
|
+
s = x @ x.mT
|
487
|
+
y = c * s
|
488
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
489
|
+
y = y @ s
|
490
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
491
|
+
x = y @ x
|
428
492
|
if G.size(0) > G.size(1):
|
429
|
-
|
430
|
-
return
|
493
|
+
x = x.T
|
494
|
+
return x.to(G.dtype)
|
431
495
|
|
432
496
|
|
433
497
|
@decorator_knowngood
|
@@ -489,6 +553,8 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
|
|
489
553
|
scale_mode = OrthoScaleMode(scale_mode)
|
490
554
|
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
|
491
555
|
y = zeropower_via_newtonschulz5(x, 5)
|
556
|
+
elif mode == ZerothPowerMode.thinky_polar_express:
|
557
|
+
y = msign(x, 10)
|
492
558
|
elif mode == ZerothPowerMode.legacy_newtonschulz:
|
493
559
|
y = legacy_zeropower_via_newtonschulz5(x, 5)
|
494
560
|
elif mode == ZerothPowerMode.qr:
|
@@ -755,6 +821,20 @@ def stochastic_multiply_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor):
|
|
755
821
|
_compilable_stochastic_multiply_(x, y)
|
756
822
|
|
757
823
|
|
824
|
+
@decorator_knowngood
|
825
|
+
def _compilable_stochastic_divide_with_eps_(x: List[Tensor], y: List[Tensor], eps: Tensor):
|
826
|
+
for x_, y_ in zip(x, y):
|
827
|
+
x32 = promote(x_)
|
828
|
+
y32 = promote(y_)
|
829
|
+
copy_stochastic_(x_, x32 / (y32 + eps))
|
830
|
+
|
831
|
+
|
832
|
+
def stochastic_divide_with_eps_(x: List[Tensor] | Tensor, y: List[Tensor] | Tensor, eps: float):
|
833
|
+
x, y = broadcastable_list_guard(x, y)
|
834
|
+
eps = scalar_guard(eps, y[0])
|
835
|
+
_compilable_stochastic_divide_with_eps_(x, y, eps)
|
836
|
+
|
837
|
+
|
758
838
|
@decorator
|
759
839
|
def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
760
840
|
"""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.2
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
|
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
License-File: LICENSE
|
18
18
|
Requires-Dist: opt-einsum>=3.4.0
|
19
|
-
Requires-Dist: torch
|
20
|
-
Requires-Dist: numpy
|
19
|
+
Requires-Dist: torch<3.0,>=2.2
|
20
|
+
Requires-Dist: numpy<2.0.0
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
5
5
|
[project]
|
6
6
|
name = "heavyball"
|
7
7
|
description = "Efficient Optimizers"
|
8
|
-
version = "2.1.
|
8
|
+
version = "2.1.2"
|
9
9
|
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
10
|
classifiers = ["Intended Audience :: Developers",
|
11
11
|
"Intended Audience :: Science/Research",
|
@@ -15,8 +15,8 @@ classifiers = ["Intended Audience :: Developers",
|
|
15
15
|
"Programming Language :: Python :: 3",
|
16
16
|
]
|
17
17
|
dependencies = ["opt-einsum>=3.4.0",
|
18
|
-
"torch>=2.
|
19
|
-
"numpy",
|
18
|
+
"torch>=2.2,<3.0",
|
19
|
+
"numpy<2.0.0",
|
20
20
|
]
|
21
21
|
keywords = ["torch",
|
22
22
|
"optimizer",
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|