heavyball 2.0.0__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {heavyball-2.0.0 → heavyball-2.1.0}/PKG-INFO +6 -7
  2. {heavyball-2.0.0 → heavyball-2.1.0}/README.md +4 -4
  3. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/__init__.py +81 -4
  4. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/chainable.py +37 -0
  5. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/utils.py +143 -16
  6. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/PKG-INFO +6 -7
  7. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/requires.txt +1 -2
  8. {heavyball-2.0.0 → heavyball-2.1.0}/pyproject.toml +2 -2
  9. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_bf16_params.py +1 -1
  10. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_bf16_q.py +1 -1
  11. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_bf16_storage.py +1 -1
  12. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_caution.py +1 -1
  13. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_channels_last.py +1 -1
  14. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_clip.py +1 -1
  15. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_closure.py +1 -1
  16. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_ema.py +1 -1
  17. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_foreach.py +1 -1
  18. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_hook.py +1 -1
  19. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_mars.py +1 -1
  20. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_memory.py +1 -1
  21. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_memory_leak.py +1 -1
  22. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_merge.py +1 -1
  23. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_nd_param.py +1 -1
  24. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_no_grad.py +1 -1
  25. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_save_restore.py +1 -1
  26. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_stochastic_updates.py +1 -1
  27. {heavyball-2.0.0 → heavyball-2.1.0}/LICENSE +0 -0
  28. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball/helpers.py +0 -0
  29. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/SOURCES.txt +0 -0
  30. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/dependency_links.txt +0 -0
  31. {heavyball-2.0.0 → heavyball-2.1.0}/heavyball.egg-info/top_level.txt +0 -0
  32. {heavyball-2.0.0 → heavyball-2.1.0}/setup.cfg +0 -0
  33. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_migrate_cli.py +0 -0
  34. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_psgd_precond_init_stability.py +0 -0
  35. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_singular_values.py +0 -0
  36. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_soap.py +0 -0
  37. {heavyball-2.0.0 → heavyball-2.1.0}/test/test_toy_training.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.0.0
3
+ Version: 2.1.0
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -24,13 +24,12 @@ Requires-Dist: pytest; extra == "dev"
24
24
  Requires-Dist: ruff; extra == "dev"
25
25
  Requires-Dist: matplotlib; extra == "dev"
26
26
  Requires-Dist: seaborn; extra == "dev"
27
- Requires-Dist: hyperopt; extra == "dev"
28
27
  Requires-Dist: pandas; extra == "dev"
29
28
  Requires-Dist: typer; extra == "dev"
30
29
  Requires-Dist: optuna; extra == "dev"
31
30
  Requires-Dist: optunahub; extra == "dev"
32
- Requires-Dist: botorch; extra == "dev"
33
31
  Requires-Dist: hebo; extra == "dev"
32
+ Requires-Dist: lightbench; extra == "dev"
34
33
  Dynamic: license-file
35
34
 
36
35
  # heavyball
@@ -54,7 +53,7 @@ _High-performance, extensible, chainable optimizers for PyTorch._
54
53
  - Schedule-Free optimizers with dynamic learning rate adaptation.
55
54
  - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
56
55
  - Chainable transforms for custom optimization recipes.
57
- - Comprehensive benchmark suite (`benchmark/`).
56
+ - Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
58
57
  - Detailed documentation and example-driven tutorials.
59
58
 
60
59
  ## Quickstart
@@ -85,14 +84,14 @@ for data, target in dataloader:
85
84
 
86
85
  ## Benchmarks
87
86
 
88
- > Reproduce benchmarks with:
87
+ > Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
89
88
  > ```bash
90
- > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
89
+ > python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
91
90
  > ```
92
91
 
93
92
  ## Migrating from HeavyBall 1.x
94
93
 
95
- - Read the detailed [2.0.0 migration notes](docs/heavyball-2.0.0-migration.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
94
+ - Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
96
95
  - Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
97
96
  ```bash
98
97
  python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
@@ -19,7 +19,7 @@ _High-performance, extensible, chainable optimizers for PyTorch._
19
19
  - Schedule-Free optimizers with dynamic learning rate adaptation.
20
20
  - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
21
21
  - Chainable transforms for custom optimization recipes.
22
- - Comprehensive benchmark suite (`benchmark/`).
22
+ - Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
23
23
  - Detailed documentation and example-driven tutorials.
24
24
 
25
25
  ## Quickstart
@@ -50,14 +50,14 @@ for data, target in dataloader:
50
50
 
51
51
  ## Benchmarks
52
52
 
53
- > Reproduce benchmarks with:
53
+ > Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
54
54
  > ```bash
55
- > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
55
+ > python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
56
56
  > ```
57
57
 
58
58
  ## Migrating from HeavyBall 1.x
59
59
 
60
- - Read the detailed [2.0.0 migration notes](docs/heavyball-2.0.0-migration.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
60
+ - Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
61
61
  - Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
62
62
  ```bash
63
63
  python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
@@ -100,6 +100,71 @@ class ForeachAdamW(C.BaseOpt):
100
100
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
101
101
 
102
102
 
103
+ class UnscaledAdamW(C.BaseOpt):
104
+ def __init__(
105
+ self,
106
+ params,
107
+ lr=0.0025,
108
+ betas=(0.9, 0.99),
109
+ eps=1e-8,
110
+ weight_decay=0,
111
+ warmup_steps=0,
112
+ foreach: bool = True,
113
+ storage_dtype: str = "float32",
114
+ mars: bool = False,
115
+ caution: bool = False,
116
+ mars_gamma: float = 0.0025,
117
+ gradient_clipping: C.str_or_fn = C.use_default,
118
+ update_clipping: C.str_or_fn = C.use_default,
119
+ palm: bool = C.use_default,
120
+ beta2_scale: float = 0.8,
121
+ **kwargs,
122
+ ):
123
+ defaults = locals()
124
+ defaults.pop("self")
125
+ params = defaults.pop("params")
126
+ defaults.update(defaults.pop("kwargs"))
127
+
128
+ if kwargs:
129
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
130
+
131
+ super().__init__(
132
+ params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
133
+ )
134
+
135
+
136
+ class SUDSAdamW(C.BaseOpt):
137
+ def __init__(
138
+ self,
139
+ params,
140
+ lr=0.0025,
141
+ betas=(0.9, 0.99),
142
+ eps=1e-8,
143
+ weight_decay=0,
144
+ warmup_steps=0,
145
+ foreach: bool = True,
146
+ storage_dtype: str = "float32",
147
+ mars: bool = False,
148
+ caution: bool = False,
149
+ mars_gamma: float = 0.0025,
150
+ gradient_clipping: C.str_or_fn = C.use_default,
151
+ update_clipping: C.str_or_fn = C.use_default,
152
+ palm: bool = C.use_default,
153
+ beta2_scale: float = 0.8,
154
+ precond_lr: float = 1e-2,
155
+ **kwargs,
156
+ ):
157
+ defaults = locals()
158
+ defaults.pop("self")
159
+ params = defaults.pop("params")
160
+ defaults.update(defaults.pop("kwargs"))
161
+
162
+ if kwargs:
163
+ utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
164
+
165
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
166
+
167
+
103
168
  class ForeachAdamC(C.BaseOpt):
104
169
  def __init__(
105
170
  self,
@@ -320,6 +385,7 @@ class ForeachMuon(C.BaseOpt):
320
385
  palm: bool = C.use_default,
321
386
  beta2_scale: float = 0.8,
322
387
  nesterov: bool = True,
388
+ heavyball_momentum: bool = False,
323
389
  **kwargs,
324
390
  ):
325
391
  defaults = locals()
@@ -330,6 +396,16 @@ class ForeachMuon(C.BaseOpt):
330
396
  if kwargs:
331
397
  utils.warn_once(f"Working with uncaptured keyword arguments: {kwargs}")
332
398
 
399
+ if heavyball_momentum:
400
+ if nesterov:
401
+ ema = C.nesterov_momentum
402
+ else:
403
+ ema = C.heavyball_momentum
404
+ elif nesterov:
405
+ ema = C.nesterov_ema
406
+ else:
407
+ ema = C.exp_avg
408
+
333
409
  super().__init__(
334
410
  params,
335
411
  defaults,
@@ -337,7 +413,7 @@ class ForeachMuon(C.BaseOpt):
337
413
  gradient_clipping,
338
414
  update_clipping,
339
415
  palm,
340
- fns=(C.nesterov_ema if nesterov else C.exp_avg, C.orthogonalize_update),
416
+ fns=(ema, C.orthogonalize_update),
341
417
  )
342
418
 
343
419
 
@@ -743,7 +819,9 @@ class ForeachPSGDKron(C.BaseOpt):
743
819
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
744
820
  inverse_free = C.default(inverse_free, self.quad)
745
821
  if inverse_free:
746
- raise ValueError("inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch")
822
+ raise ValueError(
823
+ "inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
824
+ )
747
825
 
748
826
  defaults = locals()
749
827
  defaults.pop("self")
@@ -796,7 +874,6 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
796
874
  hvp_interval = 2
797
875
 
798
876
 
799
-
800
877
  class ForeachPSGDLRA(C.BaseOpt):
801
878
  """
802
879
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -953,5 +1030,5 @@ __all__ = [
953
1030
  "MSAMLaProp",
954
1031
  "NewtonPSGDKron",
955
1032
  "ForeachAdamC",
956
- "SGD"
1033
+ "SGD",
957
1034
  ]
@@ -480,6 +480,43 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
480
480
  raise SkipUpdate from None
481
481
 
482
482
 
483
+ @zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
484
+ @no_state_no_foreach
485
+ def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
486
+ if group["step"] == 1:
487
+ utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
488
+ raise SkipUpdate from None
489
+
490
+ precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype))
491
+ precond_update = utils.adam_(
492
+ exp_avg,
493
+ exp_avg_sq,
494
+ precond_update.view_as(exp_avg),
495
+ utils.get_beta1(group),
496
+ utils.get_beta2(group),
497
+ group["step"] - 1,
498
+ )[0]
499
+ precond_update, _ = utils.eigvecs_product_rank1(precond_update.flatten(), fisher_approx.flatten(), w)
500
+
501
+ new_approx = utils.oja_update(fisher_approx.flatten().to(update.dtype), update.flatten(), group["precond_lr"])
502
+ utils.copy_stochastic_(fisher_approx, new_approx)
503
+ return precond_update
504
+
505
+
506
+ @zero_guard("exp_avg", "exp_avg_sq")
507
+ @no_state
508
+ def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq):
509
+ update = utils.unscaled_adam_(
510
+ exp_avg,
511
+ exp_avg_sq,
512
+ update,
513
+ utils.get_beta1(group),
514
+ utils.get_beta2(group),
515
+ group["step"],
516
+ )
517
+ return update
518
+
519
+
483
520
  @zero_guard("exp_avg", "exp_avg_sq")
484
521
  @no_state
485
522
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
@@ -1,5 +1,6 @@
1
1
  import collections
2
2
  import contextlib
3
+ import enum
3
4
  import functools
4
5
  import gc
5
6
  import inspect
@@ -19,11 +20,26 @@ from torch.backends import cudnn, opt_einsum
19
20
  from torch.nn import functional as F
20
21
  from torch.utils._pytree import tree_map
21
22
 
23
+
24
+ class ZerothPowerMode(enum.Enum):
25
+ newtonschulz = "newtonschulz"
26
+ legacy_newtonschulz = "legacy_newtonschulz"
27
+ qr = "qr"
28
+ svd = "svd"
29
+ legacy_svd = "legacy_svd"
30
+
31
+
32
+ class OrthoScaleMode(enum.Enum):
33
+ none = "none"
34
+ scale = "scale"
35
+ graft = "graft"
36
+
37
+
22
38
  compile_mode = "max-autotune-no-cudagraphs"
23
39
  dynamic = False
24
40
  compile_mode_recommended_to_none = None
25
41
  zeroth_power_mode = "newtonschulz"
26
- precise_zeroth_power_mode = "qr" # or svd
42
+ precise_zeroth_power_mode = "qr"
27
43
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
28
44
  _cudnn_double_backward_pattern = re.compile(
29
45
  r"the derivative for .* is not implemented\. Double backwards .* To run double backwards"
@@ -373,6 +389,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
373
389
  G.ndim >= 2
374
390
  ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
375
391
  assert steps == 5
392
+ G = G.clone()
376
393
  X = G if G.dtype == torch.float64 else stochastic_round_(G)
377
394
  if G.size(-2) > G.size(-1):
378
395
  X = X.mT
@@ -387,9 +404,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
387
404
  (2.8366, -3.0525, 1.2012),
388
405
  ]:
389
406
  A = X @ X.mT
390
- B = (
391
- b * A + c * A @ A
392
- ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
407
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
393
408
  X = a * X + B @ X
394
409
 
395
410
  if G.size(-2) > G.size(-1):
@@ -397,6 +412,24 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
397
412
  return X.to(G.dtype)
398
413
 
399
414
 
415
+ @decorator_knowngood
416
+ def legacy_zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
417
+ assert len(G.shape) == 2
418
+ a, b, c = (3.4445, -4.7750, 2.0315)
419
+ G = G.clone()
420
+ X = G if G.dtype == torch.float64 else stochastic_round_(G)
421
+ stochastic_multiply_(X, G.norm(dim=(-2, -1)) + eps) # ensure top singular value <= 1
422
+ if G.size(0) > G.size(1):
423
+ X = X.T
424
+ for _ in range(steps):
425
+ A = X @ X.T
426
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
427
+ X = a * X + B @ X
428
+ if G.size(0) > G.size(1):
429
+ X = X.T
430
+ return X.to(G.dtype)
431
+
432
+
400
433
  @decorator_knowngood
401
434
  def _compilable_heavyball_momentum_(state, grad, beta):
402
435
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
@@ -449,21 +482,30 @@ def _compilable_grafting(magnitude, direction):
449
482
 
450
483
 
451
484
  @decorator_knowngood
452
- def _compilable_orthogonal_(x: Tensor, mode: str, out: Tensor | None, scale_mode: str):
453
- if mode == "newtonschulz" or x.shape[0] != x.shape[1]:
485
+ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
486
+ if not isinstance(mode, ZerothPowerMode):
487
+ mode = ZerothPowerMode(mode)
488
+ if not isinstance(scale_mode, ZerothPowerMode):
489
+ scale_mode = OrthoScaleMode(scale_mode)
490
+ if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
454
491
  y = zeropower_via_newtonschulz5(x, 5)
455
- elif mode == "qr":
492
+ elif mode == ZerothPowerMode.legacy_newtonschulz:
493
+ y = legacy_zeropower_via_newtonschulz5(x, 5)
494
+ elif mode == ZerothPowerMode.qr:
456
495
  y = torch.linalg.qr(promote(x)).Q
457
- elif mode == "svd":
458
- u, _s, v = torch.linalg.svd(promote(x))
459
- y = u @ v.T
496
+ elif mode == ZerothPowerMode.svd:
497
+ u, _s, vt = torch.linalg.svd(promote(x))
498
+ y = u @ vt
499
+ elif mode == ZerothPowerMode.legacy_svd:
500
+ u, _s, vt = torch.linalg.svd(promote(x))
501
+ y = u @ vt.T
460
502
  else:
461
503
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
462
- if scale_mode == "none":
504
+ if scale_mode == OrthoScaleMode.none:
463
505
  pass
464
- elif scale_mode == "scale":
465
- y *= max(1, x.size(0) / x.size(1)) ** 0.5
466
- elif scale_mode == "graft":
506
+ elif scale_mode == OrthoScaleMode.scale:
507
+ y *= max(1, x.size(-2) / x.size(-1)) ** 0.5
508
+ elif scale_mode == OrthoScaleMode.graft:
467
509
  y = _compilable_grafting(x, y)
468
510
  else:
469
511
  raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
@@ -1223,17 +1265,53 @@ def _compilable_adam_(
1223
1265
 
1224
1266
 
1225
1267
  def adam_(
1268
+ exp_avg: List[Tensor] | Tensor,
1269
+ exp_avg_sq: List[Tensor] | Tensor,
1270
+ grad: List[Tensor] | Tensor,
1271
+ beta1: float,
1272
+ beta2: float,
1273
+ step: int,
1274
+ eps: float = 1e-8,
1275
+ ) -> List[Tensor]:
1276
+ exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
1277
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1278
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1279
+ return grad
1280
+
1281
+
1282
+ @decorator_knowngood
1283
+ def _compilable_unscaled_adam_(
1226
1284
  exp_avg: List[Tensor],
1227
1285
  exp_avg_sq: List[Tensor],
1228
1286
  grad: List[Tensor],
1287
+ beta1: Tensor,
1288
+ beta2: Tensor,
1289
+ step: Tensor,
1290
+ eps: Tensor,
1291
+ ):
1292
+ beta1 = beta_debias(beta1, step)
1293
+ beta2 = beta_debias(beta2, step)
1294
+
1295
+ g32 = list(map(promote, grad))
1296
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
1297
+ g32 = torch._foreach_div(g32, denom)
1298
+ exp_avg32 = _lerp(exp_avg, g32, beta1)
1299
+ u32 = torch._foreach_mul(exp_avg32, denom)
1300
+ copy_stochastic_list_(grad, u32)
1301
+
1302
+
1303
+ def unscaled_adam_(
1304
+ exp_avg: List[Tensor] | Tensor,
1305
+ exp_avg_sq: List[Tensor] | Tensor,
1306
+ grad: List[Tensor] | Tensor,
1229
1307
  beta1: float,
1230
1308
  beta2: float,
1231
1309
  step: int,
1232
1310
  eps: float = 1e-8,
1233
- ):
1311
+ ) -> List[Tensor]:
1234
1312
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
1235
1313
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
1236
- _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1314
+ _compilable_unscaled_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
1237
1315
  return grad
1238
1316
 
1239
1317
 
@@ -2285,6 +2363,55 @@ def cond(cond, true_fn, false_fn):
2285
2363
  return false_fn()
2286
2364
 
2287
2365
 
2366
+ @decorator_knowngood
2367
+ def _householder_vec_e1_to_v(v: Tensor, eps: float = 1e-12) -> Tensor:
2368
+ """
2369
+ Return w such that H = I - 2 w w^T is orthogonal and H e1 = v (v unit).
2370
+ Applying from the right: G @ H = G - 2 (G @ w) w^T.
2371
+ If v is (numerically) e1, returns w=0 and H=I.
2372
+ """
2373
+ v = v / v.norm().clamp(min=eps)
2374
+ e1 = torch.zeros_like(v)
2375
+ e1[0] = 1.0
2376
+ w = e1 - v
2377
+ return w / w.norm().clamp(min=eps)
2378
+
2379
+
2380
+ @decorator_knowngood
2381
+ def eigvecs_product_rank1(
2382
+ G: Tensor, v: Tensor, w: Optional[Tensor] = None, eps: float = 1e-12
2383
+ ) -> Tuple[Tensor, Tensor]:
2384
+ """
2385
+ Compute Y = G @ V where V is an eigenvector matrix for P = λ I + σ v v^T,
2386
+ using the Householder reflector with first column v. Never materializes V.
2387
+
2388
+ Args:
2389
+ G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
2390
+ v: shape (d,) — current unit direction (top eigenvector of P).
2391
+ w: optional Householder vector w; pass to reuse across calls.
2392
+
2393
+ Returns:
2394
+ (Y, w) where:
2395
+ Y has shape (..., d) and equals G @ eigenvectors(P),
2396
+ w is the Householder vector you can cache & reuse.
2397
+ """
2398
+ if w is None:
2399
+ w = _householder_vec_e1_to_v(v, eps)
2400
+ Y = G - 2.0 * compiled_einsum("...i,i,j->...j", G, w, w)
2401
+ return Y, w
2402
+
2403
+
2404
+ @decorator_knowngood
2405
+ def oja_update(v: Tensor, g: Tensor, lr: float = 1e-2, eps: float = 1e-12) -> Tensor:
2406
+ """
2407
+ One Oja step to track the top eigendirection of the gradient covariance.
2408
+ v <- v + lr * ((g^T v) g - (g^T v)^2 v); then renormalize.
2409
+ """
2410
+ gv = g @ v
2411
+ v = v + lr * (gv * g - (gv * gv) * v)
2412
+ return v / v.norm().clamp(min=eps)
2413
+
2414
+
2288
2415
  def cond_n(cond_val: Tensor, *fns):
2289
2416
  fns = list(fns)
2290
2417
  fn = fns.pop(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 2.0.0
3
+ Version: 2.1.0
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -24,13 +24,12 @@ Requires-Dist: pytest; extra == "dev"
24
24
  Requires-Dist: ruff; extra == "dev"
25
25
  Requires-Dist: matplotlib; extra == "dev"
26
26
  Requires-Dist: seaborn; extra == "dev"
27
- Requires-Dist: hyperopt; extra == "dev"
28
27
  Requires-Dist: pandas; extra == "dev"
29
28
  Requires-Dist: typer; extra == "dev"
30
29
  Requires-Dist: optuna; extra == "dev"
31
30
  Requires-Dist: optunahub; extra == "dev"
32
- Requires-Dist: botorch; extra == "dev"
33
31
  Requires-Dist: hebo; extra == "dev"
32
+ Requires-Dist: lightbench; extra == "dev"
34
33
  Dynamic: license-file
35
34
 
36
35
  # heavyball
@@ -54,7 +53,7 @@ _High-performance, extensible, chainable optimizers for PyTorch._
54
53
  - Schedule-Free optimizers with dynamic learning rate adaptation.
55
54
  - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
56
55
  - Chainable transforms for custom optimization recipes.
57
- - Comprehensive benchmark suite (`benchmark/`).
56
+ - Comprehensive benchmark suite packaged separately as LightBench (`../LightBench`).
58
57
  - Detailed documentation and example-driven tutorials.
59
58
 
60
59
  ## Quickstart
@@ -85,14 +84,14 @@ for data, target in dataloader:
85
84
 
86
85
  ## Benchmarks
87
86
 
88
- > Reproduce benchmarks with:
87
+ > Reproduce benchmarks with LightBench (install it via `pip install -e ../LightBench` from the repo root):
89
88
  > ```bash
90
- > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
89
+ > python3 -m lightbench.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
91
90
  > ```
92
91
 
93
92
  ## Migrating from HeavyBall 1.x
94
93
 
95
- - Read the detailed [2.0.0 migration notes](docs/heavyball-2.0.0-migration.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
94
+ - Read the detailed [2.0.0 migration notes](docs/heavyball2.md) for an end-to-end checklist, including guidance for restoring legacy behaviour when needed.
96
95
  - Use `scripts/migrate_optimizer_state.py` to rewrite pre-2.0 optimizer checkpoints:
97
96
  ```bash
98
97
  python scripts/migrate_optimizer_state.py path/to/checkpoint.pt heavyball.ForeachAdamW --state-key optimizer
@@ -8,10 +8,9 @@ pytest
8
8
  ruff
9
9
  matplotlib
10
10
  seaborn
11
- hyperopt
12
11
  pandas
13
12
  typer
14
13
  optuna
15
14
  optunahub
16
- botorch
17
15
  hebo
16
+ lightbench
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "heavyball"
7
7
  description = "Efficient Optimizers"
8
- version = "2.0.0"
8
+ version = "2.1.0"
9
9
  authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
10
10
  classifiers = ["Intended Audience :: Developers",
11
11
  "Intended Audience :: Science/Research",
@@ -28,7 +28,7 @@ readme = "README.md"
28
28
  requires-python = ">=3.9"
29
29
 
30
30
  [project.optional-dependencies]
31
- dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "hyperopt", "pandas", "typer", "optuna", "optunahub", "botorch", "hebo"]
31
+ dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
32
32
 
33
33
  [project.urls]
34
34
  source = "https://github.com/HomebrewML/HeavyBall"
@@ -3,12 +3,12 @@ import os
3
3
 
4
4
  import pytest
5
5
  import torch
6
+ from lightbench.utils import get_optim
6
7
  from torch import nn
7
8
  from torch._dynamo import config
8
9
 
9
10
  import heavyball
10
11
  import heavyball.utils
11
- from benchmark.utils import get_optim
12
12
  from heavyball.utils import clean, set_torch
13
13
 
14
14
  os.environ["TORCH_LOGS"] = "+recompiles"
@@ -1,11 +1,11 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
  from torch._dynamo import config
5
6
 
6
7
  import heavyball
7
8
  import heavyball.utils
8
- from benchmark.utils import get_optim
9
9
  from heavyball.utils import clean, set_torch
10
10
 
11
11
  config.cache_size_limit = 128
@@ -1,11 +1,11 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
  from torch._dynamo import config
5
6
 
6
7
  import heavyball
7
8
  import heavyball.utils
8
- from benchmark.utils import get_optim
9
9
  from heavyball.utils import clean, set_torch
10
10
 
11
11
  config.cache_size_limit = 128
@@ -4,12 +4,12 @@ os.environ["TORCH_LOGS"] = "+recompiles"
4
4
 
5
5
  import pytest
6
6
  import torch
7
+ from lightbench.utils import get_optim
7
8
  from torch import nn
8
9
  from torch._dynamo import config
9
10
 
10
11
  import heavyball
11
12
  import heavyball.utils
12
- from benchmark.utils import get_optim
13
13
  from heavyball.utils import clean, set_torch
14
14
 
15
15
  config.cache_size_limit = 128
@@ -4,12 +4,12 @@ os.environ["TORCH_LOGS"] = "+recompiles"
4
4
 
5
5
  import pytest
6
6
  import torch
7
+ from lightbench.utils import get_optim
7
8
  from torch import nn
8
9
  from torch._dynamo import config
9
10
 
10
11
  import heavyball
11
12
  import heavyball.utils
12
- from benchmark.utils import get_optim
13
13
  from heavyball.utils import clean, set_torch
14
14
 
15
15
  heavyball.utils.zeroth_power_mode = "newtonschulz"
@@ -6,11 +6,11 @@ import math
6
6
 
7
7
  import pytest
8
8
  import torch
9
+ from lightbench.utils import get_optim
9
10
  from torch import linalg, nn
10
11
  from torch._dynamo import config
11
12
 
12
13
  import heavyball
13
- from benchmark.utils import get_optim
14
14
  from heavyball import utils
15
15
  from heavyball.utils import (
16
16
  _compilable_global_l2norm_clip_,
@@ -2,11 +2,11 @@ from typing import List
2
2
 
3
3
  import pytest
4
4
  import torch
5
+ from lightbench.utils import get_optim
5
6
  from torch import nn
6
7
 
7
8
  import heavyball
8
9
  import heavyball.utils
9
- from benchmark.utils import get_optim
10
10
  from heavyball.utils import clean, set_torch
11
11
 
12
12
 
@@ -1,11 +1,11 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
  from torch._dynamo import config
5
6
 
6
7
  import heavyball
7
8
  import heavyball.utils
8
- from benchmark.utils import get_optim
9
9
  from heavyball.utils import clean, set_torch
10
10
 
11
11
  config.cache_size_limit = 128
@@ -1,10 +1,10 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
 
5
6
  import heavyball
6
7
  import heavyball.utils
7
- from benchmark.utils import get_optim
8
8
  from heavyball.utils import clean, set_torch
9
9
 
10
10
 
@@ -4,12 +4,12 @@ os.environ["TORCH_LOGS"] = "+recompiles"
4
4
 
5
5
  import pytest
6
6
  import torch
7
+ from lightbench.utils import get_optim
7
8
  from torch import nn
8
9
  from torch._dynamo import config
9
10
 
10
11
  import heavyball
11
12
  import heavyball.utils
12
- from benchmark.utils import get_optim
13
13
  from heavyball.utils import clean, hook_optimizer_into_model, set_torch
14
14
 
15
15
  heavyball.utils.compile_mode = "default"
@@ -1,11 +1,11 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
  from torch._dynamo import config
5
6
 
6
7
  import heavyball
7
8
  import heavyball.utils
8
- from benchmark.utils import get_optim
9
9
  from heavyball.utils import clean, set_torch
10
10
 
11
11
  config.cache_size_limit = 128
@@ -1,10 +1,10 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
 
5
6
  import heavyball
6
7
  import heavyball.utils
7
- from benchmark.utils import get_optim
8
8
  from heavyball.utils import clean, set_torch
9
9
 
10
10
 
@@ -1,12 +1,12 @@
1
1
  import pytest
2
2
  import torch
3
3
  import tqdm
4
+ from lightbench.utils import get_optim
4
5
  from torch import nn
5
6
  from torch.nn import functional as F
6
7
 
7
8
  import heavyball
8
9
  import heavyball.utils
9
- from benchmark.utils import get_optim
10
10
  from heavyball.utils import clean, set_torch
11
11
 
12
12
 
@@ -2,11 +2,11 @@ from typing import List
2
2
 
3
3
  import pytest
4
4
  import torch
5
+ from lightbench.utils import get_optim
5
6
  from torch import nn
6
7
 
7
8
  import heavyball
8
9
  import heavyball.utils
9
- from benchmark.utils import get_optim
10
10
  from heavyball.utils import clean, set_torch
11
11
 
12
12
 
@@ -1,11 +1,11 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
  from torch._dynamo import config
5
6
 
6
7
  import heavyball
7
8
  import heavyball.utils
8
- from benchmark.utils import get_optim
9
9
  from heavyball.utils import set_torch
10
10
 
11
11
  config.cache_size_limit = 2**20
@@ -2,11 +2,11 @@ from typing import List
2
2
 
3
3
  import pytest
4
4
  import torch
5
+ from lightbench.utils import get_optim
5
6
  from torch import nn
6
7
 
7
8
  import heavyball
8
9
  import heavyball.utils
9
- from benchmark.utils import get_optim
10
10
  from heavyball.utils import clean, set_torch
11
11
 
12
12
 
@@ -5,13 +5,13 @@ os.environ["TORCH_LOGS"] = "+recompiles"
5
5
 
6
6
  import pytest
7
7
  import torch
8
+ from lightbench.utils import get_optim
8
9
  from torch import nn
9
10
  from torch._dynamo import config
10
11
  from torch.utils._pytree import tree_map
11
12
 
12
13
  import heavyball
13
14
  import heavyball.utils
14
- from benchmark.utils import get_optim
15
15
  from heavyball.utils import set_torch
16
16
 
17
17
  config.cache_size_limit = 128
@@ -1,10 +1,10 @@
1
1
  import pytest
2
2
  import torch
3
+ from lightbench.utils import get_optim
3
4
  from torch import nn
4
5
 
5
6
  import heavyball
6
7
  import heavyball.utils
7
- from benchmark.utils import get_optim
8
8
  from heavyball.utils import clean, set_torch
9
9
 
10
10
 
File without changes
File without changes
File without changes