heavyball 1.7.0__tar.gz → 1.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {heavyball-1.7.0 → heavyball-1.7.1}/PKG-INFO +1 -1
  2. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball/__init__.py +20 -1
  3. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball/chainable.py +50 -8
  4. heavyball-1.7.1/heavyball/optimizations/__init__.py +38 -0
  5. heavyball-1.7.1/heavyball/optimizations/integrator.py +169 -0
  6. heavyball-1.7.1/heavyball/optimizations/optimizations.py +329 -0
  7. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball/utils.py +518 -162
  8. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/PKG-INFO +1 -1
  9. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/SOURCES.txt +11 -1
  10. {heavyball-1.7.0 → heavyball-1.7.1}/pyproject.toml +1 -1
  11. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_memory.py +12 -6
  12. heavyball-1.7.1/test/test_memory_leak.py +68 -0
  13. heavyball-1.7.1/tests/test_psgd_kron_line_optim.py +141 -0
  14. heavyball-1.7.1/tests/test_psgd_kron_regression.py +46 -0
  15. heavyball-1.7.1/tests/test_psgd_lra_regression.py +87 -0
  16. heavyball-1.7.1/tests/test_psgd_optimization.py +190 -0
  17. heavyball-1.7.1/tests/test_psgd_optimizations.py +97 -0
  18. heavyball-1.7.1/tests/test_psgd_training_performance.py +253 -0
  19. {heavyball-1.7.0 → heavyball-1.7.1}/LICENSE +0 -0
  20. {heavyball-1.7.0 → heavyball-1.7.1}/README.md +0 -0
  21. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/dependency_links.txt +0 -0
  22. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/requires.txt +0 -0
  23. {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/top_level.txt +0 -0
  24. {heavyball-1.7.0 → heavyball-1.7.1}/setup.cfg +0 -0
  25. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_bf16_params.py +0 -0
  26. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_bf16_q.py +0 -0
  27. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_bf16_storage.py +0 -0
  28. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_caution.py +0 -0
  29. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_channels_last.py +0 -0
  30. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_closure.py +0 -0
  31. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_ema.py +0 -0
  32. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_foreach.py +0 -0
  33. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_hook.py +0 -0
  34. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_mars.py +0 -0
  35. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_merge.py +0 -0
  36. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_no_grad.py +0 -0
  37. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_soap.py +0 -0
  38. {heavyball-1.7.0 → heavyball-1.7.1}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: heavyball
3
- Version: 1.7.0
3
+ Version: 1.7.1
4
4
  Summary: Efficient Optimizers
5
5
  Author-email: HeavyBall Authors <github.heavyball@nestler.sh>
6
6
  Project-URL: source, https://github.com/HomebrewML/HeavyBall
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  from typing import Optional
3
4
 
4
5
  from . import chainable as C
@@ -564,6 +565,10 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
564
565
  hessian_approx = True
565
566
 
566
567
 
568
+ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
569
+ hvp_interval = 2
570
+
571
+
567
572
  class ForeachPSGDLRA(C.BaseOpt):
568
573
  """
569
574
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -582,7 +587,7 @@ class ForeachPSGDLRA(C.BaseOpt):
582
587
  weight_decay=0.0,
583
588
  preconditioner_update_probability=None,
584
589
  momentum_into_precond_update=True,
585
- rank: int = 4,
590
+ rank: Optional[int] = None,
586
591
  warmup_steps: int = 0,
587
592
  foreach: bool = True,
588
593
  q_dtype="float32",
@@ -608,6 +613,14 @@ class ForeachPSGDLRA(C.BaseOpt):
608
613
  )
609
614
  params = defaults.pop("params")
610
615
 
616
+ if rank is None:
617
+ utils.warn_once(
618
+ f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
619
+ )
620
+ params = list(params)
621
+ defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
622
+ utils.warn_once(f"rank was set to {defaults['rank']}")
623
+
611
624
  delayed = C.default(delayed, self.delayed)
612
625
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
613
626
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
@@ -632,6 +645,10 @@ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
632
645
  hessian_approx = True
633
646
 
634
647
 
648
+ class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
649
+ hvp_interval = 2
650
+
651
+
635
652
  PalmForEachSoap = PaLMForeachSOAP
636
653
  PaLMSOAP = PaLMForeachSOAP
637
654
  PaLMSFAdamW = PaLMForeachSFAdamW
@@ -696,4 +713,6 @@ __all__ = [
696
713
  "DelayedPSGD",
697
714
  "PSGDLRA",
698
715
  "NewtonPSGDLRA",
716
+ "NewtonHybrid2PSGDLRA",
717
+ "NewtonHybrid2PSGDKron",
699
718
  ]
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  import random
3
4
  from typing import List, Literal, Optional, Union
4
5
 
@@ -43,7 +44,7 @@ class FunctionTransform:
43
44
  raise NotImplementedError
44
45
 
45
46
  def get_fn(self):
46
- if hasattr(self.fn, "get_fn"):
47
+ if utils.hasattr_none(self.fn, "get_fn"):
47
48
  return self.fn.get_fn()
48
49
  return self.fn
49
50
 
@@ -426,7 +427,7 @@ def _store_std(state, group, update, grad, param):
426
427
  state["init_std"] = torch.std(grad, dim=0)
427
428
 
428
429
 
429
- @general_guard("init_std", init_fn=_store_std)
430
+ @general_guard("init_std", init_fn=_store_std, skip_first=False)
430
431
  @no_state
431
432
  def mup_approx(group, updates, grads, params, init_std):
432
433
  _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
@@ -435,6 +436,40 @@ def mup_approx(group, updates, grads, params, init_std):
435
436
  return updates
436
437
 
437
438
 
439
+ def _init_delta(state, group, update, grad, param, log_space: bool):
440
+ val = group["initial_d"]
441
+ state["delta"] = torch.full((), math.log(val) if log_space else val, dtype=param.dtype, device=param.device)
442
+
443
+
444
+ def _init_full_delta(state, group, update, grad, param, log_space: bool):
445
+ val = group["initial_d"]
446
+ state["delta"] = torch.full_like(param, math.log(val) if log_space else val)
447
+
448
+
449
+ @zero_guard("state")
450
+ @general_guard("delta", init_fn=functools.partial(_init_delta, log_space=False), skip_first=False)
451
+ @no_state
452
+ def scale_by_d_adaptation(group, update, grad, param, state, delta):
453
+ utils.d_adaptation(grad, update, state, delta)
454
+ return update
455
+
456
+
457
+ @zero_guard("state")
458
+ @general_guard("delta", init_fn=functools.partial(_init_delta, log_space=True), skip_first=False)
459
+ @no_state
460
+ def scale_by_lr_adaptation(group, update, grad, param, state, delta):
461
+ utils.lr_adaptation(grad, update, state, delta, group["lr_lr"])
462
+ return update
463
+
464
+
465
+ @zero_guard("state")
466
+ @general_guard("delta", init_fn=functools.partial(_init_full_delta, log_space=True), skip_first=False)
467
+ @no_state
468
+ def scale_by_pointwise_lr_adaptation(group, update, grad, param, state, delta):
469
+ utils.pointwise_lr_adaptation(grad, update, state, delta, group["lr_lr"])
470
+ return update
471
+
472
+
438
473
  @zero_guard("momentum")
439
474
  @no_state
440
475
  def heavyball_momentum(group, updates, grads, params, momentum):
@@ -484,18 +519,22 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
484
519
  if not group["is_preconditioning"]:
485
520
  return Q_mat
486
521
 
522
+ if utils.hasattr_none(param, "vector"):
523
+ vector, hessian_vector = param.vector, param.hessian_vector
524
+ del param.vector
525
+ del param.hessian_vector
526
+ else:
527
+ vector, hessian_vector = utils.dampen_grad(grad)
528
+
487
529
  utils.psgd_update_precond(
488
530
  Q_mat,
489
531
  exprs,
490
- getattr(param, "hessian_vector", grad),
532
+ hessian_vector,
491
533
  group["precond_lr"],
492
534
  Q,
493
535
  group["store_triu_as_line"],
494
- getattr(param, "vector", None),
536
+ vector,
495
537
  )
496
- if hasattr(param, "vector"):
497
- del param.vector
498
- del param.hessian_vector
499
538
 
500
539
  if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
501
540
  if group["store_triu_as_line"]:
@@ -566,9 +605,12 @@ def _update_lra(
566
605
  if not group["is_preconditioning"]:
567
606
  return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
568
607
 
569
- if hasattr(params[0], "hessian_vector") and params[0].hessian_vector is not None:
608
+ if utils.hasattr_none(params[0], "hessian_vector"):
570
609
  vector = utils.flatten([p.vector for p in params])
571
610
  hessian_vector = utils.flatten([p.hessian_vector for p in params])
611
+ for p in params:
612
+ del p.vector
613
+ del p.hessian_vector
572
614
  else:
573
615
  vector, hessian_vector = utils.dampen_multiple(grads)
574
616
  return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
@@ -0,0 +1,38 @@
1
+ """
2
+ PSGD optimization module - optimized implementations of PSGD functions
3
+ to improve execution speed while maintaining numerical equivalence.
4
+ """
5
+
6
+ # Import optimized functions
7
+ # Import integrator API
8
+ from .integrator import (
9
+ enable_optimizations,
10
+ get_optimization_status,
11
+ restore_original_functions,
12
+ )
13
+ from .optimizations import (
14
+ # LRA optimizations
15
+ low_rank_mm_optimized,
16
+ lra_precond_optimized,
17
+ precond_grad_cached_optimized,
18
+ # KRON optimizations
19
+ psgd_calc_A_and_conjB_optimized,
20
+ psgd_precond_grad_optimized,
21
+ psgd_update_precond_optimized,
22
+ update_lra_precond_optimized,
23
+ )
24
+
25
+ __all__ = [
26
+ # Optimized functions
27
+ "low_rank_mm_optimized",
28
+ "update_lra_precond_optimized",
29
+ "lra_precond_optimized",
30
+ "psgd_calc_A_and_conjB_optimized",
31
+ "psgd_update_precond_optimized",
32
+ "psgd_precond_grad_optimized",
33
+ "precond_grad_cached_optimized",
34
+ # Integrator API
35
+ "enable_optimizations",
36
+ "restore_original_functions",
37
+ "get_optimization_status",
38
+ ]
@@ -0,0 +1,169 @@
1
+ """
2
+ Integration module to selectively enable optimized implementations
3
+ of PSGD functions while maintaining API compatibility.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from typing import Any, Dict
9
+
10
+ import torch
11
+
12
+ from . import optimizations
13
+ from .. import utils
14
+
15
+ # Store original function references
16
+ _original_functions = {}
17
+ _optimized_functions = {}
18
+
19
+ # Mapping of original functions to their optimized versions
20
+ OPTIMIZATION_MAP = {
21
+ # LRA functions
22
+ utils.update_lra_precond_: optimizations.update_lra_precond_optimized,
23
+ utils.lra_precond: optimizations.lra_precond_optimized,
24
+ # KRON functions
25
+ utils.psgd_update_precond: optimizations.psgd_update_precond_optimized,
26
+ utils.psgd_precond_grad: optimizations.psgd_precond_grad_optimized,
27
+ utils.precond_grad_cached_: optimizations.precond_grad_cached_optimized,
28
+ }
29
+
30
+ # Config for enabling/disabling optimizations
31
+ _config = {
32
+ "enabled": os.environ.get("HEAVYBALL_OPTIMIZE", "1") == "1",
33
+ "torch_compile_allowed": os.environ.get("HEAVYBALL_USE_COMPILE", "1") == "1",
34
+ "enable_lra": True,
35
+ "enable_kron": True,
36
+ "verbose": os.environ.get("HEAVYBALL_VERBOSE", "0") == "1",
37
+ }
38
+
39
+
40
+ def _apply_monkey_patch(original_func, optimized_func):
41
+ """Monkey patch a function with its optimized version."""
42
+ if original_func not in _original_functions:
43
+ _original_functions[original_func] = original_func
44
+
45
+ # Store reference to the optimized function
46
+ _optimized_functions[original_func] = optimized_func
47
+
48
+ # Get the module where the original function is defined
49
+ module = original_func.__module__
50
+ func_name = original_func.__name__
51
+
52
+ # Replace the function in its module
53
+ if hasattr(sys.modules[module], func_name):
54
+ setattr(sys.modules[module], func_name, optimized_func)
55
+
56
+ if _config["verbose"]:
57
+ print(f"Replaced {module}.{func_name} with optimized version")
58
+ else:
59
+ if _config["verbose"]:
60
+ print(f"Warning: Could not find {func_name} in module {module}")
61
+
62
+
63
+ def enable_optimizations(
64
+ enable: bool = True, lra: bool = True, kron: bool = True, torch_compile: bool = True, verbose: bool = False
65
+ ):
66
+ """
67
+ Enable or disable PSGD optimizations.
68
+
69
+ Args:
70
+ enable: Whether to enable optimizations at all
71
+ lra: Whether to enable LRA-specific optimizations
72
+ kron: Whether to enable Kron-specific optimizations
73
+ torch_compile: Whether to allow torch.compile optimizations
74
+ verbose: Whether to print optimization status messages
75
+ """
76
+ _config["enabled"] = enable
77
+ _config["enable_lra"] = lra
78
+ _config["enable_kron"] = kron
79
+ _config["torch_compile_allowed"] = torch_compile
80
+ _config["verbose"] = verbose
81
+
82
+ if verbose:
83
+ print(f"PSGD Optimizations: {'enabled' if enable else 'disabled'}")
84
+ print(f" - LRA optimizations: {'enabled' if lra else 'disabled'}")
85
+ print(f" - KRON optimizations: {'enabled' if kron else 'disabled'}")
86
+ print(f" - torch.compile: {'allowed' if torch_compile else 'disabled'}")
87
+
88
+ if not enable:
89
+ # Restore original functions
90
+ restore_original_functions()
91
+ return
92
+
93
+ # Apply optimizations based on config
94
+ for orig_func, opt_func in OPTIMIZATION_MAP.items():
95
+ # Skip LRA functions if disabled
96
+ if not _config["enable_lra"] and orig_func in [utils.update_lra_precond_, utils.lra_precond]:
97
+ continue
98
+
99
+ # Skip KRON functions if disabled
100
+ if not _config["enable_kron"] and orig_func in [
101
+ utils.psgd_update_precond,
102
+ utils.psgd_precond_grad,
103
+ utils.precond_grad_cached_,
104
+ ]:
105
+ continue
106
+
107
+ _apply_monkey_patch(orig_func, opt_func)
108
+
109
+ # Disable torch.compile if not allowed
110
+ if not _config["torch_compile_allowed"]:
111
+ # Monkey patch torch.compile to be a no-op
112
+ def _noop_compile(fn, **kwargs):
113
+ return fn
114
+
115
+ if not hasattr(torch, "_original_compile"):
116
+ torch._original_compile = torch.compile
117
+ torch.compile = _noop_compile
118
+ if verbose:
119
+ print("Disabled torch.compile (replaced with no-op)")
120
+ else:
121
+ # Restore original torch.compile
122
+ if hasattr(torch, "_original_compile"):
123
+ torch.compile = torch._original_compile
124
+ del torch._original_compile
125
+ if verbose:
126
+ print("Restored original torch.compile")
127
+
128
+
129
+ def restore_original_functions():
130
+ """Restore all original function implementations."""
131
+ for orig_func, func_ref in _original_functions.items():
132
+ module = orig_func.__module__
133
+ func_name = orig_func.__name__
134
+
135
+ if hasattr(sys.modules[module], func_name):
136
+ setattr(sys.modules[module], func_name, func_ref)
137
+
138
+ if _config["verbose"]:
139
+ print(f"Restored original implementation of {module}.{func_name}")
140
+
141
+ # Also restore torch.compile if it was modified
142
+ if hasattr(torch, "_original_compile"):
143
+ torch.compile = torch._original_compile
144
+ del torch._original_compile
145
+ if _config["verbose"]:
146
+ print("Restored original torch.compile")
147
+
148
+
149
+ def get_optimization_status() -> Dict[str, Any]:
150
+ """Get current optimization status."""
151
+ return {
152
+ "enabled": _config["enabled"],
153
+ "lra_enabled": _config["enable_lra"],
154
+ "kron_enabled": _config["enable_kron"],
155
+ "torch_compile_allowed": _config["torch_compile_allowed"],
156
+ "optimized_functions": list(_optimized_functions.keys()),
157
+ "original_functions": list(_original_functions.keys()),
158
+ }
159
+
160
+
161
+ # Auto-initialize optimizations based on environment
162
+ if os.environ.get("HEAVYBALL_AUTO_OPTIMIZE", "1") == "1":
163
+ enable_optimizations(
164
+ enable=_config["enabled"],
165
+ lra=_config["enable_lra"],
166
+ kron=_config["enable_kron"],
167
+ torch_compile=_config["torch_compile_allowed"],
168
+ verbose=_config["verbose"],
169
+ )