heavyball 2.1.1__tar.gz → 2.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.1.1 → heavyball-2.1.2}/PKG-INFO +3 -3
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/utils.py +79 -16
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/PKG-INFO +3 -3
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/requires.txt +2 -2
- {heavyball-2.1.1 → heavyball-2.1.2}/pyproject.toml +3 -3
- {heavyball-2.1.1 → heavyball-2.1.2}/LICENSE +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/README.md +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/__init__.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/chainable.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball/helpers.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/setup.cfg +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_bf16_params.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_bf16_q.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_bf16_storage.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_caution.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_channels_last.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_clip.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_closure.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_ema.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_foreach.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_hook.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_mars.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_memory.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_memory_leak.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_merge.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_migrate_cli.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_nd_param.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_no_grad.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_psgd_precond_init_stability.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_save_restore.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_singular_values.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_soap.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_stochastic_updates.py +0 -0
- {heavyball-2.1.1 → heavyball-2.1.2}/test/test_toy_training.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.2
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
|
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
License-File: LICENSE
|
18
18
|
Requires-Dist: opt-einsum>=3.4.0
|
19
|
-
Requires-Dist: torch
|
20
|
-
Requires-Dist: numpy
|
19
|
+
Requires-Dist: torch<3.0,>=2.2
|
20
|
+
Requires-Dist: numpy<2.0.0
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
@@ -27,6 +27,7 @@ class ZerothPowerMode(enum.Enum):
|
|
27
27
|
qr = "qr"
|
28
28
|
svd = "svd"
|
29
29
|
legacy_svd = "legacy_svd"
|
30
|
+
thinky_polar_express = "thinky_polar_express"
|
30
31
|
|
31
32
|
|
32
33
|
class OrthoScaleMode(enum.Enum):
|
@@ -390,12 +391,12 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
390
391
|
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
391
392
|
assert steps == 5
|
392
393
|
G = G.clone()
|
393
|
-
|
394
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
394
395
|
if G.size(-2) > G.size(-1):
|
395
|
-
|
396
|
+
x = x.mT
|
396
397
|
|
397
398
|
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
398
|
-
stochastic_divide_with_eps_(
|
399
|
+
stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
399
400
|
# Perform the NS iterations
|
400
401
|
for a, b, c in [
|
401
402
|
(4.0848, -6.8946, 2.9270),
|
@@ -404,13 +405,70 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
404
405
|
(2.8769, -3.1427, 1.2046),
|
405
406
|
(2.8366, -3.0525, 1.2012),
|
406
407
|
]:
|
407
|
-
|
408
|
-
|
409
|
-
|
408
|
+
s = x @ x.mT
|
409
|
+
y = c * s
|
410
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
411
|
+
y = y @ s
|
412
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
413
|
+
x = y @ x
|
410
414
|
|
411
415
|
if G.size(-2) > G.size(-1):
|
412
|
-
|
413
|
-
return
|
416
|
+
x = x.mT
|
417
|
+
return x.to(G.dtype)
|
418
|
+
|
419
|
+
|
420
|
+
###### START
|
421
|
+
# Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
422
|
+
# under the MIT License
|
423
|
+
|
424
|
+
ABC_LIST: list[tuple[float, float, float]] = [
|
425
|
+
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
426
|
+
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
427
|
+
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
|
428
|
+
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
|
429
|
+
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
|
430
|
+
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
|
431
|
+
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
|
432
|
+
(1.875, -1.25, 0.375),
|
433
|
+
]
|
434
|
+
|
435
|
+
# safety factor for numerical stability (but exclude last polynomial)
|
436
|
+
ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
437
|
+
(a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
|
438
|
+
] + [ABC_LIST[-1]]
|
439
|
+
|
440
|
+
|
441
|
+
def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
|
442
|
+
"""
|
443
|
+
Polar Express algorithm for the matrix sign function:
|
444
|
+
https://arxiv.org/abs/2505.16932
|
445
|
+
"""
|
446
|
+
assert G.ndim >= 2
|
447
|
+
should_transpose: bool = G.size(-2) > G.size(-1)
|
448
|
+
|
449
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
450
|
+
if should_transpose:
|
451
|
+
x = x.mT
|
452
|
+
|
453
|
+
x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
|
454
|
+
for step in range(steps):
|
455
|
+
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
456
|
+
s = x @ x.mT
|
457
|
+
# goal is to compute x = a x + b S x + c S^2 x
|
458
|
+
# we can break this up into: x = (a I + (b I + c S) S) x
|
459
|
+
y = c * s
|
460
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
461
|
+
y = y @ s
|
462
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
463
|
+
x = y @ x
|
464
|
+
|
465
|
+
if should_transpose:
|
466
|
+
x = x.mT
|
467
|
+
x = torch.nan_to_num(x)
|
468
|
+
return x.float()
|
469
|
+
|
470
|
+
|
471
|
+
###### END
|
414
472
|
|
415
473
|
|
416
474
|
@decorator_knowngood
|
@@ -418,19 +476,22 @@ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
418
476
|
assert len(G.shape) == 2
|
419
477
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
420
478
|
G = G.clone()
|
421
|
-
|
479
|
+
x = G if G.dtype == torch.float64 else stochastic_round_(G)
|
422
480
|
|
423
481
|
# X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
424
|
-
stochastic_divide_with_eps_(
|
482
|
+
stochastic_divide_with_eps_(x, G.norm(dim=(-2, -1)), eps) # ensure top singular value <= 1
|
425
483
|
if G.size(0) > G.size(1):
|
426
|
-
|
484
|
+
x = x.T
|
427
485
|
for _ in range(steps):
|
428
|
-
|
429
|
-
|
430
|
-
|
486
|
+
s = x @ x.mT
|
487
|
+
y = c * s
|
488
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
489
|
+
y = y @ s
|
490
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
491
|
+
x = y @ x
|
431
492
|
if G.size(0) > G.size(1):
|
432
|
-
|
433
|
-
return
|
493
|
+
x = x.T
|
494
|
+
return x.to(G.dtype)
|
434
495
|
|
435
496
|
|
436
497
|
@decorator_knowngood
|
@@ -492,6 +553,8 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
|
|
492
553
|
scale_mode = OrthoScaleMode(scale_mode)
|
493
554
|
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
|
494
555
|
y = zeropower_via_newtonschulz5(x, 5)
|
556
|
+
elif mode == ZerothPowerMode.thinky_polar_express:
|
557
|
+
y = msign(x, 10)
|
495
558
|
elif mode == ZerothPowerMode.legacy_newtonschulz:
|
496
559
|
y = legacy_zeropower_via_newtonschulz5(x, 5)
|
497
560
|
elif mode == ZerothPowerMode.qr:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.2
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -16,8 +16,8 @@ Requires-Python: >=3.9
|
|
16
16
|
Description-Content-Type: text/markdown
|
17
17
|
License-File: LICENSE
|
18
18
|
Requires-Dist: opt-einsum>=3.4.0
|
19
|
-
Requires-Dist: torch
|
20
|
-
Requires-Dist: numpy
|
19
|
+
Requires-Dist: torch<3.0,>=2.2
|
20
|
+
Requires-Dist: numpy<2.0.0
|
21
21
|
Provides-Extra: dev
|
22
22
|
Requires-Dist: pre-commit; extra == "dev"
|
23
23
|
Requires-Dist: pytest; extra == "dev"
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
5
5
|
[project]
|
6
6
|
name = "heavyball"
|
7
7
|
description = "Efficient Optimizers"
|
8
|
-
version = "2.1.
|
8
|
+
version = "2.1.2"
|
9
9
|
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
10
|
classifiers = ["Intended Audience :: Developers",
|
11
11
|
"Intended Audience :: Science/Research",
|
@@ -15,8 +15,8 @@ classifiers = ["Intended Audience :: Developers",
|
|
15
15
|
"Programming Language :: Python :: 3",
|
16
16
|
]
|
17
17
|
dependencies = ["opt-einsum>=3.4.0",
|
18
|
-
"torch>=2.
|
19
|
-
"numpy",
|
18
|
+
"torch>=2.2,<3.0",
|
19
|
+
"numpy<2.0.0",
|
20
20
|
]
|
21
21
|
keywords = ["torch",
|
22
22
|
"optimizer",
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|