heavyball 2.0.0__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-2.0.0 → heavyball-2.1.0}/PKG-INFO +6 -7
- {heavyball-2.0.0 → heavyball-2.1.0}/README.md +4 -4
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/__init__.py +81 -4
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/chainable.py +37 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/utils.py +143 -16
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/PKG-INFO +6 -7
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/requires.txt +1 -2
- {heavyball-2.0.0 → heavyball-2.1.0}/pyproject.toml +2 -2
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_bf16_params.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_bf16_q.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_bf16_storage.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_caution.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_channels_last.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_clip.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_closure.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_ema.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_foreach.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_hook.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_mars.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_memory.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_memory_leak.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_merge.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_nd_param.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_no_grad.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_save_restore.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_stochastic_updates.py +1 -1
- {heavyball-2.0.0 → heavyball-2.1.0}/LICENSE +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/helpers.py +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/setup.cfg +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_migrate_cli.py +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_psgd_precond_init_stability.py +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_singular_values.py +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_soap.py +0 -0
- {heavyball-2.0.0 → heavyball-2.1.0}/test/test_toy_training.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.1.0
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -24,13 +24,12 @@ Requires-Dist: pytest; extra == "dev"
|
|
24
24
|
Requires-Dist: ruff; extra == "dev"
|
25
25
|
Requires-Dist: matplotlib; extra == "dev"
|
26
26
|
Requires-Dist: seaborn; extra == "dev"
|
27
|
-
Requires-Dist: hyperopt; extra == "dev"
|
28
27
|
Requires-Dist: pandas; extra == "dev"
|
29
28
|
Requires-Dist: typer; extra == "dev"
|
30
29
|
Requires-Dist: optuna; extra == "dev"
|
31
30
|
Requires-Dist: optunahub; extra == "dev"
|
32
|
-
Requires-Dist: botorch; extra == "dev"
|
33
31
|
Requires-Dist: hebo; extra == "dev"
|
32
|
+
Requires-Dist: lightbench; extra == "dev"
|
34
33
|
Dynamic: license-file
|
35
34
|
|
36
35
|
# heavyball
|
@@ -54,7 +53,7 @@ _High-performance, extensible, chainable optimizers for PyTorch._
|
|
54
53
|
- Schedule-Free optimizers with dynamic learning rate adaptation.
|
55
54
|
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
|
56
55
|
- Chainable transforms for custom optimization recipes.
|
57
|
-
- Comprehensive benchmark suite (`
|
56
|
+
- Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
|
58
57
|
- Detailed documentation and example-driven tutorials.
|
59
58
|
|
60
59
|
## Quickstart
|
@@ -85,14 +84,14 @@ for data, target in dataloader:
|
|
85
84
|
|
86
85
|
## Benchmarks
|
87
86
|
|
88
|
-
> Reproduce benchmarks with:
|
87
|
+
> Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
|
89
88
|
> ```bash
|
90
|
-
> python3 -m
|
89
|
+
> python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
|
91
90
|
> ```
|
92
91
|
|
93
92
|
## Migrating from HeavyBall 1.x
|
94
93
|
|
95
|
-
- Read the detailed [2.0.0 migration notes](docs/
|
94
|
+
- Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
|
96
95
|
- Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
|
97
96
|
```bash
|
98
97
|
python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
|
@@ -19,7 +19,7 @@ _High-performance, extensible, chainable optimizers for PyTorch._
|
|
19
19
|
- Schedule-Free optimizers with dynamic learning rate adaptation.
|
20
20
|
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
|
21
21
|
- Chainable transforms for custom optimization recipes.
|
22
|
-
- Comprehensive benchmark suite (`
|
22
|
+
- Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
|
23
23
|
- Detailed documentation and example-driven tutorials.
|
24
24
|
|
25
25
|
## Quickstart
|
@@ -50,14 +50,14 @@ for data, target in dataloader:
|
|
50
50
|
|
51
51
|
## Benchmarks
|
52
52
|
|
53
|
-
> Reproduce benchmarks with:
|
53
|
+
> Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
|
54
54
|
> ```bash
|
55
|
-
> python3 -m
|
55
|
+
> python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
|
56
56
|
> ```
|
57
57
|
|
58
58
|
## Migrating from HeavyBall 1.x
|
59
59
|
|
60
|
-
- Read the detailed [2.0.0 migration notes](docs/
|
60
|
+
- Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
|
61
61
|
- Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
|
62
62
|
```bash
|
63
63
|
python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
|
@@ -100,6 +100,71 @@ class ForeachAdamW(C.BaseOpt):
|
|
100
100
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
|
101
101
|
|
102
102
|
|
103
|
+
class UnscaledAdamW(C.BaseOpt):
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
params,
|
107
|
+
lr=0.0025,
|
108
|
+
betas=(0.9, 0.99),
|
109
|
+
eps=1e-8,
|
110
|
+
weight_decay=0,
|
111
|
+
warmup_steps=0,
|
112
|
+
foreach: bool = True,
|
113
|
+
storage_dtype: str = "float32",
|
114
|
+
mars: bool = False,
|
115
|
+
caution: bool = False,
|
116
|
+
mars_gamma: float = 0.0025,
|
117
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
118
|
+
update_clipping: C.str_or_fn = C.use_default,
|
119
|
+
palm: bool = C.use_default,
|
120
|
+
beta2_scale: float = 0.8,
|
121
|
+
**kwargs,
|
122
|
+
):
|
123
|
+
defaults = locals()
|
124
|
+
defaults.pop("self")
|
125
|
+
params = defaults.pop("params")
|
126
|
+
defaults.update(defaults.pop("kwargs"))
|
127
|
+
|
128
|
+
if kwargs:
|
129
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
130
|
+
|
131
|
+
super().__init__(
|
132
|
+
params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
class SUDSAdamW(C.BaseOpt):
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
params,
|
140
|
+
lr=0.0025,
|
141
|
+
betas=(0.9, 0.99),
|
142
|
+
eps=1e-8,
|
143
|
+
weight_decay=0,
|
144
|
+
warmup_steps=0,
|
145
|
+
foreach: bool = True,
|
146
|
+
storage_dtype: str = "float32",
|
147
|
+
mars: bool = False,
|
148
|
+
caution: bool = False,
|
149
|
+
mars_gamma: float = 0.0025,
|
150
|
+
gradient_clipping: C.str_or_fn = C.use_default,
|
151
|
+
update_clipping: C.str_or_fn = C.use_default,
|
152
|
+
palm: bool = C.use_default,
|
153
|
+
beta2_scale: float = 0.8,
|
154
|
+
precond_lr: float = 1e-2,
|
155
|
+
**kwargs,
|
156
|
+
):
|
157
|
+
defaults = locals()
|
158
|
+
defaults.pop("self")
|
159
|
+
params = defaults.pop("params")
|
160
|
+
defaults.update(defaults.pop("kwargs"))
|
161
|
+
|
162
|
+
if kwargs:
|
163
|
+
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
164
|
+
|
165
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
|
166
|
+
|
167
|
+
|
103
168
|
class ForeachAdamC(C.BaseOpt):
|
104
169
|
def __init__(
|
105
170
|
self,
|
@@ -320,6 +385,7 @@ class ForeachMuon(C.BaseOpt):
|
|
320
385
|
palm: bool = C.use_default,
|
321
386
|
beta2_scale: float = 0.8,
|
322
387
|
nesterov: bool = True,
|
388
|
+
heavyball_momentum: bool = False,
|
323
389
|
**kwargs,
|
324
390
|
):
|
325
391
|
defaults = locals()
|
@@ -330,6 +396,16 @@ class ForeachMuon(C.BaseOpt):
|
|
330
396
|
if kwargs:
|
331
397
|
utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
|
332
398
|
|
399
|
+
if heavyball_momentum:
|
400
|
+
if nesterov:
|
401
|
+
ema = C.nesterov_momentum
|
402
|
+
else:
|
403
|
+
ema = C.heavyball_momentum
|
404
|
+
elif nesterov:
|
405
|
+
ema = C.nesterov_ema
|
406
|
+
else:
|
407
|
+
ema = C.exp_avg
|
408
|
+
|
333
409
|
super().__init__(
|
334
410
|
params,
|
335
411
|
defaults,
|
@@ -337,7 +413,7 @@ class ForeachMuon(C.BaseOpt):
|
|
337
413
|
gradient_clipping,
|
338
414
|
update_clipping,
|
339
415
|
palm,
|
340
|
-
fns=(
|
416
|
+
fns=(ema, C.orthogonalize_update),
|
341
417
|
)
|
342
418
|
|
343
419
|
|
@@ -743,7 +819,9 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
743
819
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
744
820
|
inverse_free = C.default(inverse_free, self.quad)
|
745
821
|
if inverse_free:
|
746
|
-
raise ValueError(
|
822
|
+
raise ValueError(
|
823
|
+
"inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
|
824
|
+
)
|
747
825
|
|
748
826
|
defaults = locals()
|
749
827
|
defaults.pop("self")
|
@@ -796,7 +874,6 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
|
796
874
|
hvp_interval = 2
|
797
875
|
|
798
876
|
|
799
|
-
|
800
877
|
class ForeachPSGDLRA(C.BaseOpt):
|
801
878
|
"""
|
802
879
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -953,5 +1030,5 @@ __all__ = [
|
|
953
1030
|
"MSAMLaProp",
|
954
1031
|
"NewtonPSGDKron",
|
955
1032
|
"ForeachAdamC",
|
956
|
-
"SGD"
|
1033
|
+
"SGD",
|
957
1034
|
]
|
@@ -480,6 +480,43 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
480
480
|
raise SkipUpdate from None
|
481
481
|
|
482
482
|
|
483
|
+
@zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
|
484
|
+
@no_state_no_foreach
|
485
|
+
def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
|
486
|
+
if group["step"] == 1:
|
487
|
+
utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
|
488
|
+
raise SkipUpdate from None
|
489
|
+
|
490
|
+
precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype))
|
491
|
+
precond_update = utils.adam_(
|
492
|
+
exp_avg,
|
493
|
+
exp_avg_sq,
|
494
|
+
precond_update.view_as(exp_avg),
|
495
|
+
utils.get_beta1(group),
|
496
|
+
utils.get_beta2(group),
|
497
|
+
group["step"] - 1,
|
498
|
+
)[0]
|
499
|
+
precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
|
500
|
+
|
501
|
+
new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
|
502
|
+
utils.copy_stochastic_(fisher_approx, new_approx)
|
503
|
+
return precond_update
|
504
|
+
|
505
|
+
|
506
|
+
@zero_guard("exp_avg", "exp_avg_sq")
|
507
|
+
@no_state
|
508
|
+
def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
509
|
+
update = utils.unscaled_adam_(
|
510
|
+
exp_avg,
|
511
|
+
exp_avg_sq,
|
512
|
+
update,
|
513
|
+
utils.get_beta1(group),
|
514
|
+
utils.get_beta2(group),
|
515
|
+
group["step"],
|
516
|
+
)
|
517
|
+
return update
|
518
|
+
|
519
|
+
|
483
520
|
@zero_guard("exp_avg", "exp_avg_sq")
|
484
521
|
@no_state
|
485
522
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import contextlib
|
3
|
+
import enum
|
3
4
|
import functools
|
4
5
|
import gc
|
5
6
|
import inspect
|
@@ -19,11 +20,26 @@ from torch.backends import cudnn, opt_einsum
|
|
19
20
|
from torch.nn import functional as F
|
20
21
|
from torch.utils._pytree import tree_map
|
21
22
|
|
23
|
+
|
24
|
+
class ZerothPowerMode(enum.Enum):
|
25
|
+
newtonschulz = "newtonschulz"
|
26
|
+
legacy_newtonschulz = "legacy_newtonschulz"
|
27
|
+
qr = "qr"
|
28
|
+
svd = "svd"
|
29
|
+
legacy_svd = "legacy_svd"
|
30
|
+
|
31
|
+
|
32
|
+
class OrthoScaleMode(enum.Enum):
|
33
|
+
none = "none"
|
34
|
+
scale = "scale"
|
35
|
+
graft = "graft"
|
36
|
+
|
37
|
+
|
22
38
|
compile_mode = "max-autotune-no-cudagraphs"
|
23
39
|
dynamic = False
|
24
40
|
compile_mode_recommended_to_none = None
|
25
41
|
zeroth_power_mode = "newtonschulz"
|
26
|
-
precise_zeroth_power_mode = "qr"
|
42
|
+
precise_zeroth_power_mode = "qr"
|
27
43
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
28
44
|
_cudnn_double_backward_pattern = re.compile(
|
29
45
|
r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
|
@@ -373,6 +389,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
373
389
|
G.ndim >= 2
|
374
390
|
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
375
391
|
assert steps == 5
|
392
|
+
G = G.clone()
|
376
393
|
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
377
394
|
if G.size(-2) > G.size(-1):
|
378
395
|
X = X.mT
|
@@ -387,9 +404,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
387
404
|
(2.8366, -3.0525, 1.2012),
|
388
405
|
]:
|
389
406
|
A = X @ X.mT
|
390
|
-
B =
|
391
|
-
b * A + c * A @ A
|
392
|
-
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
407
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
393
408
|
X = a * X + B @ X
|
394
409
|
|
395
410
|
if G.size(-2) > G.size(-1):
|
@@ -397,6 +412,24 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
|
397
412
|
return X.to(G.dtype)
|
398
413
|
|
399
414
|
|
415
|
+
@decorator_knowngood
|
416
|
+
def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
417
|
+
assert len(G.shape) == 2
|
418
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
419
|
+
G = G.clone()
|
420
|
+
X = G if G.dtype == torch.float64 else stochastic_round_(G)
|
421
|
+
stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
|
422
|
+
if G.size(0) > G.size(1):
|
423
|
+
X = X.T
|
424
|
+
for _ in range(steps):
|
425
|
+
A = X @ X.T
|
426
|
+
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
427
|
+
X = a * X + B @ X
|
428
|
+
if G.size(0) > G.size(1):
|
429
|
+
X = X.T
|
430
|
+
return X.to(G.dtype)
|
431
|
+
|
432
|
+
|
400
433
|
@decorator_knowngood
|
401
434
|
def _compilable_heavyball_momentum_(state, grad, beta):
|
402
435
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
@@ -449,21 +482,30 @@ def _compilable_grafting(magnitude, direction):
|
|
449
482
|
|
450
483
|
|
451
484
|
@decorator_knowngood
|
452
|
-
def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
|
453
|
-
if mode
|
485
|
+
def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
|
486
|
+
if not isinstance(mode, ZerothPowerMode):
|
487
|
+
mode = ZerothPowerMode(mode)
|
488
|
+
if not isinstance(scale_mode, ZerothPowerMode):
|
489
|
+
scale_mode = OrthoScaleMode(scale_mode)
|
490
|
+
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
|
454
491
|
y = zeropower_via_newtonschulz5(x, 5)
|
455
|
-
elif mode ==
|
492
|
+
elif mode == ZerothPowerMode.legacy_newtonschulz:
|
493
|
+
y = legacy_zeropower_via_newtonschulz5(x, 5)
|
494
|
+
elif mode == ZerothPowerMode.qr:
|
456
495
|
y = torch.linalg.qr(promote(x)).Q
|
457
|
-
elif mode ==
|
458
|
-
u, _s,
|
459
|
-
y = u @
|
496
|
+
elif mode == ZerothPowerMode.svd:
|
497
|
+
u, _s, vt = torch.linalg.svd(promote(x))
|
498
|
+
y = u @ vt
|
499
|
+
elif mode == ZerothPowerMode.legacy_svd:
|
500
|
+
u, _s, vt = torch.linalg.svd(promote(x))
|
501
|
+
y = u @ vt.T
|
460
502
|
else:
|
461
503
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
462
|
-
if scale_mode ==
|
504
|
+
if scale_mode == OrthoScaleMode.none:
|
463
505
|
pass
|
464
|
-
elif scale_mode ==
|
465
|
-
y *= max(1, x.size(
|
466
|
-
elif scale_mode ==
|
506
|
+
elif scale_mode == OrthoScaleMode.scale:
|
507
|
+
y *= max(1, x.size(-2) / x.size(-1)) ** 0.5
|
508
|
+
elif scale_mode == OrthoScaleMode.graft:
|
467
509
|
y = _compilable_grafting(x, y)
|
468
510
|
else:
|
469
511
|
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
@@ -1223,17 +1265,53 @@ def _compilable_adam_(
|
|
1223
1265
|
|
1224
1266
|
|
1225
1267
|
def adam_(
|
1268
|
+
exp_avg: List[Tensor] | Tensor,
|
1269
|
+
exp_avg_sq: List[Tensor] | Tensor,
|
1270
|
+
grad: List[Tensor] | Tensor,
|
1271
|
+
beta1: float,
|
1272
|
+
beta2: float,
|
1273
|
+
step: int,
|
1274
|
+
eps: float = 1e-8,
|
1275
|
+
) -> List[Tensor]:
|
1276
|
+
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
1277
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1278
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
1279
|
+
return grad
|
1280
|
+
|
1281
|
+
|
1282
|
+
@decorator_knowngood
|
1283
|
+
def _compilable_unscaled_adam_(
|
1226
1284
|
exp_avg: List[Tensor],
|
1227
1285
|
exp_avg_sq: List[Tensor],
|
1228
1286
|
grad: List[Tensor],
|
1287
|
+
beta1: Tensor,
|
1288
|
+
beta2: Tensor,
|
1289
|
+
step: Tensor,
|
1290
|
+
eps: Tensor,
|
1291
|
+
):
|
1292
|
+
beta1 = beta_debias(beta1, step)
|
1293
|
+
beta2 = beta_debias(beta2, step)
|
1294
|
+
|
1295
|
+
g32 = list(map(promote, grad))
|
1296
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
1297
|
+
g32 = torch._foreach_div(g32, denom)
|
1298
|
+
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
1299
|
+
u32 = torch._foreach_mul(exp_avg32, denom)
|
1300
|
+
copy_stochastic_list_(grad, u32)
|
1301
|
+
|
1302
|
+
|
1303
|
+
def unscaled_adam_(
|
1304
|
+
exp_avg: List[Tensor] | Tensor,
|
1305
|
+
exp_avg_sq: List[Tensor] | Tensor,
|
1306
|
+
grad: List[Tensor] | Tensor,
|
1229
1307
|
beta1: float,
|
1230
1308
|
beta2: float,
|
1231
1309
|
step: int,
|
1232
1310
|
eps: float = 1e-8,
|
1233
|
-
):
|
1311
|
+
) -> List[Tensor]:
|
1234
1312
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
1235
1313
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
1236
|
-
|
1314
|
+
_compilable_unscaled_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
1237
1315
|
return grad
|
1238
1316
|
|
1239
1317
|
|
@@ -2285,6 +2363,55 @@ def cond(cond, true_fn, false_fn):
|
|
2285
2363
|
return false_fn()
|
2286
2364
|
|
2287
2365
|
|
2366
|
+
@decorator_knowngood
|
2367
|
+
def _householder_vec_e1_to_v(v: Tensor, eps: float = 1e-12) -> Tensor:
|
2368
|
+
"""
|
2369
|
+
Return w such that H = I - 2 w w^T is orthogonal and H e1 = v (v unit).
|
2370
|
+
Applying from the right: G @ H = G - 2 (G @ w) w^T.
|
2371
|
+
If v is (numerically) e1, returns w=0 and H=I.
|
2372
|
+
"""
|
2373
|
+
v = v / v.norm().clamp(min=eps)
|
2374
|
+
e1 = torch.zeros_like(v)
|
2375
|
+
e1[0] = 1.0
|
2376
|
+
w = e1 - v
|
2377
|
+
return w / w.norm().clamp(min=eps)
|
2378
|
+
|
2379
|
+
|
2380
|
+
@decorator_knowngood
|
2381
|
+
def eigvecs_product_rank1(
|
2382
|
+
G: Tensor, v: Tensor, w: Optional[Tensor] = None, eps: float = 1e-12
|
2383
|
+
) -> Tuple[Tensor, Tensor]:
|
2384
|
+
"""
|
2385
|
+
Compute Y = G @ V where V is an eigenvector matrix for P = λ I + σ v v^T,
|
2386
|
+
using the Householder reflector with first column v. Never materializes V.
|
2387
|
+
|
2388
|
+
Args:
|
2389
|
+
G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
|
2390
|
+
v: shape (d,) — current unit direction (top eigenvector of P).
|
2391
|
+
w: optional Householder vector w; pass to reuse across calls.
|
2392
|
+
|
2393
|
+
Returns:
|
2394
|
+
(Y, w) where:
|
2395
|
+
Y has shape (..., d) and equals G @ eigenvectors(P),
|
2396
|
+
w is the Householder vector you can cache & reuse.
|
2397
|
+
"""
|
2398
|
+
if w is None:
|
2399
|
+
w = _householder_vec_e1_to_v(v, eps)
|
2400
|
+
Y = G - 2.0 * compiled_einsum("...i,i,j->...j", G, w, w)
|
2401
|
+
return Y, w
|
2402
|
+
|
2403
|
+
|
2404
|
+
@decorator_knowngood
|
2405
|
+
def oja_update(v: Tensor, g: Tensor, lr: float = 1e-2, eps: float = 1e-12) -> Tensor:
|
2406
|
+
"""
|
2407
|
+
One Oja step to track the top eigendirection of the gradient covariance.
|
2408
|
+
v <- v + lr * ((g^T v) g - (g^T v)^2 v); then renormalize.
|
2409
|
+
"""
|
2410
|
+
gv = g @ v
|
2411
|
+
v = v + lr * (gv * g - (gv * gv) * v)
|
2412
|
+
return v / v.norm().clamp(min=eps)
|
2413
|
+
|
2414
|
+
|
2288
2415
|
def cond_n(cond_val: Tensor, *fns):
|
2289
2416
|
fns = list(fns)
|
2290
2417
|
fn = fns.pop(0)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: heavyball
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.1.0
|
4
4
|
Summary: Efficient Optimizers
|
5
5
|
Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
|
6
6
|
Project-URL: source, https://github.com/HomebrewML/HeavyBall
|
@@ -24,13 +24,12 @@ Requires-Dist: pytest; extra == "dev"
|
|
24
24
|
Requires-Dist: ruff; extra == "dev"
|
25
25
|
Requires-Dist: matplotlib; extra == "dev"
|
26
26
|
Requires-Dist: seaborn; extra == "dev"
|
27
|
-
Requires-Dist: hyperopt; extra == "dev"
|
28
27
|
Requires-Dist: pandas; extra == "dev"
|
29
28
|
Requires-Dist: typer; extra == "dev"
|
30
29
|
Requires-Dist: optuna; extra == "dev"
|
31
30
|
Requires-Dist: optunahub; extra == "dev"
|
32
|
-
Requires-Dist: botorch; extra == "dev"
|
33
31
|
Requires-Dist: hebo; extra == "dev"
|
32
|
+
Requires-Dist: lightbench; extra == "dev"
|
34
33
|
Dynamic: license-file
|
35
34
|
|
36
35
|
# heavyball
|
@@ -54,7 +53,7 @@ _High-performance, extensible, chainable optimizers for PyTorch._
|
|
54
53
|
- Schedule-Free optimizers with dynamic learning rate adaptation.
|
55
54
|
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
|
56
55
|
- Chainable transforms for custom optimization recipes.
|
57
|
-
- Comprehensive benchmark suite (`
|
56
|
+
- Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
|
58
57
|
- Detailed documentation and example-driven tutorials.
|
59
58
|
|
60
59
|
## Quickstart
|
@@ -85,14 +84,14 @@ for data, target in dataloader:
|
|
85
84
|
|
86
85
|
## Benchmarks
|
87
86
|
|
88
|
-
> Reproduce benchmarks with:
|
87
|
+
> Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
|
89
88
|
> ```bash
|
90
|
-
> python3 -m
|
89
|
+
> python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
|
91
90
|
> ```
|
92
91
|
|
93
92
|
## Migrating from HeavyBall 1.x
|
94
93
|
|
95
|
-
- Read the detailed [2.0.0 migration notes](docs/
|
94
|
+
- Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
|
96
95
|
- Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
|
97
96
|
```bash
|
98
97
|
python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
5
5
|
[project]
|
6
6
|
name = "heavyball"
|
7
7
|
description = "Efficient Optimizers"
|
8
|
-
version = "2.
|
8
|
+
version = "2.1.0"
|
9
9
|
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
|
10
10
|
classifiers = ["Intended Audience :: Developers",
|
11
11
|
"Intended Audience :: Science/Research",
|
@@ -28,7 +28,7 @@ readme = "README.md"
|
|
28
28
|
requires-python = ">=3.9"
|
29
29
|
|
30
30
|
[project.optional-dependencies]
|
31
|
-
dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "
|
31
|
+
dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
|
32
32
|
|
33
33
|
[project.urls]
|
34
34
|
source = "https://github.com/HomebrewML/HeavyBall"
|
@@ -3,12 +3,12 @@ import os
|
|
3
3
|
|
4
4
|
import pytest
|
5
5
|
import torch
|
6
|
+
from lightbench.utils import get_optim
|
6
7
|
from torch import nn
|
7
8
|
from torch._dynamo import config
|
8
9
|
|
9
10
|
import heavyball
|
10
11
|
import heavyball.utils
|
11
|
-
from benchmark.utils import get_optim
|
12
12
|
from heavyball.utils import clean, set_torch
|
13
13
|
|
14
14
|
os.environ["TORCH_LOGS"] = "+recompiles"
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import pytest
|
2
2
|
import torch
|
3
|
+
from lightbench.utils import get_optim
|
3
4
|
from torch import nn
|
4
5
|
from torch._dynamo import config
|
5
6
|
|
6
7
|
import heavyball
|
7
8
|
import heavyball.utils
|
8
|
-
from benchmark.utils import get_optim
|
9
9
|
from heavyball.utils import clean, set_torch
|
10
10
|
|
11
11
|
config.cache_size_limit = 128
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import pytest
|
2
2
|
import torch
|
3
|
+
from lightbench.utils import get_optim
|
3
4
|
from torch import nn
|
4
5
|
from torch._dynamo import config
|
5
6
|
|
6
7
|
import heavyball
|
7
8
|
import heavyball.utils
|
8
|
-
from benchmark.utils import get_optim
|
9
9
|
from heavyball.utils import clean, set_torch
|
10
10
|
|
11
11
|
config.cache_size_limit = 128
|
@@ -4,12 +4,12 @@ os.environ["TORCH_LOGS"] = "+recompiles"
|
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
7
|
+
from lightbench.utils import get_optim
|
7
8
|
from torch import nn
|
8
9
|
from torch._dynamo import config
|
9
10
|
|
10
11
|
import heavyball
|
11
12
|
import heavyball.utils
|
12
|
-
from benchmark.utils import get_optim
|
13
13
|
from heavyball.utils import clean, set_torch
|
14
14
|
|
15
15
|
config.cache_size_limit = 128
|
@@ -4,12 +4,12 @@ os.environ["TORCH_LOGS"] = "+recompiles"
|
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
7
|
+
from lightbench.utils import get_optim
|
7
8
|
from torch import nn
|
8
9
|
from torch._dynamo import config
|
9
10
|
|
10
11
|
import heavyball
|
11
12
|
import heavyball.utils
|
12
|
-
from benchmark.utils import get_optim
|
13
13
|
from heavyball.utils import clean, set_torch
|
14
14
|
|
15
15
|
heavyball.utils.zeroth_power_mode = "newtonschulz"
|
@@ -6,11 +6,11 @@ import math
|
|
6
6
|
|
7
7
|
import pytest
|
8
8
|
import torch
|
9
|
+
from lightbench.utils import get_optim
|
9
10
|
from torch import linalg, nn
|
10
11
|
from torch._dynamo import config
|
11
12
|
|
12
13
|
import heavyball
|
13
|
-
from benchmark.utils import get_optim
|
14
14
|
from heavyball import utils
|
15
15
|
from heavyball.utils import (
|
16
16
|
_compilable_global_l2norm_clip_,
|
@@ -2,11 +2,11 @@ from typing import List
|
|
2
2
|
|
3
3
|
import pytest
|
4
4
|
import torch
|
5
|
+
from lightbench.utils import get_optim
|
5
6
|
from torch import nn
|
6
7
|
|
7
8
|
import heavyball
|
8
9
|
import heavyball.utils
|
9
|
-
from benchmark.utils import get_optim
|
10
10
|
from heavyball.utils import clean, set_torch
|
11
11
|
|
12
12
|
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import pytest
|
2
2
|
import torch
|
3
|
+
from lightbench.utils import get_optim
|
3
4
|
from torch import nn
|
4
5
|
from torch._dynamo import config
|
5
6
|
|
6
7
|
import heavyball
|
7
8
|
import heavyball.utils
|
8
|
-
from benchmark.utils import get_optim
|
9
9
|
from heavyball.utils import clean, set_torch
|
10
10
|
|
11
11
|
config.cache_size_limit = 128
|
@@ -4,12 +4,12 @@ os.environ["TORCH_LOGS"] = "+recompiles"
|
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
import torch
|
7
|
+
from lightbench.utils import get_optim
|
7
8
|
from torch import nn
|
8
9
|
from torch._dynamo import config
|
9
10
|
|
10
11
|
import heavyball
|
11
12
|
import heavyball.utils
|
12
|
-
from benchmark.utils import get_optim
|
13
13
|
from heavyball.utils import clean, hook_optimizer_into_model, set_torch
|
14
14
|
|
15
15
|
heavyball.utils.compile_mode = "default"
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import pytest
|
2
2
|
import torch
|
3
|
+
from lightbench.utils import get_optim
|
3
4
|
from torch import nn
|
4
5
|
from torch._dynamo import config
|
5
6
|
|
6
7
|
import heavyball
|
7
8
|
import heavyball.utils
|
8
|
-
from benchmark.utils import get_optim
|
9
9
|
from heavyball.utils import clean, set_torch
|
10
10
|
|
11
11
|
config.cache_size_limit = 128
|
@@ -1,12 +1,12 @@
|
|
1
1
|
import pytest
|
2
2
|
import torch
|
3
3
|
import tqdm
|
4
|
+
from lightbench.utils import get_optim
|
4
5
|
from torch import nn
|
5
6
|
from torch.nn import functional as F
|
6
7
|
|
7
8
|
import heavyball
|
8
9
|
import heavyball.utils
|
9
|
-
from benchmark.utils import get_optim
|
10
10
|
from heavyball.utils import clean, set_torch
|
11
11
|
|
12
12
|
|
@@ -2,11 +2,11 @@ from typing import List
|
|
2
2
|
|
3
3
|
import pytest
|
4
4
|
import torch
|
5
|
+
from lightbench.utils import get_optim
|
5
6
|
from torch import nn
|
6
7
|
|
7
8
|
import heavyball
|
8
9
|
import heavyball.utils
|
9
|
-
from benchmark.utils import get_optim
|
10
10
|
from heavyball.utils import clean, set_torch
|
11
11
|
|
12
12
|
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import pytest
|
2
2
|
import torch
|
3
|
+
from lightbench.utils import get_optim
|
3
4
|
from torch import nn
|
4
5
|
from torch._dynamo import config
|
5
6
|
|
6
7
|
import heavyball
|
7
8
|
import heavyball.utils
|
8
|
-
from benchmark.utils import get_optim
|
9
9
|
from heavyball.utils import set_torch
|
10
10
|
|
11
11
|
config.cache_size_limit = 2**20
|
@@ -2,11 +2,11 @@ from typing import List
|
|
2
2
|
|
3
3
|
import pytest
|
4
4
|
import torch
|
5
|
+
from lightbench.utils import get_optim
|
5
6
|
from torch import nn
|
6
7
|
|
7
8
|
import heavyball
|
8
9
|
import heavyball.utils
|
9
|
-
from benchmark.utils import get_optim
|
10
10
|
from heavyball.utils import clean, set_torch
|
11
11
|
|
12
12
|
|
@@ -5,13 +5,13 @@ os.environ["TORCH_LOGS"] = "+recompiles"
|
|
5
5
|
|
6
6
|
import pytest
|
7
7
|
import torch
|
8
|
+
from lightbench.utils import get_optim
|
8
9
|
from torch import nn
|
9
10
|
from torch._dynamo import config
|
10
11
|
from torch.utils._pytree import tree_map
|
11
12
|
|
12
13
|
import heavyball
|
13
14
|
import heavyball.utils
|
14
|
-
from benchmark.utils import get_optim
|
15
15
|
from heavyball.utils import set_torch
|
16
16
|
|
17
17
|
config.cache_size_limit = 128
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|