heavyball 1.7.0__tar.gz → 1.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.7.0 → heavyball-1.7.1}/PKG-INFO +1 -1
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball/__init__.py +20 -1
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball/chainable.py +50 -8
- heavyball-1.7.1/heavyball/optimizations/__init__.py +38 -0
- heavyball-1.7.1/heavyball/optimizations/integrator.py +169 -0
- heavyball-1.7.1/heavyball/optimizations/optimizations.py +329 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball/utils.py +518 -162
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/SOURCES.txt +11 -1
- {heavyball-1.7.0 → heavyball-1.7.1}/pyproject.toml +1 -1
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_memory.py +12 -6
- heavyball-1.7.1/test/test_memory_leak.py +68 -0
- heavyball-1.7.1/tests/test_psgd_kron_line_optim.py +141 -0
- heavyball-1.7.1/tests/test_psgd_kron_regression.py +46 -0
- heavyball-1.7.1/tests/test_psgd_lra_regression.py +87 -0
- heavyball-1.7.1/tests/test_psgd_optimization.py +190 -0
- heavyball-1.7.1/tests/test_psgd_optimizations.py +97 -0
- heavyball-1.7.1/tests/test_psgd_training_performance.py +253 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/LICENSE +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/README.md +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/setup.cfg +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_bf16_params.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_bf16_q.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_bf16_storage.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_caution.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_channels_last.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_closure.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_ema.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_foreach.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_hook.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_mars.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_merge.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_no_grad.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_soap.py +0 -0
- {heavyball-1.7.0 → heavyball-1.7.1}/test/test_stochastic_updates.py +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import math
|
2
3
|
from typing import Optional
|
3
4
|
|
4
5
|
from . import chainable as C
|
@@ -564,6 +565,10 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
|
|
564
565
|
hessian_approx = True
|
565
566
|
|
566
567
|
|
568
|
+
class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
|
569
|
+
hvp_interval = 2
|
570
|
+
|
571
|
+
|
567
572
|
class ForeachPSGDLRA(C.BaseOpt):
|
568
573
|
"""
|
569
574
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -582,7 +587,7 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
582
587
|
weight_decay=0.0,
|
583
588
|
preconditioner_update_probability=None,
|
584
589
|
momentum_into_precond_update=True,
|
585
|
-
rank: int =
|
590
|
+
rank: Optional[int] = None,
|
586
591
|
warmup_steps: int = 0,
|
587
592
|
foreach: bool = True,
|
588
593
|
q_dtype="float32",
|
@@ -608,6 +613,14 @@ class ForeachPSGDLRA(C.BaseOpt):
|
|
608
613
|
)
|
609
614
|
params = defaults.pop("params")
|
610
615
|
|
616
|
+
if rank is None:
|
617
|
+
utils.warn_once(
|
618
|
+
f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
|
619
|
+
)
|
620
|
+
params = list(params)
|
621
|
+
defaults["rank"] = round(math.log2(sum(p.numel() for p in params)))
|
622
|
+
utils.warn_once(f"rank was set to {defaults['rank']}")
|
623
|
+
|
611
624
|
delayed = C.default(delayed, self.delayed)
|
612
625
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
613
626
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
@@ -632,6 +645,10 @@ class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
|
|
632
645
|
hessian_approx = True
|
633
646
|
|
634
647
|
|
648
|
+
class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
|
649
|
+
hvp_interval = 2
|
650
|
+
|
651
|
+
|
635
652
|
PalmForEachSoap = PaLMForeachSOAP
|
636
653
|
PaLMSOAP = PaLMForeachSOAP
|
637
654
|
PaLMSFAdamW = PaLMForeachSFAdamW
|
@@ -696,4 +713,6 @@ __all__ = [
|
|
696
713
|
"DelayedPSGD",
|
697
714
|
"PSGDLRA",
|
698
715
|
"NewtonPSGDLRA",
|
716
|
+
"NewtonHybrid2PSGDLRA",
|
717
|
+
"NewtonHybrid2PSGDKron",
|
699
718
|
]
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import functools
|
2
|
+
import math
|
2
3
|
import random
|
3
4
|
from typing import List, Literal, Optional, Union
|
4
5
|
|
@@ -43,7 +44,7 @@ class FunctionTransform:
|
|
43
44
|
raise NotImplementedError
|
44
45
|
|
45
46
|
def get_fn(self):
|
46
|
-
if
|
47
|
+
if utils.hasattr_none(self.fn, "get_fn"):
|
47
48
|
return self.fn.get_fn()
|
48
49
|
return self.fn
|
49
50
|
|
@@ -426,7 +427,7 @@ def _store_std(state, group, update, grad, param):
|
|
426
427
|
state["init_std"] = torch.std(grad, dim=0)
|
427
428
|
|
428
429
|
|
429
|
-
@general_guard("init_std", init_fn=_store_std)
|
430
|
+
@general_guard("init_std", init_fn=_store_std, skip_first=False)
|
430
431
|
@no_state
|
431
432
|
def mup_approx(group, updates, grads, params, init_std):
|
432
433
|
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
@@ -435,6 +436,40 @@ def mup_approx(group, updates, grads, params, init_std):
|
|
435
436
|
return updates
|
436
437
|
|
437
438
|
|
439
|
+
def _init_delta(state, group, update, grad, param, log_space: bool):
|
440
|
+
val = group["initial_d"]
|
441
|
+
state["delta"] = torch.full((), math.log(val) if log_space else val, dtype=param.dtype, device=param.device)
|
442
|
+
|
443
|
+
|
444
|
+
def _init_full_delta(state, group, update, grad, param, log_space: bool):
|
445
|
+
val = group["initial_d"]
|
446
|
+
state["delta"] = torch.full_like(param, math.log(val) if log_space else val)
|
447
|
+
|
448
|
+
|
449
|
+
@zero_guard("state")
|
450
|
+
@general_guard("delta", init_fn=functools.partial(_init_delta, log_space=False), skip_first=False)
|
451
|
+
@no_state
|
452
|
+
def scale_by_d_adaptation(group, update, grad, param, state, delta):
|
453
|
+
utils.d_adaptation(grad, update, state, delta)
|
454
|
+
return update
|
455
|
+
|
456
|
+
|
457
|
+
@zero_guard("state")
|
458
|
+
@general_guard("delta", init_fn=functools.partial(_init_delta, log_space=True), skip_first=False)
|
459
|
+
@no_state
|
460
|
+
def scale_by_lr_adaptation(group, update, grad, param, state, delta):
|
461
|
+
utils.lr_adaptation(grad, update, state, delta, group["lr_lr"])
|
462
|
+
return update
|
463
|
+
|
464
|
+
|
465
|
+
@zero_guard("state")
|
466
|
+
@general_guard("delta", init_fn=functools.partial(_init_full_delta, log_space=True), skip_first=False)
|
467
|
+
@no_state
|
468
|
+
def scale_by_pointwise_lr_adaptation(group, update, grad, param, state, delta):
|
469
|
+
utils.pointwise_lr_adaptation(grad, update, state, delta, group["lr_lr"])
|
470
|
+
return update
|
471
|
+
|
472
|
+
|
438
473
|
@zero_guard("momentum")
|
439
474
|
@no_state
|
440
475
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
@@ -484,18 +519,22 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p
|
|
484
519
|
if not group["is_preconditioning"]:
|
485
520
|
return Q_mat
|
486
521
|
|
522
|
+
if utils.hasattr_none(param, "vector"):
|
523
|
+
vector, hessian_vector = param.vector, param.hessian_vector
|
524
|
+
del param.vector
|
525
|
+
del param.hessian_vector
|
526
|
+
else:
|
527
|
+
vector, hessian_vector = utils.dampen_grad(grad)
|
528
|
+
|
487
529
|
utils.psgd_update_precond(
|
488
530
|
Q_mat,
|
489
531
|
exprs,
|
490
|
-
|
532
|
+
hessian_vector,
|
491
533
|
group["precond_lr"],
|
492
534
|
Q,
|
493
535
|
group["store_triu_as_line"],
|
494
|
-
|
536
|
+
vector,
|
495
537
|
)
|
496
|
-
if hasattr(param, "vector"):
|
497
|
-
del param.vector
|
498
|
-
del param.hessian_vector
|
499
538
|
|
500
539
|
if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"):
|
501
540
|
if group["store_triu_as_line"]:
|
@@ -566,9 +605,12 @@ def _update_lra(
|
|
566
605
|
if not group["is_preconditioning"]:
|
567
606
|
return utils.flatten(U, 1), utils.flatten(V, 1), utils.flatten(d)
|
568
607
|
|
569
|
-
if
|
608
|
+
if utils.hasattr_none(params[0], "hessian_vector"):
|
570
609
|
vector = utils.flatten([p.vector for p in params])
|
571
610
|
hessian_vector = utils.flatten([p.hessian_vector for p in params])
|
611
|
+
for p in params:
|
612
|
+
del p.vector
|
613
|
+
del p.hessian_vector
|
572
614
|
else:
|
573
615
|
vector, hessian_vector = utils.dampen_multiple(grads)
|
574
616
|
return utils.update_lra_precond_(U, V, d, vector, hessian_vector, group["eps"], group["precond_lr"], delayed)
|
@@ -0,0 +1,38 @@
|
|
1
|
+
"""
|
2
|
+
PSGD optimization module - optimized implementations of PSGD functions
|
3
|
+
to improve execution speed while maintaining numerical equivalence.
|
4
|
+
"""
|
5
|
+
|
6
|
+
# Import optimized functions
|
7
|
+
# Import integrator API
|
8
|
+
from .integrator import (
|
9
|
+
enable_optimizations,
|
10
|
+
get_optimization_status,
|
11
|
+
restore_original_functions,
|
12
|
+
)
|
13
|
+
from .optimizations import (
|
14
|
+
# LRA optimizations
|
15
|
+
low_rank_mm_optimized,
|
16
|
+
lra_precond_optimized,
|
17
|
+
precond_grad_cached_optimized,
|
18
|
+
# KRON optimizations
|
19
|
+
psgd_calc_A_and_conjB_optimized,
|
20
|
+
psgd_precond_grad_optimized,
|
21
|
+
psgd_update_precond_optimized,
|
22
|
+
update_lra_precond_optimized,
|
23
|
+
)
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
# Optimized functions
|
27
|
+
"low_rank_mm_optimized",
|
28
|
+
"update_lra_precond_optimized",
|
29
|
+
"lra_precond_optimized",
|
30
|
+
"psgd_calc_A_and_conjB_optimized",
|
31
|
+
"psgd_update_precond_optimized",
|
32
|
+
"psgd_precond_grad_optimized",
|
33
|
+
"precond_grad_cached_optimized",
|
34
|
+
# Integrator API
|
35
|
+
"enable_optimizations",
|
36
|
+
"restore_original_functions",
|
37
|
+
"get_optimization_status",
|
38
|
+
]
|
@@ -0,0 +1,169 @@
|
|
1
|
+
"""
|
2
|
+
Integration module to selectively enable optimized implementations
|
3
|
+
of PSGD functions while maintaining API compatibility.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import os
|
7
|
+
import sys
|
8
|
+
from typing import Any, Dict
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
from . import optimizations
|
13
|
+
from .. import utils
|
14
|
+
|
15
|
+
# Store original function references
|
16
|
+
_original_functions = {}
|
17
|
+
_optimized_functions = {}
|
18
|
+
|
19
|
+
# Mapping of original functions to their optimized versions
|
20
|
+
OPTIMIZATION_MAP = {
|
21
|
+
# LRA functions
|
22
|
+
utils.update_lra_precond_: optimizations.update_lra_precond_optimized,
|
23
|
+
utils.lra_precond: optimizations.lra_precond_optimized,
|
24
|
+
# KRON functions
|
25
|
+
utils.psgd_update_precond: optimizations.psgd_update_precond_optimized,
|
26
|
+
utils.psgd_precond_grad: optimizations.psgd_precond_grad_optimized,
|
27
|
+
utils.precond_grad_cached_: optimizations.precond_grad_cached_optimized,
|
28
|
+
}
|
29
|
+
|
30
|
+
# Config for enabling/disabling optimizations
|
31
|
+
_config = {
|
32
|
+
"enabled": os.environ.get("HEAVYBALL_OPTIMIZE", "1") == "1",
|
33
|
+
"torch_compile_allowed": os.environ.get("HEAVYBALL_USE_COMPILE", "1") == "1",
|
34
|
+
"enable_lra": True,
|
35
|
+
"enable_kron": True,
|
36
|
+
"verbose": os.environ.get("HEAVYBALL_VERBOSE", "0") == "1",
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
def _apply_monkey_patch(original_func, optimized_func):
|
41
|
+
"""Monkey patch a function with its optimized version."""
|
42
|
+
if original_func not in _original_functions:
|
43
|
+
_original_functions[original_func] = original_func
|
44
|
+
|
45
|
+
# Store reference to the optimized function
|
46
|
+
_optimized_functions[original_func] = optimized_func
|
47
|
+
|
48
|
+
# Get the module where the original function is defined
|
49
|
+
module = original_func.__module__
|
50
|
+
func_name = original_func.__name__
|
51
|
+
|
52
|
+
# Replace the function in its module
|
53
|
+
if hasattr(sys.modules[module], func_name):
|
54
|
+
setattr(sys.modules[module], func_name, optimized_func)
|
55
|
+
|
56
|
+
if _config["verbose"]:
|
57
|
+
print(f"Replaced {module}.{func_name} with optimized version")
|
58
|
+
else:
|
59
|
+
if _config["verbose"]:
|
60
|
+
print(f"Warning: Could not find {func_name} in module {module}")
|
61
|
+
|
62
|
+
|
63
|
+
def enable_optimizations(
|
64
|
+
enable: bool = True, lra: bool = True, kron: bool = True, torch_compile: bool = True, verbose: bool = False
|
65
|
+
):
|
66
|
+
"""
|
67
|
+
Enable or disable PSGD optimizations.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
enable: Whether to enable optimizations at all
|
71
|
+
lra: Whether to enable LRA-specific optimizations
|
72
|
+
kron: Whether to enable Kron-specific optimizations
|
73
|
+
torch_compile: Whether to allow torch.compile optimizations
|
74
|
+
verbose: Whether to print optimization status messages
|
75
|
+
"""
|
76
|
+
_config["enabled"] = enable
|
77
|
+
_config["enable_lra"] = lra
|
78
|
+
_config["enable_kron"] = kron
|
79
|
+
_config["torch_compile_allowed"] = torch_compile
|
80
|
+
_config["verbose"] = verbose
|
81
|
+
|
82
|
+
if verbose:
|
83
|
+
print(f"PSGD Optimizations: {'enabled' if enable else 'disabled'}")
|
84
|
+
print(f" - LRA optimizations: {'enabled' if lra else 'disabled'}")
|
85
|
+
print(f" - KRON optimizations: {'enabled' if kron else 'disabled'}")
|
86
|
+
print(f" - torch.compile: {'allowed' if torch_compile else 'disabled'}")
|
87
|
+
|
88
|
+
if not enable:
|
89
|
+
# Restore original functions
|
90
|
+
restore_original_functions()
|
91
|
+
return
|
92
|
+
|
93
|
+
# Apply optimizations based on config
|
94
|
+
for orig_func, opt_func in OPTIMIZATION_MAP.items():
|
95
|
+
# Skip LRA functions if disabled
|
96
|
+
if not _config["enable_lra"] and orig_func in [utils.update_lra_precond_, utils.lra_precond]:
|
97
|
+
continue
|
98
|
+
|
99
|
+
# Skip KRON functions if disabled
|
100
|
+
if not _config["enable_kron"] and orig_func in [
|
101
|
+
utils.psgd_update_precond,
|
102
|
+
utils.psgd_precond_grad,
|
103
|
+
utils.precond_grad_cached_,
|
104
|
+
]:
|
105
|
+
continue
|
106
|
+
|
107
|
+
_apply_monkey_patch(orig_func, opt_func)
|
108
|
+
|
109
|
+
# Disable torch.compile if not allowed
|
110
|
+
if not _config["torch_compile_allowed"]:
|
111
|
+
# Monkey patch torch.compile to be a no-op
|
112
|
+
def _noop_compile(fn, **kwargs):
|
113
|
+
return fn
|
114
|
+
|
115
|
+
if not hasattr(torch, "_original_compile"):
|
116
|
+
torch._original_compile = torch.compile
|
117
|
+
torch.compile = _noop_compile
|
118
|
+
if verbose:
|
119
|
+
print("Disabled torch.compile (replaced with no-op)")
|
120
|
+
else:
|
121
|
+
# Restore original torch.compile
|
122
|
+
if hasattr(torch, "_original_compile"):
|
123
|
+
torch.compile = torch._original_compile
|
124
|
+
del torch._original_compile
|
125
|
+
if verbose:
|
126
|
+
print("Restored original torch.compile")
|
127
|
+
|
128
|
+
|
129
|
+
def restore_original_functions():
|
130
|
+
"""Restore all original function implementations."""
|
131
|
+
for orig_func, func_ref in _original_functions.items():
|
132
|
+
module = orig_func.__module__
|
133
|
+
func_name = orig_func.__name__
|
134
|
+
|
135
|
+
if hasattr(sys.modules[module], func_name):
|
136
|
+
setattr(sys.modules[module], func_name, func_ref)
|
137
|
+
|
138
|
+
if _config["verbose"]:
|
139
|
+
print(f"Restored original implementation of {module}.{func_name}")
|
140
|
+
|
141
|
+
# Also restore torch.compile if it was modified
|
142
|
+
if hasattr(torch, "_original_compile"):
|
143
|
+
torch.compile = torch._original_compile
|
144
|
+
del torch._original_compile
|
145
|
+
if _config["verbose"]:
|
146
|
+
print("Restored original torch.compile")
|
147
|
+
|
148
|
+
|
149
|
+
def get_optimization_status() -> Dict[str, Any]:
|
150
|
+
"""Get current optimization status."""
|
151
|
+
return {
|
152
|
+
"enabled": _config["enabled"],
|
153
|
+
"lra_enabled": _config["enable_lra"],
|
154
|
+
"kron_enabled": _config["enable_kron"],
|
155
|
+
"torch_compile_allowed": _config["torch_compile_allowed"],
|
156
|
+
"optimized_functions": list(_optimized_functions.keys()),
|
157
|
+
"original_functions": list(_original_functions.keys()),
|
158
|
+
}
|
159
|
+
|
160
|
+
|
161
|
+
# Auto-initialize optimizations based on environment
|
162
|
+
if os.environ.get("HEAVYBALL_AUTO_OPTIMIZE", "1") == "1":
|
163
|
+
enable_optimizations(
|
164
|
+
enable=_config["enabled"],
|
165
|
+
lra=_config["enable_lra"],
|
166
|
+
kron=_config["enable_kron"],
|
167
|
+
torch_compile=_config["torch_compile_allowed"],
|
168
|
+
verbose=_config["verbose"],
|
169
|
+
)
|