heavyball 0.21.1__tar.gz → 0.21.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.21.1 → heavyball-0.21.2}/PKG-INFO +1 -1
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/cached_delayed_psgd_kron.py +1 -1
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/cached_psgd_kron.py +1 -1
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/delayed_psgd.py +8 -2
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.21.1 → heavyball-0.21.2}/setup.py +1 -1
- {heavyball-0.21.1 → heavyball-0.21.2}/LICENSE +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/README.md +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/__init__.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/p_adam.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/psgd_kron.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/pure_psgd.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball/utils.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/setup.cfg +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_bf16_params.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_bf16_q.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_bf16_storage.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_closure.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_ema.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_foreach.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_memory.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_merge.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_no_grad.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_psgd.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_soap.py +0 -0
- {heavyball-0.21.1 → heavyball-0.21.2}/test/test_stochastic_updates.py +0 -0
@@ -120,7 +120,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
120
120
|
q_orig = Q_list.pop(0)
|
121
121
|
ea = exp_avg_list.pop(0)
|
122
122
|
|
123
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'],
|
123
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
|
124
124
|
|
125
125
|
if should_update:
|
126
126
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
@@ -128,4 +128,4 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
128
128
|
else:
|
129
129
|
torch.mul(q_.conj(), q_, out=c_)
|
130
130
|
|
131
|
-
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'],
|
131
|
+
precond_grad_cached_(cached_q, ea, self.state_(p)['cache_expr'], p, lr, weight_decay)
|
@@ -11,6 +11,12 @@ from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust
|
|
11
11
|
split_p_and_g_in_group, triu_to_line, line_to_triu, promote
|
12
12
|
|
13
13
|
|
14
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
15
|
+
def _compilable_psgd_precond_grad_(q, exprs, ea, p, lr, weight_decay):
|
16
|
+
new = psgd_precond_grad(q, exprs, ea)
|
17
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
18
|
+
|
19
|
+
|
14
20
|
class ForeachDelayedPSGD(PSGDBase):
|
15
21
|
"""
|
16
22
|
Implements PSGD with off-by-one preconditioning (akin to ADOPT and SOAP)
|
@@ -98,14 +104,14 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
98
104
|
stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
|
99
105
|
|
100
106
|
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
107
|
+
lr = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(lr)
|
101
108
|
|
102
109
|
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
103
110
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
104
111
|
q_orig = Q_list.pop(0)
|
105
112
|
ea = exp_avg_list.pop(0)
|
106
113
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
107
|
-
|
108
|
-
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
114
|
+
_compilable_psgd_precond_grad_(q, state["exprs"], ea, p, lr, weight_decay)
|
109
115
|
if should_update:
|
110
116
|
q32 = [promote(q_) for q_ in q]
|
111
117
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|